diff --git a/kreuzberg/_constants.py b/kreuzberg/_constants.py index 5b7f0c1..99efd2a 100644 --- a/kreuzberg/_constants.py +++ b/kreuzberg/_constants.py @@ -4,3 +4,5 @@ from typing import Final DEFAULT_MAX_PROCESSES: Final[int] = cpu_count() +MINIMAL_SUPPORTED_TESSERACT_VERSION: Final[int] = 5 +MINIMAL_SUPPORTED_PANDOC_VERSION: Final[int] = 2 diff --git a/kreuzberg/_pandoc.py b/kreuzberg/_pandoc.py index 0cdf638..bc6e8a6 100644 --- a/kreuzberg/_pandoc.py +++ b/kreuzberg/_pandoc.py @@ -1,20 +1,22 @@ from __future__ import annotations -import subprocess +import re import sys from functools import partial from json import JSONDecodeError, loads from typing import TYPE_CHECKING, Any, Final, Literal, cast -from anyio import CapacityLimiter, create_task_group, run_process, to_process from anyio import Path as AsyncPath +from anyio import run_process -from kreuzberg._constants import DEFAULT_MAX_PROCESSES +from kreuzberg import ValidationError +from kreuzberg._constants import MINIMAL_SUPPORTED_PANDOC_VERSION from kreuzberg._mime_types import MARKDOWN_MIME_TYPE from kreuzberg._string import normalize_spaces +from kreuzberg._sync import run_taskgroup from kreuzberg._tmp import create_temp_file from kreuzberg._types import ExtractionResult, Metadata -from kreuzberg.exceptions import MissingDependencyError, ParsingError, ValidationError +from kreuzberg.exceptions import MissingDependencyError, ParsingError if TYPE_CHECKING: # pragma: no cover from collections.abc import Mapping @@ -23,10 +25,8 @@ if sys.version_info < (3, 11): # pragma: no cover from exceptiongroup import ExceptionGroup # type: ignore[import-not-found] - version_ref: Final[dict[str, bool]] = {"checked": False} - # Block-level node types in Pandoc AST BLOCK_HEADER: Final = "Header" # Header with level, attributes and inline content BLOCK_PARA: Final = "Para" # Paragraph containing inline content @@ -228,20 +228,15 @@ def _extract_metadata(raw_meta: dict[str, Any]) -> Metadata: def _get_pandoc_type_from_mime_type(mime_type: str) -> str: - if mime_type not in MIMETYPE_TO_PANDOC_TYPE_MAPPING or not any( - mime_type.startswith(value) for value in MIMETYPE_TO_PANDOC_TYPE_MAPPING - ): - raise ValidationError( - f"Unsupported mime type: {mime_type}", - context={ - "mime_type": mime_type, - "supported_mimetypes": ",".join(sorted(MIMETYPE_TO_PANDOC_TYPE_MAPPING)), - }, + if pandoc_type := (MIMETYPE_TO_PANDOC_TYPE_MAPPING.get(mime_type, "")): + return pandoc_type + + if any(k.startswith(mime_type) for k in MIMETYPE_TO_PANDOC_TYPE_MAPPING): + return next( + MIMETYPE_TO_PANDOC_TYPE_MAPPING[k] for k in MIMETYPE_TO_PANDOC_TYPE_MAPPING if k.startswith(mime_type) ) - return MIMETYPE_TO_PANDOC_TYPE_MAPPING.get(mime_type) or next( - MIMETYPE_TO_PANDOC_TYPE_MAPPING[k] for k in MIMETYPE_TO_PANDOC_TYPE_MAPPING if k.startswith(mime_type) - ) + raise ValidationError(f"Unsupported mime type: {mime_type}") async def _validate_pandoc_version() -> None: @@ -251,20 +246,18 @@ async def _validate_pandoc_version() -> None: command = ["pandoc", "--version"] result = await run_process(command) - version = result.stdout.decode().split("\n")[0].split()[1] - major_version = int(version.split(".")[0]) - if major_version < 2: - raise MissingDependencyError("Pandoc version 2 or above is required.") + + version_match = re.search(r"pandoc\s+v?(\d+)\.\d+\.\d+", result.stdout.decode()) + if not version_match or int(version_match.group(1)) < MINIMAL_SUPPORTED_PANDOC_VERSION: + raise MissingDependencyError("Pandoc version 2 or above is required") version_ref["checked"] = True except FileNotFoundError as e: - raise MissingDependencyError("Pandoc is not installed.") from e + raise MissingDependencyError("Pandoc is not installed") from e -async def _handle_extract_metadata( - input_file: str | PathLike[str], *, mime_type: str, max_processes: int = DEFAULT_MAX_PROCESSES -) -> Metadata: +async def _handle_extract_metadata(input_file: str | PathLike[str], *, mime_type: str) -> Metadata: pandoc_type = _get_pandoc_type_from_mime_type(mime_type) metadata_file, unlink = await create_temp_file(".json") try: @@ -276,15 +269,10 @@ async def _handle_extract_metadata( "--standalone", "--quiet", "--output", - metadata_file, + str(metadata_file), ] - result = await to_process.run_sync( - partial(subprocess.run, capture_output=True), - command, - cancellable=True, - limiter=CapacityLimiter(max_processes), - ) + result = await run_process(command) if result.returncode != 0: raise ParsingError("Failed to extract file data", context={"file": str(input_file), "error": result.stderr}) @@ -297,9 +285,7 @@ async def _handle_extract_metadata( await unlink() -async def _handle_extract_file( - input_file: str | PathLike[str], *, mime_type: str, max_processes: int = DEFAULT_MAX_PROCESSES -) -> str: +async def _handle_extract_file(input_file: str | PathLike[str], *, mime_type: str) -> str: pandoc_type = _get_pandoc_type_from_mime_type(mime_type) output_path, unlink = await create_temp_file(".md") try: @@ -315,12 +301,7 @@ async def _handle_extract_file( command.extend(["--output", str(output_path)]) - result = await to_process.run_sync( - partial(subprocess.run, capture_output=True), - command, - cancellable=True, - limiter=CapacityLimiter(max_processes), - ) + result = await run_process(command) if result.returncode != 0: raise ParsingError("Failed to extract file data", context={"file": str(input_file), "error": result.stderr}) @@ -334,15 +315,12 @@ async def _handle_extract_file( await unlink() -async def process_file_with_pandoc( - input_file: str | PathLike[str], *, mime_type: str, max_processes: int = DEFAULT_MAX_PROCESSES -) -> ExtractionResult: +async def process_file_with_pandoc(input_file: str | PathLike[str], *, mime_type: str) -> ExtractionResult: """Process a single file using Pandoc and convert to markdown. Args: input_file: The path to the file to process. mime_type: The mime type of the file. - max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1). Raises: ParsingError: If the file data could not be extracted. @@ -354,44 +332,27 @@ async def process_file_with_pandoc( _get_pandoc_type_from_mime_type(mime_type) - metadata: Metadata = {} - content: str = "" - try: - async with create_task_group() as tg: - - async def _get_metadata() -> None: - nonlocal metadata - metadata = await _handle_extract_metadata(input_file, mime_type=mime_type, max_processes=max_processes) - - async def _get_content() -> None: - nonlocal content - content = await _handle_extract_file(input_file, mime_type=mime_type, max_processes=max_processes) + metadata, content = await run_taskgroup( + partial(_handle_extract_metadata, input_file, mime_type=mime_type), + partial(_handle_extract_file, input_file, mime_type=mime_type), + ) - tg.start_soon(_get_metadata) - tg.start_soon(_get_content) + return ExtractionResult( + content=normalize_spaces(cast(str, content)), + metadata=cast(Metadata, metadata), + mime_type=MARKDOWN_MIME_TYPE, + ) except ExceptionGroup as eg: - raise ParsingError( - "Failed to extract file data", - context={"file": str(input_file), "errors": ",".join([str(e) for e in eg.exceptions])}, - ) from eg.exceptions[0] - - return ExtractionResult( - content=normalize_spaces(content), - metadata=metadata, - mime_type=MARKDOWN_MIME_TYPE, - ) + raise ParsingError("Failed to process file", context={"file": str(input_file), "errors": eg.exceptions}) from eg -async def process_content_with_pandoc( - content: bytes, *, mime_type: str, max_processes: int = DEFAULT_MAX_PROCESSES -) -> ExtractionResult: +async def process_content_with_pandoc(content: bytes, *, mime_type: str) -> ExtractionResult: """Process content using Pandoc and convert to markdown. Args: content: The content to process. mime_type: The mime type of the content. - max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1). Returns: ExtractionResult @@ -400,7 +361,7 @@ async def process_content_with_pandoc( input_file, unlink = await create_temp_file(f".{extension}") await AsyncPath(input_file).write_bytes(content) - result = await process_file_with_pandoc(input_file, mime_type=mime_type, max_processes=max_processes) + result = await process_file_with_pandoc(input_file, mime_type=mime_type) await unlink() return result diff --git a/kreuzberg/_sync.py b/kreuzberg/_sync.py index 4fc9a7a..8ace17f 100644 --- a/kreuzberg/_sync.py +++ b/kreuzberg/_sync.py @@ -1,7 +1,6 @@ from __future__ import annotations import sys -from collections.abc import Awaitable from functools import partial from typing import TYPE_CHECKING, TypeVar, cast @@ -9,7 +8,7 @@ from anyio.to_thread import run_sync as any_io_run_sync if TYPE_CHECKING: # pragma: no cover - from collections.abc import Callable + from collections.abc import Callable, Coroutine if sys.version_info >= (3, 10): from typing import ParamSpec @@ -35,32 +34,32 @@ async def run_sync(sync_fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) - return cast(T, await any_io_run_sync(handler, *args, abandon_on_cancel=True)) # pyright: ignore [reportCallIssue] -async def run_taskgroup(*coroutines: Callable[[], Awaitable[T]]) -> list[T]: +async def run_taskgroup(*async_tasks: Callable[[], Coroutine[None, None, T]]) -> list[T]: """Run a list of coroutines concurrently. Args: - coroutines: The list of coroutines to run. + *async_tasks: The list of coroutines to run. Returns: The results of the coroutines. """ - results = cast(list[T], [None] * len(coroutines)) + results = cast(list[T], [None] * len(async_tasks)) - async def run_task(index: int, task: Callable[[], Awaitable[T]]) -> None: + async def run_task(index: int, task: Callable[[], Coroutine[None, None, T]]) -> None: results[index] = await task() async with create_task_group() as tg: - for i, coro in enumerate(coroutines): - tg.start_soon(run_task, i, coro) + for i, t in enumerate(async_tasks): + tg.start_soon(run_task, i, t) return results -async def run_taskgroup_batched(*coroutines: Callable[[], Awaitable[T]], batch_size: int) -> list[T]: +async def run_taskgroup_batched(*async_tasks: Callable[[], Coroutine[None, None, T]], batch_size: int) -> list[T]: """Run a list of coroutines concurrently in batches. Args: - coroutines: The list of coroutines to run. + *async_tasks: The list of coroutines to run. batch_size: The size of each batch. Returns: @@ -68,8 +67,8 @@ async def run_taskgroup_batched(*coroutines: Callable[[], Awaitable[T]], batch_s """ results: list[T] = [] - for i in range(0, len(coroutines), batch_size): - batch = coroutines[i : i + batch_size] + for i in range(0, len(async_tasks), batch_size): + batch = async_tasks[i : i + batch_size] results.extend(await run_taskgroup(*batch)) return results diff --git a/kreuzberg/_tesseract.py b/kreuzberg/_tesseract.py index c76be5b..d5c2ca2 100644 --- a/kreuzberg/_tesseract.py +++ b/kreuzberg/_tesseract.py @@ -5,13 +5,13 @@ from enum import Enum from functools import partial from os import PathLike -from typing import Any, Final, TypeVar, Union +from typing import Any, TypeVar, Union from anyio import Path as AsyncPath from anyio import run_process from PIL.Image import Image -from kreuzberg._constants import DEFAULT_MAX_PROCESSES +from kreuzberg._constants import DEFAULT_MAX_PROCESSES, MINIMAL_SUPPORTED_TESSERACT_VERSION from kreuzberg._mime_types import PLAIN_TEXT_MIME_TYPE from kreuzberg._string import normalize_spaces from kreuzberg._sync import run_sync, run_taskgroup_batched @@ -22,8 +22,6 @@ if sys.version_info < (3, 11): # pragma: no cover from exceptiongroup import ExceptionGroup # type: ignore[import-not-found] -MINIMAL_SUPPORTED_TESSERACT_VERSION: Final[int] = 5 - version_ref = {"checked": False} T = TypeVar("T", bound=Union[Image, PathLike[str], str]) @@ -233,6 +231,4 @@ async def batch_process_images( batch_size=max_processes, ) except ExceptionGroup as eg: - raise ParsingError( - "Failed to process images with Tesseract", context={"errors": ",".join([str(e) for e in eg.exceptions])} - ) from eg + raise ParsingError("Failed to process images with Tesseract", context={"errors": eg.exceptions}) from eg diff --git a/kreuzberg/_xlsx.py b/kreuzberg/_xlsx.py index 187b73a..3c28884 100644 --- a/kreuzberg/_xlsx.py +++ b/kreuzberg/_xlsx.py @@ -1,23 +1,47 @@ from __future__ import annotations import csv +import sys +from functools import partial from io import StringIO -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING from anyio import Path as AsyncPath -from anyio import create_task_group from python_calamine import CalamineWorkbook from kreuzberg import ExtractionResult, ParsingError from kreuzberg._mime_types import MARKDOWN_MIME_TYPE from kreuzberg._pandoc import process_file_with_pandoc from kreuzberg._string import normalize_spaces -from kreuzberg._sync import run_sync +from kreuzberg._sync import run_sync, run_taskgroup from kreuzberg._tmp import create_temp_file if TYPE_CHECKING: # pragma: no cover from pathlib import Path +if sys.version_info < (3, 11): # pragma: no cover + from exceptiongroup import ExceptionGroup # type: ignore[import-not-found] + + +async def convert_sheet_to_text(workbook: CalamineWorkbook, sheet_name: str) -> str: + values = workbook.get_sheet_by_name(sheet_name).to_python() + + csv_buffer = StringIO() + writer = csv.writer(csv_buffer) + + for row in values: + writer.writerow(row) + + csv_data = csv_buffer.getvalue() + csv_buffer.close() + + csv_path, unlink = await create_temp_file(".csv") + await AsyncPath(csv_path).write_text(csv_data) + + result = await process_file_with_pandoc(csv_path, mime_type="text/csv") + await unlink() + return f"## {sheet_name}\n\n{normalize_spaces(result.content)}" + async def extract_xlsx_file(input_file: Path) -> ExtractionResult: """Extract text from an XLSX file by converting it to CSV and then to markdown. @@ -33,33 +57,9 @@ async def extract_xlsx_file(input_file: Path) -> ExtractionResult: """ try: workbook: CalamineWorkbook = await run_sync(CalamineWorkbook.from_path, str(input_file)) - - results = cast(list[str], [None] * len(workbook.sheet_names)) - - async def convert_sheet_to_text(sheet_name: str) -> None: - nonlocal results - values = workbook.get_sheet_by_name(sheet_name).to_python() - - csv_buffer = StringIO() - writer = csv.writer(csv_buffer) - - for row in values: - writer.writerow(row) - - csv_data = csv_buffer.getvalue() - csv_buffer.close() - - csv_path, unlink = await create_temp_file(".csv") - await AsyncPath(csv_path).write_text(csv_data) - - result = await process_file_with_pandoc(csv_path, mime_type="text/csv") - - results[workbook.sheet_names.index(sheet_name)] = f"## {sheet_name}\n\n{normalize_spaces(result.content)}" - await unlink() - - async with create_task_group() as tg: - for sheet_name in workbook.sheet_names: - tg.start_soon(convert_sheet_to_text, sheet_name) + results = await run_taskgroup( + *[partial(convert_sheet_to_text, workbook, sheet_name) for sheet_name in workbook.sheet_names] + ) return ExtractionResult( content="\n\n".join(results), @@ -69,8 +69,8 @@ async def convert_sheet_to_text(sheet_name: str) -> None: except ExceptionGroup as eg: raise ParsingError( "Failed to extract file data", - context={"file": str(input_file), "errors": ",".join([str(e) for e in eg.exceptions])}, - ) from eg.exceptions[0] + context={"file": str(input_file), "errors": eg.exceptions}, + ) from eg async def extract_xlsx_content(content: bytes) -> ExtractionResult: diff --git a/kreuzberg/exceptions.py b/kreuzberg/exceptions.py index 4c6a49f..ae0f2db 100644 --- a/kreuzberg/exceptions.py +++ b/kreuzberg/exceptions.py @@ -14,9 +14,28 @@ def __init__(self, message: str, *, context: Any = None) -> None: self.context = context super().__init__(message) + def _serialize_context(self, obj: Any) -> Any: + """Recursively serialize context objects to ensure JSON compatibility.""" + if isinstance(obj, bytes): + return obj.decode("utf-8", errors="replace") + if isinstance(obj, dict): + return {k: self._serialize_context(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [self._serialize_context(x) for x in obj] + if isinstance(obj, Exception): + return { + "type": obj.__class__.__name__, + "message": str(obj), + } + return obj + def __str__(self) -> str: """Return a string representation of the exception.""" - ctx = f"\n\nContext: {dumps(self.context)}" if self.context else "" + if self.context: + serialized_context = self._serialize_context(self.context) + ctx = f"\n\nContext: {dumps(serialized_context)}" + else: + ctx = "" return f"{self.__class__.__name__}: {super().__str__()}{ctx}" diff --git a/kreuzberg/extraction.py b/kreuzberg/extraction.py index d053db4..6aad397 100644 --- a/kreuzberg/extraction.py +++ b/kreuzberg/extraction.py @@ -92,7 +92,7 @@ async def extract_bytes( if mime_type in PANDOC_SUPPORTED_MIME_TYPES or any( mime_type.startswith(value) for value in PANDOC_SUPPORTED_MIME_TYPES ): - return await process_content_with_pandoc(content=content, mime_type=mime_type, max_processes=max_processes) + return await process_content_with_pandoc(content=content, mime_type=mime_type) if mime_type == POWER_POINT_MIME_TYPE or mime_type.startswith(POWER_POINT_MIME_TYPE): return await extract_pptx_file_content(content) @@ -153,7 +153,7 @@ async def extract_file( if mime_type in PANDOC_SUPPORTED_MIME_TYPES or any( mime_type.startswith(value) for value in PANDOC_SUPPORTED_MIME_TYPES ): - return await process_file_with_pandoc(input_file=input_file, mime_type=mime_type, max_processes=max_processes) + return await process_file_with_pandoc(input_file=input_file, mime_type=mime_type) if mime_type == POWER_POINT_MIME_TYPE or mime_type.startswith(POWER_POINT_MIME_TYPE): return await extract_pptx_file_content(Path(input_file)) diff --git a/pyproject.toml b/pyproject.toml index 7e2d9f6..231acf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ lint.per-file-ignores."tests/**/*.*" = [ "PT013", "RUF012", "S", + "SLF001", ] lint.isort.known-first-party = [ "kreuzberg", "tests" ] lint.mccabe.max-complexity = 15 diff --git a/tests/exceptions_test.py b/tests/exceptions_test.py new file mode 100644 index 0000000..1b970f0 --- /dev/null +++ b/tests/exceptions_test.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from kreuzberg.exceptions import KreuzbergError + + +def test_kreuzberg_error_serialize_context_with_bytes() -> None: + error = KreuzbergError("Test error", context={"data": b"test bytes"}) + serialized = error._serialize_context(error.context) + assert serialized == {"data": "test bytes"} diff --git a/tests/extraction_test.py b/tests/extraction_test.py index c5ba9c3..6dae158 100644 --- a/tests/extraction_test.py +++ b/tests/extraction_test.py @@ -228,6 +228,9 @@ def assert_extraction_result(result: ExtractionResult, *, mime_type: str) -> Non Args: result: The extraction result to check. mime_type: The expected mime type. + + Raises: + AssertionError: If the extraction result does not have the expected properties. """ assert isinstance(result.content, str) assert result.content.strip() diff --git a/tests/pandoc_test.py b/tests/pandoc_test.py index cb58a3f..35d6efb 100644 --- a/tests/pandoc_test.py +++ b/tests/pandoc_test.py @@ -1,15 +1,19 @@ from __future__ import annotations -import json -from pathlib import Path -from typing import TYPE_CHECKING, Any, cast -from unittest.mock import AsyncMock, Mock +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Coroutine + from unittest.mock import Mock import pytest from kreuzberg import ExtractionResult from kreuzberg._pandoc import ( MIMETYPE_TO_PANDOC_TYPE_MAPPING, + _extract_inline_text, + _extract_inlines, + _extract_meta_value, _get_pandoc_type_from_mime_type, _handle_extract_file, _handle_extract_metadata, @@ -17,10 +21,11 @@ process_content_with_pandoc, process_file_with_pandoc, ) -from kreuzberg._tmp import create_temp_file from kreuzberg.exceptions import MissingDependencyError, ParsingError, ValidationError if TYPE_CHECKING: + from collections.abc import Callable + from pytest_mock import MockerFixture SAMPLE_PANDOC_JSON = { @@ -31,280 +36,236 @@ @pytest.fixture -def mock_subprocess_run(mocker: MockerFixture) -> Mock: - def run_sync(command: list[str], **kwargs: Any) -> Mock: - result = Mock() - result.returncode = 0 - result.stderr = b"" - - if "--version" in command: - result.stdout = b"pandoc 3.1.0" - return result - - # Handle error test cases - if "test_process_file_error" in str(kwargs.get("cwd")): - result.returncode = 1 - result.stderr = b"Error processing file" - raise ParsingError("Error processing file", context={"error": "Error processing file"}) - - # Handle empty result test case - if "test_process_content_empty_result" in str(kwargs.get("cwd")): - result.returncode = 1 - result.stderr = b"Empty content" - raise ParsingError("Empty content", context={"error": "Empty content"}) - - # Handle metadata error test case - if "test_extract_metadata_error" in str(kwargs.get("cwd")): - result.returncode = 1 - result.stderr = b"Invalid metadata" - raise ParsingError("Invalid metadata", context={"error": "Invalid metadata"}) - - # Handle runtime error test case - if "test_process_file_runtime_error" in str(kwargs.get("cwd")): - raise RuntimeError("Command failed") - - # Normal case - output_file = next((str(arg) for arg in command if str(arg).endswith((".md", ".json"))), "") - if output_file: - content = ( - json.dumps(SAMPLE_PANDOC_JSON) if str(output_file).endswith(".json") else "Sample processed content" - ) - Path(output_file).write_text(content) - return result - - # Mock anyio.run_process - mock = mocker.patch("anyio.run_process", side_effect=run_sync) - return mock +def mock_run_process(mocker: MockerFixture) -> Mock: + return mocker.patch("kreuzberg._pandoc.run_process") @pytest.fixture -def mock_subprocess_run_invalid(mocker: MockerFixture) -> Mock: - def run_sync(command: list[str], **kwargs: Any) -> Mock: - result = Mock() - result.stdout = b"pandoc 1.0.0" - result.returncode = 0 - result.stderr = b"" - return result - - mock = mocker.patch("anyio.run_process", side_effect=run_sync) - return mock +def mock_version_check(mocker: MockerFixture) -> None: + mocker.patch("kreuzberg._pandoc.version_ref", {"checked": True}) @pytest.fixture -def mock_subprocess_run_error(mocker: MockerFixture) -> Mock: - def run_sync(command: list[str], **kwargs: Any) -> Mock: - raise FileNotFoundError - - mock = mocker.patch("anyio.run_process", side_effect=run_sync) - return mock - - -@pytest.fixture(autouse=True) -def reset_version_ref(mocker: MockerFixture) -> None: - mocker.patch("kreuzberg._pandoc.version_ref", {"checked": False}) +def mock_run_taskgroup(mocker: MockerFixture) -> Mock: + return mocker.patch("kreuzberg._pandoc.run_taskgroup") @pytest.mark.anyio -async def test_validate_pandoc_version(mock_subprocess_run: Mock) -> None: - await _validate_pandoc_version() - mock_subprocess_run.assert_called_with(["pandoc", "--version"]) +@pytest.mark.parametrize( + "major_version, should_raise", + [ + (1, True), + (2, False), + (3, False), + ], +) +async def test_validate_pandoc_version( + mocker: MockerFixture, mock_run_process: Mock, major_version: int, should_raise: bool +) -> None: + mocker.patch("kreuzberg._pandoc.version_ref", {"checked": False}) + mock_run_process.return_value.returncode = 0 + mock_run_process.return_value.stderr = b"" + mock_run_process.return_value.stdout = f"pandoc {major_version}.1.0".encode() -@pytest.mark.anyio -async def test_validate_pandoc_version_invalid(mock_subprocess_run_invalid: Mock) -> None: - with pytest.raises(MissingDependencyError, match="Pandoc version 3 or above is required"): + if should_raise: + with pytest.raises(MissingDependencyError): + await _validate_pandoc_version() + else: await _validate_pandoc_version() + mock_run_process.assert_called_with(["pandoc", "--version"]) -@pytest.mark.anyio -async def test_validate_pandoc_version_missing(mock_subprocess_run_error: Mock) -> None: - with pytest.raises(MissingDependencyError, match="Pandoc is not installed"): - await _validate_pandoc_version() +@pytest.mark.parametrize( + "node, expected_output", + [ + ({"t": "Str", "c": "Hello"}, "Hello"), + ({"t": "Space", "c": " "}, " "), + ({"t": "Emph", "c": [{"t": "Str", "c": "Emphasized"}]}, "Emphasized"), + ], +) +def test_extract_inline_text(node: dict[str, Any], expected_output: str) -> None: + assert _extract_inline_text(node) == expected_output -@pytest.mark.anyio -async def test_get_pandoc_type_from_mime_type_valid() -> None: - for mime_type in MIMETYPE_TO_PANDOC_TYPE_MAPPING: - extension = _get_pandoc_type_from_mime_type(mime_type) - assert isinstance(extension, str) - assert extension +@pytest.mark.parametrize( + "nodes, expected_output", + [ + ([{"t": "Str", "c": "Hello"}, {"t": "Space", "c": " "}, {"t": "Str", "c": "World"}], "Hello World"), + ([{"t": "Emph", "c": [{"t": "Str", "c": "Emphasized"}]}], "Emphasized"), + ], +) +def test_extract_inlines(nodes: list[dict[str, Any]], expected_output: str) -> None: + assert _extract_inlines(nodes) == expected_output -@pytest.mark.anyio -async def test_get_pandoc_type_from_mime_type_invalid() -> None: - with pytest.raises(ValidationError, match="Unsupported mime type"): - _get_pandoc_type_from_mime_type("invalid/mime-type") +@pytest.mark.parametrize( + "node, expected_output", + [ + ({"t": "MetaString", "c": "Test String"}, "Test String"), + ({"t": "MetaInlines", "c": [{"t": "Str", "c": "Inline String"}]}, "Inline String"), + ({"t": "MetaList", "c": [{"t": "MetaString", "c": "List Item"}]}, ["List Item"]), + ], +) +def test_extract_meta_value(node: Any, expected_output: Any) -> None: + assert _extract_meta_value(node) == expected_output -@pytest.mark.anyio -async def test_process_file_success(mock_subprocess_run: Mock, docx_document: Path) -> None: - result = await process_file_with_pandoc( - docx_document, mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ) - assert isinstance(result, ExtractionResult) - assert result.content.strip() == "Sample processed content" +@pytest.mark.parametrize("mime_type, expected_type", MIMETYPE_TO_PANDOC_TYPE_MAPPING.items()) +def test_get_pandoc_type_from_mime_type_all_mappings(mime_type: str, expected_type: str) -> None: + assert _get_pandoc_type_from_mime_type(mime_type) == expected_type -@pytest.mark.anyio -async def test_process_file_error(mock_subprocess_run: Mock, docx_document: Path) -> None: - def side_effect(*args: list[Any], **_: Any) -> Mock: - if args[0][0] == "pandoc" and "--version" in args[0]: - mock_subprocess_run.return_value.stdout = b"pandoc 3.1.0" - return cast(Mock, mock_subprocess_run.return_value) - mock_subprocess_run.return_value.returncode = 1 - mock_subprocess_run.return_value.stderr = b"Error processing file" - raise RuntimeError("Error processing file") - - mock_subprocess_run.side_effect = side_effect - with pytest.raises(ParsingError, match="Failed to extract file data"): - await process_file_with_pandoc( - docx_document, mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ) +# Mock the Pandoc version check +@pytest.fixture(autouse=True) +def mock_pandoc_version(mocker: MockerFixture) -> None: + # Mock the version_ref to avoid version check + mocker.patch("kreuzberg._pandoc.version_ref", {"checked": True}) -@pytest.mark.anyio -async def test_process_content_success(mock_subprocess_run: Mock) -> None: - result = await process_content_with_pandoc( - b"test content", mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ) - assert isinstance(result, ExtractionResult) - assert result.content.strip() == "Sample processed content" +@pytest.fixture +def mock_temp_file(mocker: MockerFixture) -> None: + # Mock create_temp_file to return a predictable path + async def mock_create(_: Any) -> tuple[str, Callable[[], Coroutine[None, None, None]]]: + async def mock_unlink() -> None: + pass -@pytest.mark.anyio -async def test_extract_metadata_error(mock_subprocess_run: Mock, docx_document: Path) -> None: - def side_effect(*args: list[Any], **_: Any) -> Mock: - if args[0][0] == "pandoc" and "--version" in args[0]: - mock_subprocess_run.return_value.stdout = b"pandoc 3.1.0" - return cast(Mock, mock_subprocess_run.return_value) - mock_subprocess_run.return_value.returncode = 1 - mock_subprocess_run.return_value.stderr = b"Error extracting metadata" - raise RuntimeError("Error extracting metadata") - - mock_subprocess_run.side_effect = side_effect - with pytest.raises(ParsingError, match="Failed to extract file data"): - await _handle_extract_metadata( - docx_document, mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ) + return "/tmp/test", mock_unlink + mocker.patch("kreuzberg._pandoc.create_temp_file", side_effect=mock_create) -@pytest.mark.anyio -async def test_extract_metadata_runtime_error(mock_subprocess_run: Mock, docx_document: Path) -> None: - mock_subprocess_run.side_effect = RuntimeError("Command failed") - with pytest.raises(ParsingError, match="Failed to extract file data"): - await _handle_extract_metadata( - docx_document, mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ) +@pytest.fixture +def mock_async_path(mocker: MockerFixture) -> None: + # Mock AsyncPath operations + mock_path = mocker.patch("kreuzberg._pandoc.AsyncPath") + mock_path.return_value.read_text = mocker.AsyncMock(return_value="Test content") + mock_path.return_value.write_bytes = mocker.AsyncMock() @pytest.mark.anyio -async def test_integration_validate_pandoc_version() -> None: - await _validate_pandoc_version() +async def test_handle_extract_file(mock_run_process: Mock) -> None: + mock_run_process.return_value.returncode = 0 + mock_run_process.return_value.stdout = b"Test content" + + result = await _handle_extract_file("dummy_file", mime_type="application/csl+json") + assert isinstance(result, str) @pytest.mark.anyio -async def test_integration_process_file(markdown_document: Path) -> None: - result = await process_file_with_pandoc(markdown_document, mime_type="text/x-markdown") +async def test_process_file_with_pandoc(mock_version_check: Mock, mock_run_taskgroup: Mock) -> None: + mock_run_taskgroup.return_value = ({"title": "Test Document"}, "Test Content") + + result = await process_file_with_pandoc("dummy_file", mime_type="application/csl+json") assert isinstance(result, ExtractionResult) - assert isinstance(result.content, str) - assert result.content.strip() + assert result.metadata["title"] == "Test Document" + assert result.content == "Test Content" @pytest.mark.anyio -async def test_integration_process_content() -> None: - content = b"# Test\nThis is a test file." - result = await process_content_with_pandoc(content, mime_type="text/x-markdown") +async def test_process_content_with_pandoc(mock_version_check: Mock, mock_run_taskgroup: Mock) -> None: + mock_run_taskgroup.return_value = ({"title": "Test Document"}, "Test Content") + result = await process_content_with_pandoc(b"Test Content", mime_type="application/csl+json") assert isinstance(result, ExtractionResult) - assert isinstance(result.content, str) - assert result.content.strip() + assert result.metadata["title"] == "Test Document" + assert result.content == "Test Content" @pytest.mark.anyio -async def test_integration_extract_metadata(markdown_document: Path) -> None: - result = await _handle_extract_metadata(markdown_document, mime_type="text/x-markdown") - assert isinstance(result, dict) +async def test_validate_pandoc_version_file_not_found(mocker: MockerFixture) -> None: + mocker.patch("kreuzberg._pandoc.version_ref", {"checked": False}) + mock_run = mocker.patch("kreuzberg._pandoc.run_process") + mock_run.side_effect = FileNotFoundError() + + with pytest.raises(MissingDependencyError, match="Pandoc is not installed"): + await _validate_pandoc_version() @pytest.mark.anyio -async def test_process_file_runtime_error(mock_subprocess_run: Mock, docx_document: Path) -> None: - def side_effect(*args: list[Any], **_: Any) -> Mock: - if args[0][0] == "pandoc" and "--version" in args[0]: - mock_subprocess_run.return_value.stdout = b"pandoc 3.1.0" - return cast(Mock, mock_subprocess_run.return_value) - raise RuntimeError("Pandoc error") - - mock_subprocess_run.side_effect = side_effect - with pytest.raises(ParsingError, match="Failed to extract file data"): - await process_file_with_pandoc( - docx_document, mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ) +async def test_validate_pandoc_version_invalid_output(mocker: MockerFixture) -> None: + mocker.patch("kreuzberg._pandoc.version_ref", {"checked": False}) + mock_run = mocker.patch("kreuzberg._pandoc.run_process") + mock_run.return_value.stdout = b"invalid version output" + + with pytest.raises(MissingDependencyError, match="MissingDependencyError: Pandoc version 2 or above is required"): + await _validate_pandoc_version() @pytest.mark.anyio -async def test_process_content_empty_result(mock_subprocess_run: Mock) -> None: - def side_effect(*args: list[Any], **_: Any) -> Mock: - if args[0][0] == "pandoc" and "--version" in args[0]: - mock_subprocess_run.return_value.stdout = b"pandoc 3.1.0" - return cast(Mock, mock_subprocess_run.return_value) - output_file = next((str(arg) for arg in args[0] if str(arg).endswith((".md", ".json"))), "") - if output_file: - if str(output_file).endswith(".json"): - Path(output_file).write_text('{"pandoc-api-version":[1,22,2,1],"meta":{},"blocks":[]}') - else: - Path(output_file).write_text("") - mock_subprocess_run.return_value.returncode = 0 - return cast(Mock, mock_subprocess_run.return_value) - raise RuntimeError("Empty content") - - mock_subprocess_run.side_effect = side_effect - result = await process_content_with_pandoc( - b"content", mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ) - assert isinstance(result, ExtractionResult) - assert result.content == "" - assert result.metadata == {} +async def test_validate_pandoc_version_parse_error(mocker: MockerFixture) -> None: + mocker.patch("kreuzberg._pandoc.version_ref", {"checked": False}) + mock_run = mocker.patch("kreuzberg._pandoc.run_process") + mock_run.return_value.stdout = b"pandoc abc" + + with pytest.raises(MissingDependencyError, match="Pandoc version 2 or above is required"): + await _validate_pandoc_version() @pytest.mark.anyio -async def test_process_file_invalid_mime_type(mock_subprocess_run: Mock, docx_document: Path) -> None: - with pytest.raises(ValidationError, match="Unsupported mime type"): - await process_file_with_pandoc(docx_document, mime_type="invalid/mime-type") +async def test_handle_extract_metadata_runtime_error( + mock_run_process: Mock, mock_temp_file: None, mock_async_path: None +) -> None: + mock_run_process.side_effect = RuntimeError("Runtime error") + + with pytest.raises(ParsingError, match="Failed to extract file data"): + await _handle_extract_metadata("dummy_file", mime_type="application/csl+json") @pytest.mark.anyio -async def test_process_content_invalid_mime_type(mock_subprocess_run: Mock) -> None: - with pytest.raises(ValidationError, match="Unsupported mime type"): - await process_content_with_pandoc(b"content", mime_type="invalid/mime-type") +async def test_handle_extract_file_runtime_error( + mock_run_process: Mock, mock_temp_file: None, mock_async_path: None +) -> None: + mock_run_process.side_effect = RuntimeError("Runtime error") + + with pytest.raises(ParsingError, match="Failed to extract file data"): + await _handle_extract_file("dummy_file", mime_type="application/csl+json") @pytest.mark.anyio -async def test_handle_extract_metadata_os_error( - mock_subprocess_run: Mock, mocker: MockerFixture, docx_document: Path +async def test_process_content_with_pandoc_runtime_error( + mock_version_check: None, mock_temp_file: None, mock_async_path: None, mock_run_process: Mock ) -> None: - await create_temp_file(".json") - mock_path = Mock(read_text=AsyncMock(side_effect=OSError)) + mock_run_process.side_effect = RuntimeError("Runtime error") - mocker.patch("kreuzberg._pandoc.AsyncPath", return_value=mock_path) - with pytest.raises(ParsingError) as exc_info: - await _handle_extract_metadata( - docx_document, mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ) + with pytest.raises(ParsingError, match="Failed to process file"): + await process_content_with_pandoc(b"Test content", mime_type="application/csl+json") - assert "Failed to extract file data" in str(exc_info.value) + +def test_get_pandoc_type_unsupported_mime() -> None: + with pytest.raises(ValidationError, match="Unsupported mime type: invalid/type"): + _get_pandoc_type_from_mime_type("invalid/type") + + +def test_get_pandoc_type_prefix_match() -> None: + assert _get_pandoc_type_from_mime_type("application/csl") == "csljson" @pytest.mark.anyio -async def test_handle_extract_file_os_error( - mock_subprocess_run: Mock, mocker: MockerFixture, docx_document: Path +async def test_handle_extract_metadata_error( + mock_run_process: Mock, mock_temp_file: None, mock_async_path: None ) -> None: - await create_temp_file(".md") - mock_path = Mock(read_text=AsyncMock(side_effect=OSError)) + mock_run_process.return_value.returncode = 1 + mock_run_process.return_value.stderr = b"Error processing file" + + with pytest.raises(ParsingError, match="Failed to extract file data"): + await _handle_extract_metadata("dummy_file", mime_type="application/csl+json") + + +@pytest.mark.anyio +async def test_handle_extract_file_error(mock_run_process: Mock, mock_temp_file: None, mock_async_path: None) -> None: + mock_run_process.return_value.returncode = 1 + mock_run_process.return_value.stderr = b"Error processing file" + + with pytest.raises(ParsingError, match="Failed to extract file data"): + await _handle_extract_file("dummy_file", mime_type="application/csl+json") - mocker.patch("kreuzberg._pandoc.AsyncPath", return_value=mock_path) - with pytest.raises(ParsingError) as exc_info: - await _handle_extract_file( - docx_document, mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ) - assert "Failed to extract file data" in str(exc_info.value) +@pytest.mark.anyio +async def test_process_content_with_pandoc_error( + mock_version_check: None, mock_temp_file: None, mock_async_path: None, mock_run_process: Mock +) -> None: + mock_run_process.return_value.returncode = 1 + mock_run_process.return_value.stderr = b"Error processing file" + with pytest.raises(ValidationError, match="Unsupported mime type: invalid/type"): + await process_content_with_pandoc(b"Invalid content", mime_type="invalid/type") diff --git a/tests/string_test.py b/tests/string_test.py index f875fbf..5c49665 100644 --- a/tests/string_test.py +++ b/tests/string_test.py @@ -20,6 +20,12 @@ def test_safe_decode(byte_data: bytes, encoding: str | None, expected: str) -> N assert safe_decode(byte_data, encoding) == expected +def test_safe_decode_with_invalid_encoding_value() -> None: + test_bytes = b"Hello World" + result = safe_decode(test_bytes, encoding="invalid-encoding") + assert result == "Hello World" + + def test_safe_decode_with_detected_encoding() -> None: text = "Hello 世界" byte_data = text.encode("utf-8") diff --git a/tests/tesseract_test.py b/tests/tesseract_test.py index a655276..812c20d 100644 --- a/tests/tesseract_test.py +++ b/tests/tesseract_test.py @@ -25,7 +25,7 @@ @pytest.fixture -def mock_subprocess_run(mocker: MockerFixture) -> Mock: +def mock_run_process(mocker: MockerFixture) -> Mock: def run_sync(command: list[str], **kwargs: Any) -> Mock: result = Mock() result.stdout = b"tesseract 5.0.0" @@ -64,13 +64,11 @@ def run_sync(command: list[str], **kwargs: Any) -> Mock: return result - # Mock run_process - mock = mocker.patch("kreuzberg._tesseract.run_process", side_effect=run_sync) - return mock + return mocker.patch("kreuzberg._tesseract.run_process", side_effect=run_sync) @pytest.fixture -def mock_subprocess_run_invalid(mocker: MockerFixture) -> Mock: +def mock_run_process_invalid(mocker: MockerFixture) -> Mock: def run_sync(command: list[str], **kwargs: Any) -> Mock: result = Mock() result.stdout = b"tesseract 4.0.0" @@ -78,23 +76,21 @@ def run_sync(command: list[str], **kwargs: Any) -> Mock: result.stderr = b"" return result - mock = mocker.patch("kreuzberg._tesseract.run_process", side_effect=run_sync) - return mock + return mocker.patch("kreuzberg._tesseract.run_process", side_effect=run_sync) @pytest.fixture -def mock_subprocess_run_error(mocker: MockerFixture) -> Mock: +def mock_run_process_error(mocker: MockerFixture) -> Mock: def run_sync(command: list[str], **kwargs: Any) -> Mock: raise FileNotFoundError - mock = mocker.patch("kreuzberg._tesseract.run_process", side_effect=run_sync) - return mock + return mocker.patch("kreuzberg._tesseract.run_process", side_effect=run_sync) @pytest.mark.anyio -async def test_validate_tesseract_version(mock_subprocess_run: Mock) -> None: +async def test_validate_tesseract_version(mock_run_process: Mock) -> None: await validate_tesseract_version() - mock_subprocess_run.assert_called_with(["tesseract", "--version"]) + mock_run_process.assert_called_with(["tesseract", "--version"]) @pytest.fixture(autouse=True) @@ -103,49 +99,49 @@ def reset_version_ref(mocker: MockerFixture) -> None: @pytest.mark.anyio -async def test_validate_tesseract_version_invalid(mock_subprocess_run_invalid: Mock, reset_version_ref: None) -> None: +async def test_validate_tesseract_version_invalid(mock_run_process_invalid: Mock, reset_version_ref: None) -> None: with pytest.raises(MissingDependencyError, match="Tesseract version 5 or above is required"): await validate_tesseract_version() @pytest.mark.anyio -async def test_validate_tesseract_version_missing(mock_subprocess_run_error: Mock, reset_version_ref: None) -> None: +async def test_validate_tesseract_version_missing(mock_run_process_error: Mock, reset_version_ref: None) -> None: with pytest.raises(MissingDependencyError, match="Tesseract is not installed"): await validate_tesseract_version() @pytest.mark.anyio -async def test_process_file(mock_subprocess_run: Mock, ocr_image: Path) -> None: +async def test_process_file(mock_run_process: Mock, ocr_image: Path) -> None: result = await process_file(ocr_image, language="eng", psm=PSMMode.AUTO) assert isinstance(result, ExtractionResult) assert result.content.strip() == "Sample OCR text" @pytest.mark.anyio -async def test_process_file_with_options(mock_subprocess_run: Mock, ocr_image: Path) -> None: +async def test_process_file_with_options(mock_run_process: Mock, ocr_image: Path) -> None: result = await process_file(ocr_image, language="eng", psm=PSMMode.AUTO) assert isinstance(result, ExtractionResult) assert result.content.strip() == "Sample OCR text" @pytest.mark.anyio -async def test_process_file_error(mock_subprocess_run: Mock, ocr_image: Path) -> None: - mock_subprocess_run.return_value.returncode = 1 - mock_subprocess_run.return_value.stderr = b"Error processing file" - mock_subprocess_run.side_effect = None +async def test_process_file_error(mock_run_process: Mock, ocr_image: Path) -> None: + mock_run_process.return_value.returncode = 1 + mock_run_process.return_value.stderr = b"Error processing file" + mock_run_process.side_effect = None with pytest.raises(OCRError, match="OCR failed with a non-0 return code"): await process_file(ocr_image, language="eng", psm=PSMMode.AUTO) @pytest.mark.anyio -async def test_process_file_runtime_error(mock_subprocess_run: Mock, ocr_image: Path) -> None: - mock_subprocess_run.side_effect = RuntimeError() +async def test_process_file_runtime_error(mock_run_process: Mock, ocr_image: Path) -> None: + mock_run_process.side_effect = RuntimeError() with pytest.raises(OCRError, match="Failed to OCR using tesseract"): await process_file(ocr_image, language="eng", psm=PSMMode.AUTO) @pytest.mark.anyio -async def test_process_image(mock_subprocess_run: Mock) -> None: +async def test_process_image(mock_run_process: Mock) -> None: image = Image.new("RGB", (100, 100)) result = await process_image(image, language="eng", psm=PSMMode.AUTO) assert isinstance(result, ExtractionResult) @@ -153,7 +149,7 @@ async def test_process_image(mock_subprocess_run: Mock) -> None: @pytest.mark.anyio -async def test_process_image_with_tesseract_pillow(mock_subprocess_run: Mock) -> None: +async def test_process_image_with_tesseract_pillow(mock_run_process: Mock) -> None: image = Image.new("RGB", (100, 100)) result = await process_image_with_tesseract(image) assert isinstance(result, ExtractionResult) @@ -161,7 +157,7 @@ async def test_process_image_with_tesseract_pillow(mock_subprocess_run: Mock) -> @pytest.mark.anyio -async def test_process_image_with_tesseract_path(mock_subprocess_run: Mock, ocr_image: Path) -> None: +async def test_process_image_with_tesseract_path(mock_run_process: Mock, ocr_image: Path) -> None: result = await process_image_with_tesseract(ocr_image) assert isinstance(result, ExtractionResult) assert result.content.strip() == "Sample OCR text" @@ -174,7 +170,7 @@ async def test_process_image_with_tesseract_invalid_input() -> None: @pytest.mark.anyio -async def test_batch_process_images_pillow(mock_subprocess_run: Mock) -> None: +async def test_batch_process_images_pillow(mock_run_process: Mock) -> None: images = [Image.new("RGB", (100, 100)) for _ in range(3)] results = await batch_process_images(images, language="eng", psm=PSMMode.AUTO, max_processes=1) assert isinstance(results, list) @@ -183,7 +179,7 @@ async def test_batch_process_images_pillow(mock_subprocess_run: Mock) -> None: @pytest.mark.anyio -async def test_batch_process_images_paths(mock_subprocess_run: Mock, ocr_image: Path) -> None: +async def test_batch_process_images_paths(mock_run_process: Mock, ocr_image: Path) -> None: images = [str(ocr_image)] * 3 results = await batch_process_images(images, language="eng", psm=PSMMode.AUTO, max_processes=1) assert isinstance(results, list) @@ -192,7 +188,7 @@ async def test_batch_process_images_paths(mock_subprocess_run: Mock, ocr_image: @pytest.mark.anyio -async def test_batch_process_images_mixed(mock_subprocess_run: Mock, ocr_image: Path) -> None: +async def test_batch_process_images_mixed(mock_run_process: Mock, ocr_image: Path) -> None: images: list[Image.Image | PathLike[str] | str] = [ Image.new("RGB", (100, 100)), str(ocr_image), @@ -283,15 +279,31 @@ async def test_integration_batch_process_images_mixed(ocr_image: Path) -> None: @pytest.mark.anyio -async def test_batch_process_images_exception_group(mock_subprocess_run: Mock) -> None: +async def test_batch_process_images_exception_group(mock_run_process: Mock) -> None: def side_effect(*args: list[Any], **kwargs: dict[str, Any]) -> Mock: if args[0][0] == "tesseract" and "--version" in args[0]: - mock_subprocess_run.return_value.stdout = b"tesseract 5.0.0" - return cast(Mock, mock_subprocess_run.return_value) + mock_run_process.return_value.stdout = b"tesseract 5.0.0" + return cast(Mock, mock_run_process.return_value) raise RuntimeError("Tesseract error") - mock_subprocess_run.side_effect = side_effect + mock_run_process.side_effect = side_effect image = Image.new("RGB", (100, 100)) with pytest.raises(ParsingError, match="Failed to process images with Tesseract"): await batch_process_images([image], language="eng", psm=PSMMode.AUTO, max_processes=1) + + +@pytest.mark.anyio +async def test_process_file_linux(mocker: MockerFixture) -> None: + # Mock sys.platform to simulate Linux + mocker.patch("sys.platform", "linux") + + mock_run = mocker.patch("kreuzberg._tesseract.run_process") + mock_run.return_value.returncode = 0 + mock_run.return_value.stdout = b"test output" + + await process_file("test.png", language="eng", psm=PSMMode.AUTO) + + # Verify that OMP_THREAD_LIMIT was set for Linux + mock_run.assert_called_once() + assert mock_run.call_args[1]["env"] == {"OMP_THREAD_LIMIT": "1"} diff --git a/tests/xlsx_test.py b/tests/xlsx_test.py index 5da35f1..7e25f45 100644 --- a/tests/xlsx_test.py +++ b/tests/xlsx_test.py @@ -2,14 +2,23 @@ from __future__ import annotations -from pathlib import Path +import sys +from typing import TYPE_CHECKING import pytest -from kreuzberg import ExtractionResult +from kreuzberg import ExtractionResult, ParsingError from kreuzberg._mime_types import MARKDOWN_MIME_TYPE from kreuzberg._xlsx import extract_xlsx_file -from kreuzberg.exceptions import ParsingError + +if TYPE_CHECKING: + from pathlib import Path + + from pytest_mock import MockerFixture + + +if sys.version_info < (3, 11): # pragma: no cover + from exceptiongroup import ExceptionGroup # type: ignore[import-not-found] @pytest.mark.anyio @@ -21,15 +30,6 @@ async def test_extract_xlsx_file(excel_document: Path) -> None: assert result.mime_type == "text/markdown" -@pytest.mark.anyio -async def test_extract_xlsx_file_invalid() -> None: - """Test that attempting to extract from an invalid Excel file raises an error.""" - with pytest.raises(ParsingError) as exc_info: - await extract_xlsx_file(Path("/invalid/path.xlsx")) - - assert "Could not extract text from XLSX" in str(exc_info.value) - - @pytest.mark.anyio async def test_extract_xlsx_multi_sheet_file(excel_multi_sheet_document: Path) -> None: """Test extracting text from an Excel file with multiple sheets.""" @@ -63,3 +63,17 @@ async def test_extract_xlsx_multi_sheet_file(excel_multi_sheet_document: Path) - assert "Beetroot" in second_sheet_content assert "1.0" in second_sheet_content assert "2.0" in second_sheet_content + + +@pytest.mark.anyio +async def test_extract_xlsx_file_exception_group(mocker: MockerFixture, excel_multi_sheet_document: Path) -> None: + # Mock openpyxl to raise multiple exceptions + mock_load = mocker.patch("kreuzberg._xlsx.run_taskgroup") + exceptions = [ValueError("Error 1"), ValueError("Error 2")] + mock_load.side_effect = ExceptionGroup("test group", exceptions) + + with pytest.raises(ParsingError) as exc_info: + await extract_xlsx_file(excel_multi_sheet_document) + + assert "Failed to extract file data" in str(exc_info.value) + assert len(exc_info.value.context["errors"]) == 2