From 5e347bf99b068937047262debe369c2623b5995c Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Mon, 17 Feb 2025 20:51:34 +0100 Subject: [PATCH] chore: updated test output --- .github/workflows/ci.yaml | 4 +- .pre-commit-config.yaml | 2 +- Dockerfile | 18 ++++++ README.md | 7 ++- docker-compose.yaml | 11 ++++ kreuzberg/_constants.py | 2 +- kreuzberg/_html.py | 3 +- kreuzberg/_pandoc.py | 9 ++- kreuzberg/_pdf.py | 2 +- kreuzberg/_sync.py | 42 +++++++++++++ kreuzberg/_tesseract.py | 125 +++++++++++++++----------------------- kreuzberg/_xlsx.py | 2 +- kreuzberg/extraction.py | 6 +- tests/conftest.py | 6 ++ tests/extraction_test.py | 9 +-- tests/pandoc_test.py | 30 +++++---- tests/tesseract_test.py | 47 ++++++++------ 17 files changed, 191 insertions(+), 134 deletions(-) create mode 100644 Dockerfile create mode 100644 docker-compose.yaml diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5a8718f..cb541c9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -59,7 +59,7 @@ jobs: #os: [ ubuntu-latest, macOS-latest, windows-latest ] #python: [" 3.9", "3.10", "3.11", "3.12", "3.13" ] runs-on: ${{ matrix.os }} - timeout-minutes: 15 + timeout-minutes: 20 steps: - name: Checkout uses: actions/checkout@v4 @@ -122,4 +122,4 @@ jobs: pandoc --version - name: Run Tests - run: uv run pytest tests -vvv -n auto --dist=loadfile --timeout 30 + run: uv run pytest -n auto diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b0a477c..e701c90 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,7 @@ repos: additional_dependencies: - tomli - repo: https://github.com/jsh9/pydoclint - rev: 0.6.0 + rev: 0.6.2 hooks: - id: pydoclint args: diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..36c1de5 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,18 @@ +FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS base +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + g++ \ + libpq-dev \ + pandoc \ + tesseract-ocr \ + tesseract-ocr-deu \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +FROM base AS install +WORKDIR /app/ +COPY pyproject.toml uv.lock ./ +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=uv.lock,target=uv.lock \ + --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ + uv sync --verbose --frozen +ENV PATH="/app/.venv/bin:$PATH" diff --git a/README.md b/README.md index b3e3231..966e51a 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,8 @@ pip install kreuzberg Kreuzberg requires two system level dependencies: -- [Pandoc](https://pandoc.org/installing.html) - For document format conversion -- [Tesseract OCR](https://tesseract-ocr.github.io/) - For image and PDF OCR +- [Pandoc](https://pandoc.org/installing.html) - For document format conversion. Minimum required version is Pandoc 2. +- [Tesseract OCR](https://tesseract-ocr.github.io/) - For image and PDF OCR. Minimum required version is Tesseract 4. You can install these with: @@ -40,7 +40,7 @@ sudo apt-get install pandoc tesseract-ocr #### MacOS ```shell -# MacOS +# brew install tesseract pandoc ``` @@ -54,6 +54,7 @@ Notes: - in most distributions the tesseract-ocr package is split into multiple packages, you may need to install any language models you need aside from English separately. - please consult the official documentation for these libraries for the most up-to-date installation instructions for your platform. +- th ## Architecture diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..e30fc5d --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,11 @@ +services: + kreuzberg: + build: + context: . + dockerfile: Dockerfile + target: install + ports: + - "8000:8000" + volumes: + - ./kreuzberg:/src/kreuzberg:cached + - ./tests:/src/tests:cached diff --git a/kreuzberg/_constants.py b/kreuzberg/_constants.py index fc9dab0..5b7f0c1 100644 --- a/kreuzberg/_constants.py +++ b/kreuzberg/_constants.py @@ -3,4 +3,4 @@ from multiprocessing import cpu_count from typing import Final -DEFAULT_MAX_PROCESSES: Final[int] = max(cpu_count() // 2, 1) +DEFAULT_MAX_PROCESSES: Final[int] = cpu_count() diff --git a/kreuzberg/_html.py b/kreuzberg/_html.py index 69e2a0d..f28c204 100644 --- a/kreuzberg/_html.py +++ b/kreuzberg/_html.py @@ -8,7 +8,6 @@ from kreuzberg import ExtractionResult from kreuzberg._mime_types import MARKDOWN_MIME_TYPE from kreuzberg._string import normalize_spaces, safe_decode -from kreuzberg._sync import run_sync if TYPE_CHECKING: from pathlib import Path @@ -28,5 +27,5 @@ async def extract_html_string(file_path_or_contents: Path | bytes) -> Extraction if isinstance(file_path_or_contents, bytes) else await AsyncPath(file_path_or_contents).read_text() ) - result = await run_sync(html_to_markdown.convert_to_markdown, content) + result = html_to_markdown.convert_to_markdown(content) return ExtractionResult(content=normalize_spaces(result), mime_type=MARKDOWN_MIME_TYPE, metadata={}) diff --git a/kreuzberg/_pandoc.py b/kreuzberg/_pandoc.py index 39ab2ae..0cb2fe7 100644 --- a/kreuzberg/_pandoc.py +++ b/kreuzberg/_pandoc.py @@ -6,13 +6,12 @@ from json import JSONDecodeError, loads from typing import TYPE_CHECKING, Any, Final, Literal, cast -from anyio import CapacityLimiter, create_task_group, to_process +from anyio import CapacityLimiter, create_task_group, run_process, to_process from anyio import Path as AsyncPath from kreuzberg._constants import DEFAULT_MAX_PROCESSES from kreuzberg._mime_types import MARKDOWN_MIME_TYPE from kreuzberg._string import normalize_spaces -from kreuzberg._sync import run_sync from kreuzberg._tmp import create_temp_file from kreuzberg._types import ExtractionResult, Metadata from kreuzberg.exceptions import MissingDependencyError, ParsingError, ValidationError @@ -251,10 +250,10 @@ async def _validate_pandoc_version() -> None: return command = ["pandoc", "--version"] - result = await run_sync(subprocess.run, command, capture_output=True) + result = await run_process(command) version = result.stdout.decode().split("\n")[0].split()[1] - if not version.startswith("3."): - raise MissingDependencyError("Pandoc version 3 or above is required.") + if not version.startswith("2."): + raise MissingDependencyError("Pandoc version 2 or above is required.") version_ref["checked"] = True diff --git a/kreuzberg/_pdf.py b/kreuzberg/_pdf.py index 6c51e04..c698f06 100644 --- a/kreuzberg/_pdf.py +++ b/kreuzberg/_pdf.py @@ -67,7 +67,7 @@ async def _convert_pdf_to_images(input_file: Path) -> list[Image]: document: pypdfium2.PdfDocument | None = None try: document = await run_sync(pypdfium2.PdfDocument, str(input_file)) - return [page.render(scale=2.0).to_pil() for page in cast(pypdfium2.PdfDocument, document)] + return [page.render(scale=4.25).to_pil() for page in cast(pypdfium2.PdfDocument, document)] except pypdfium2.PdfiumError as e: raise ParsingError( "Could not convert PDF to images", context={"file_path": str(input_file), "error": str(e)} diff --git a/kreuzberg/_sync.py b/kreuzberg/_sync.py index b0f0a66..4fc9a7a 100644 --- a/kreuzberg/_sync.py +++ b/kreuzberg/_sync.py @@ -1,9 +1,11 @@ from __future__ import annotations import sys +from collections.abc import Awaitable from functools import partial from typing import TYPE_CHECKING, TypeVar, cast +from anyio import create_task_group from anyio.to_thread import run_sync as any_io_run_sync if TYPE_CHECKING: # pragma: no cover @@ -31,3 +33,43 @@ async def run_sync(sync_fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) - """ handler = partial(sync_fn, **kwargs) return cast(T, await any_io_run_sync(handler, *args, abandon_on_cancel=True)) # pyright: ignore [reportCallIssue] + + +async def run_taskgroup(*coroutines: Callable[[], Awaitable[T]]) -> list[T]: + """Run a list of coroutines concurrently. + + Args: + coroutines: The list of coroutines to run. + + Returns: + The results of the coroutines. + """ + results = cast(list[T], [None] * len(coroutines)) + + async def run_task(index: int, task: Callable[[], Awaitable[T]]) -> None: + results[index] = await task() + + async with create_task_group() as tg: + for i, coro in enumerate(coroutines): + tg.start_soon(run_task, i, coro) + + return results + + +async def run_taskgroup_batched(*coroutines: Callable[[], Awaitable[T]], batch_size: int) -> list[T]: + """Run a list of coroutines concurrently in batches. + + Args: + coroutines: The list of coroutines to run. + batch_size: The size of each batch. + + Returns: + The results of the coroutines. + """ + results: list[T] = [] + + for i in range(0, len(coroutines), batch_size): + batch = coroutines[i : i + batch_size] + results.extend(await run_taskgroup(*batch)) + + return results diff --git a/kreuzberg/_tesseract.py b/kreuzberg/_tesseract.py index 5ad91c8..c76be5b 100644 --- a/kreuzberg/_tesseract.py +++ b/kreuzberg/_tesseract.py @@ -1,24 +1,20 @@ from __future__ import annotations import re -import subprocess import sys from enum import Enum from functools import partial -from io import BytesIO from os import PathLike -from typing import Final, TypeVar, Union, cast +from typing import Any, Final, TypeVar, Union -from anyio import CapacityLimiter, create_task_group, to_process from anyio import Path as AsyncPath -from PIL import ImageOps -from PIL.Image import Image, Resampling -from PIL.Image import open as open_image +from anyio import run_process +from PIL.Image import Image from kreuzberg._constants import DEFAULT_MAX_PROCESSES from kreuzberg._mime_types import PLAIN_TEXT_MIME_TYPE from kreuzberg._string import normalize_spaces -from kreuzberg._sync import run_sync +from kreuzberg._sync import run_sync, run_taskgroup_batched from kreuzberg._tmp import create_temp_file from kreuzberg._types import ExtractionResult from kreuzberg.exceptions import MissingDependencyError, OCRError, ParsingError @@ -28,12 +24,6 @@ MINIMAL_SUPPORTED_TESSERACT_VERSION: Final[int] = 5 - -DEFAULT_DPI: Final[int] = 72 -TARGET_DPI: Final[int] = 300 -BINARIZATION_THRESHOLD: Final[int] = 0 -BINARIZATION_MAX_VALUE: Final[int] = 255 - version_ref = {"checked": False} T = TypeVar("T", bound=Union[Image, PathLike[str], str]) @@ -66,33 +56,6 @@ class PSMMode(Enum): """Treat the image as a single character.""" -def resize_for_ocr(image: Image) -> Image: - """Resize the image to ensure sufficient DPI for OCR. - - Args: - image: Input Pillow image. - - Returns: - The resized image. - """ - width, height = image.size - scale_factor = TARGET_DPI / DEFAULT_DPI - new_size = (int(width * scale_factor), int(height * scale_factor)) - return image.resize(new_size, Resampling.LANCZOS) - - -def preprocess_image(image: Image) -> Image: - """Preprocess the input image for OCR. - - Args: - image: Input Pillow image. - - Returns: - The preprocessed version of the input image. - """ - return resize_for_ocr(ImageOps.grayscale(image)) - - async def validate_tesseract_version() -> None: """Validate that Tesseract is installed and is version 5 or above. @@ -104,14 +67,16 @@ async def validate_tesseract_version() -> None: return command = ["tesseract", "--version"] - result = await run_sync(subprocess.run, command, capture_output=True) - version_match = re.search(r"tesseract\s+v?(\d+)", result.stdout.decode()) + result = await run_process(command) + version_match = re.search(r"tesseract\s+v?(\d+)\.\d+\.\d+", result.stdout.decode()) if not version_match or int(version_match.group(1)) < MINIMAL_SUPPORTED_TESSERACT_VERSION: raise MissingDependencyError("Tesseract version 5 or above is required.") version_ref["checked"] = True except FileNotFoundError as e: - raise MissingDependencyError("Tesseract is not installed.") from e + raise MissingDependencyError( + "Tesseract is not installed or not in path. Please install tesseract 5 and above on your system." + ) from e async def process_file( @@ -119,7 +84,6 @@ async def process_file( *, language: str, psm: PSMMode, - max_processes: int = DEFAULT_MAX_PROCESSES, ) -> ExtractionResult: """Process a single image file using Tesseract OCR. @@ -127,7 +91,6 @@ async def process_file( input_file: The path to the image file to process. language: The language code for OCR. psm: Page segmentation mode. - max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1). Raises: OCRError: If OCR fails to extract text from the image. @@ -138,6 +101,7 @@ async def process_file( output_path, unlink = await create_temp_file(".txt") try: output_base = str(output_path).replace(".txt", "") + command = [ "tesseract", str(input_file), @@ -146,14 +110,33 @@ async def process_file( language, "--psm", str(psm.value), + "--oem", + "1", + "--loglevel", + "OFF", + "-c", + "thresholding_method=1", + "-c", + "tessedit_enable_dict_correction=1", + "-c", + "language_model_ngram_on=1", + "-c", + "textord_space_size_is_variable=1", + "-c", + "classify_use_pre_adapted_templates=1", + "-c", + "tessedit_dont_blkrej_good_wds=1", + "-c", + "tessedit_dont_rowrej_good_wds=1", + "-c", + "tessedit_use_primary_params_model=1", ] - result = await to_process.run_sync( - partial(subprocess.run, capture_output=True), - command, - limiter=CapacityLimiter(max_processes), - cancellable=True, - ) + env: dict[str, Any] | None = None + if sys.platform.startswith("linux"): + env = {"OMP_THREAD_LIMIT": "1"} + + result = await run_process(command, env=env) if not result.returncode == 0: raise OCRError( @@ -164,7 +147,7 @@ async def process_file( output = await AsyncPath(output_path).read_text("utf-8") return ExtractionResult(content=normalize_spaces(output), mime_type=PLAIN_TEXT_MIME_TYPE, metadata={}) except (RuntimeError, OSError) as e: - raise OCRError("Failed to OCR using tesseract") from e + raise OCRError(f"Failed to OCR using tesseract: {e}") from e finally: await unlink() @@ -174,7 +157,6 @@ async def process_image( *, language: str, psm: PSMMode, - max_processes: int = DEFAULT_MAX_PROCESSES, ) -> ExtractionResult: """Process a single Pillow Image using Tesseract OCR. @@ -182,15 +164,13 @@ async def process_image( image: The Pillow Image to process. language: The language code for OCR. psm: Page segmentation mode. - max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1). Returns: ExtractionResult: The extracted text from the image. """ - binary_image = preprocess_image(image) image_path, unlink = await create_temp_file(".png") - await run_sync(binary_image.save, str(image_path), format="PNG") - result = await process_file(image_path, language=language, psm=psm, max_processes=max_processes) + await run_sync(image.save, str(image_path), format="PNG") + result = await process_file(image_path, language=language, psm=psm) await unlink() return result @@ -200,7 +180,6 @@ async def process_image_with_tesseract( *, language: str = "eng", psm: PSMMode = PSMMode.AUTO, - max_processes: int = DEFAULT_MAX_PROCESSES, ) -> ExtractionResult: """Run Tesseract OCR asynchronously on a single Pillow Image or a list of Pillow Images. @@ -208,7 +187,6 @@ async def process_image_with_tesseract( image: A single Pillow Image, a pathlike or a string or a list of Pillow Images to process. language: The language code for OCR (default: "eng"). psm: Page segmentation mode (default: PSMMode.AUTO). - max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1). Raises: ValueError: If the input is not a Pillow Image or a list of Pillow Images. @@ -219,12 +197,10 @@ async def process_image_with_tesseract( await validate_tesseract_version() if isinstance(image, Image): - return await process_image(image, language=language, psm=psm, max_processes=max_processes) + return await process_image(image, language=language, psm=psm) if isinstance(image, (PathLike, str)): - contents = BytesIO(await AsyncPath(image).read_bytes()) - image = await run_sync(open_image, contents) - return await process_image(image, language=language, psm=psm, max_processes=max_processes) + return await process_file(image, language=language, psm=psm) raise ValueError("Input must be one of: str, Pathlike or Pillow Image.") @@ -242,7 +218,7 @@ async def batch_process_images( images: A list of Pillow Images, paths or strings to process. language: The language code for OCR (default: "eng"). psm: Page segmentation mode (default: PSMMode.AUTO). - max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1). + max_processes: Maximum number of concurrent processes (default: CPU count / 2). Raises: ParsingError: If OCR fails to extract text from any of the images. @@ -251,17 +227,12 @@ async def batch_process_images( List of ExtractionResult objects, one per input image. """ await validate_tesseract_version() - results = cast(list[ExtractionResult], list(range(len(images)))) - - async def _process_image(index: int, image: T) -> None: - results[index] = await process_image_with_tesseract( - image, language=language, psm=psm, max_processes=max_processes - ) - try: - async with create_task_group() as tg: - for i, image in enumerate(images): - tg.start_soon(_process_image, i, image) - return results + return await run_taskgroup_batched( + *[partial(process_image_with_tesseract, image, language=language, psm=psm) for image in images], + batch_size=max_processes, + ) except ExceptionGroup as eg: - raise ParsingError("Failed to process images with Tesseract") from eg + raise ParsingError( + "Failed to process images with Tesseract", context={"errors": ",".join([str(e) for e in eg.exceptions])} + ) from eg diff --git a/kreuzberg/_xlsx.py b/kreuzberg/_xlsx.py index f54f588..4fedca6 100644 --- a/kreuzberg/_xlsx.py +++ b/kreuzberg/_xlsx.py @@ -38,7 +38,7 @@ async def extract_xlsx_file(input_file: Path) -> ExtractionResult: async def convert_sheet_to_text(sheet_name: str) -> None: nonlocal results - values = await run_sync(workbook.get_sheet_by_name(sheet_name).to_python) + values = workbook.get_sheet_by_name(sheet_name).to_python() csv_buffer = StringIO() writer = csv.writer(csv_buffer) diff --git a/kreuzberg/extraction.py b/kreuzberg/extraction.py index 2fa6cc7..d053db4 100644 --- a/kreuzberg/extraction.py +++ b/kreuzberg/extraction.py @@ -87,9 +87,7 @@ async def extract_bytes( return await extract_xlsx_content(content) if mime_type in IMAGE_MIME_TYPES or any(mime_type.startswith(value) for value in IMAGE_MIME_TYPES): - return await process_image_with_tesseract( - open_image(BytesIO(content)), max_processes=max_processes, psm=psm, language=language - ) + return await process_image_with_tesseract(open_image(BytesIO(content)), psm=psm, language=language) if mime_type in PANDOC_SUPPORTED_MIME_TYPES or any( mime_type.startswith(value) for value in PANDOC_SUPPORTED_MIME_TYPES @@ -150,7 +148,7 @@ async def extract_file( return await extract_xlsx_file(Path(input_file)) if mime_type in IMAGE_MIME_TYPES or any(mime_type.startswith(value) for value in IMAGE_MIME_TYPES): - return await process_image_with_tesseract(input_file, max_processes=max_processes, psm=psm, language=language) + return await process_image_with_tesseract(input_file, psm=psm, language=language) if mime_type in PANDOC_SUPPORTED_MIME_TYPES or any( mime_type.startswith(value) for value in PANDOC_SUPPORTED_MIME_TYPES diff --git a/tests/conftest.py b/tests/conftest.py index d3acc1b..df4ba19 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,12 @@ import pytest +@pytest.fixture +def anyio_backend() -> str: + """override anyio to test only against asyncio""" + return "asyncio" + + @pytest.fixture(scope="session") def searchable_pdf() -> Path: return Path(__file__).parent / "source" / "searchable.pdf" diff --git a/tests/extraction_test.py b/tests/extraction_test.py index e6611c9..c5ba9c3 100644 --- a/tests/extraction_test.py +++ b/tests/extraction_test.py @@ -44,7 +44,7 @@ async def test_extract_bytes_force_ocr_pdf(non_ascii_pdf: Path) -> None: content = non_ascii_pdf.read_bytes() result = await extract_bytes(content, PDF_MIME_TYPE, force_ocr=True, language="deu") assert_extraction_result(result, mime_type=PLAIN_TEXT_MIME_TYPE) - assert result.content.startswith("AMTSBLATT") + assert "Spatenstich für neue Hackschnitzelheizung Nachhaltige Wärmeversorgung" in result.content @pytest.mark.anyio @@ -112,13 +112,6 @@ async def test_extract_file_pdf(scanned_pdf: Path) -> None: assert_extraction_result(result, mime_type=PLAIN_TEXT_MIME_TYPE) -@pytest.mark.anyio -async def test_extract_file_force_ocr_pdf(non_ascii_pdf: Path) -> None: - result = await extract_file(non_ascii_pdf, PDF_MIME_TYPE, force_ocr=True) - assert_extraction_result(result, mime_type=PLAIN_TEXT_MIME_TYPE) - assert result.content.startswith("AMTSBLATT") - - @pytest.mark.anyio async def test_extract_file_image(ocr_image: Path) -> None: mime_type = "image/jpeg" diff --git a/tests/pandoc_test.py b/tests/pandoc_test.py index efaf898..d54271c 100644 --- a/tests/pandoc_test.py +++ b/tests/pandoc_test.py @@ -32,13 +32,13 @@ @pytest.fixture def mock_subprocess_run(mocker: MockerFixture) -> Mock: - def run_sync(*args: list[Any], **kwargs: Any) -> Mock: + def run_sync(command: list[str], **kwargs: Any) -> Mock: result = Mock() result.stdout = b"pandoc 3.1.0" result.returncode = 0 result.stderr = b"" - if isinstance(args[0], list) and "--version" in args[0]: + if "--version" in command: return result # Handle error test cases @@ -60,7 +60,7 @@ def run_sync(*args: list[Any], **kwargs: Any) -> Mock: raise RuntimeError("Invalid metadata") # Normal case - output_file = next((str(arg) for arg in args[0] if str(arg).endswith((".md", ".json"))), "") + output_file = next((str(arg) for arg in command if str(arg).endswith((".md", ".json"))), "") if output_file: content = ( json.dumps(SAMPLE_PANDOC_JSON) if str(output_file).endswith(".json") else "Sample processed content" @@ -68,24 +68,30 @@ def run_sync(*args: list[Any], **kwargs: Any) -> Mock: Path(output_file).write_text(content) return result - # Mock both subprocess.run and anyio.to_process.run_sync - mock = mocker.patch("subprocess.run", side_effect=run_sync) - mocker.patch("anyio.to_process.run_sync", side_effect=lambda func, *args, **kwargs: func(*args, **kwargs)) + # Mock anyio.run_process + mock = mocker.patch("anyio.run_process", side_effect=run_sync) return mock @pytest.fixture def mock_subprocess_run_invalid(mocker: MockerFixture) -> Mock: - mock = mocker.patch("subprocess.run") - mock.return_value.stdout = b"pandoc 2.0.0" - mock.return_value.returncode = 0 + def run_sync(command: list[str], **kwargs: Any) -> Mock: + result = Mock() + result.stdout = b"pandoc 1.0.0" + result.returncode = 0 + result.stderr = b"" + return result + + mock = mocker.patch("anyio.run_process", side_effect=run_sync) return mock @pytest.fixture def mock_subprocess_run_error(mocker: MockerFixture) -> Mock: - mock = mocker.patch("subprocess.run") - mock.side_effect = FileNotFoundError() + def run_sync(command: list[str], **kwargs: Any) -> Mock: + raise FileNotFoundError + + mock = mocker.patch("anyio.run_process", side_effect=run_sync) return mock @@ -97,7 +103,7 @@ def reset_version_ref(mocker: MockerFixture) -> None: @pytest.mark.anyio async def test_validate_pandoc_version(mock_subprocess_run: Mock) -> None: await _validate_pandoc_version() - mock_subprocess_run.assert_called_with(["pandoc", "--version"], capture_output=True) + mock_subprocess_run.assert_called_with(["pandoc", "--version"]) @pytest.mark.anyio diff --git a/tests/tesseract_test.py b/tests/tesseract_test.py index bfa1da6..a655276 100644 --- a/tests/tesseract_test.py +++ b/tests/tesseract_test.py @@ -26,62 +26,75 @@ @pytest.fixture def mock_subprocess_run(mocker: MockerFixture) -> Mock: - def run_sync(*args: list[Any], **kwargs: dict[str, Any]) -> Mock: + def run_sync(command: list[str], **kwargs: Any) -> Mock: result = Mock() result.stdout = b"tesseract 5.0.0" result.returncode = 0 result.stderr = b"" - if isinstance(args[0], list) and "--version" in args[0]: + if "--version" in command and command[0].endswith("tesseract"): return result # Handle error test cases if "test_process_file_error" in str(kwargs.get("cwd")): result.returncode = 1 result.stderr = b"Error processing file" - raise RuntimeError("Error processing file") + raise OCRError("Error processing file") if "test_process_file_runtime_error" in str(kwargs.get("cwd")): raise RuntimeError("Command failed") # Normal case - if isinstance(args[0], list) and len(args[0]) >= 3: - output_file = args[0][2] + if len(command) >= 3 and command[0].endswith("tesseract"): + output_file = command[2] if "test_process_image_with_tesseract_invalid_input" in str(kwargs.get("cwd")): result.returncode = 1 result.stderr = b"Error processing file" - raise RuntimeError("Error processing file") + raise OCRError("Error processing file") + + # Verify required tesseract arguments + if not all(arg in command for arg in ["--oem", "1", "--loglevel", "OFF", "-c", "thresholding_method=1"]): + result.returncode = 1 + result.stderr = b"Missing required tesseract arguments" + return result + Path(f"{output_file}.txt").write_text("Sample OCR text") result.returncode = 0 return result return result - # Mock both subprocess.run and anyio.to_process.run_sync - mock = mocker.patch("subprocess.run", side_effect=run_sync) - mocker.patch("anyio.to_process.run_sync", side_effect=lambda func, *args, **kwargs: func(*args, **kwargs)) + # Mock run_process + mock = mocker.patch("kreuzberg._tesseract.run_process", side_effect=run_sync) return mock @pytest.fixture def mock_subprocess_run_invalid(mocker: MockerFixture) -> Mock: - mock = mocker.patch("subprocess.run") - mock.return_value.stdout = b"tesseract 4.0.0" - mock.return_value.returncode = 0 + def run_sync(command: list[str], **kwargs: Any) -> Mock: + result = Mock() + result.stdout = b"tesseract 4.0.0" + result.returncode = 0 + result.stderr = b"" + return result + + mock = mocker.patch("kreuzberg._tesseract.run_process", side_effect=run_sync) return mock @pytest.fixture def mock_subprocess_run_error(mocker: MockerFixture) -> Mock: - mock = mocker.patch("subprocess.run") - mock.side_effect = FileNotFoundError() + def run_sync(command: list[str], **kwargs: Any) -> Mock: + raise FileNotFoundError + + mock = mocker.patch("kreuzberg._tesseract.run_process", side_effect=run_sync) return mock @pytest.mark.anyio async def test_validate_tesseract_version(mock_subprocess_run: Mock) -> None: await validate_tesseract_version() - mock_subprocess_run.assert_called_with(["tesseract", "--version"], capture_output=True) + mock_subprocess_run.assert_called_with(["tesseract", "--version"]) @pytest.fixture(autouse=True) @@ -90,13 +103,13 @@ def reset_version_ref(mocker: MockerFixture) -> None: @pytest.mark.anyio -async def test_validate_tesseract_version_invalid(mock_subprocess_run_invalid: Mock) -> None: +async def test_validate_tesseract_version_invalid(mock_subprocess_run_invalid: Mock, reset_version_ref: None) -> None: with pytest.raises(MissingDependencyError, match="Tesseract version 5 or above is required"): await validate_tesseract_version() @pytest.mark.anyio -async def test_validate_tesseract_version_missing(mock_subprocess_run_error: Mock) -> None: +async def test_validate_tesseract_version_missing(mock_subprocess_run_error: Mock, reset_version_ref: None) -> None: with pytest.raises(MissingDependencyError, match="Tesseract is not installed"): await validate_tesseract_version()