Skip to content

Commit

Permalink
chore: fix sync issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Goldziher committed Feb 19, 2025
1 parent 967cb11 commit 43bf7b2
Show file tree
Hide file tree
Showing 14 changed files with 371 additions and 388 deletions.
2 changes: 2 additions & 0 deletions kreuzberg/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
from typing import Final

DEFAULT_MAX_PROCESSES: Final[int] = cpu_count()
MINIMAL_SUPPORTED_TESSERACT_VERSION: Final[int] = 5
MINIMAL_SUPPORTED_PANDOC_VERSION: Final[int] = 2
111 changes: 36 additions & 75 deletions kreuzberg/_pandoc.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from __future__ import annotations

import subprocess
import re
import sys
from functools import partial
from json import JSONDecodeError, loads
from typing import TYPE_CHECKING, Any, Final, Literal, cast

from anyio import CapacityLimiter, create_task_group, run_process, to_process
from anyio import Path as AsyncPath
from anyio import run_process

from kreuzberg._constants import DEFAULT_MAX_PROCESSES
from kreuzberg import ValidationError
from kreuzberg._constants import MINIMAL_SUPPORTED_PANDOC_VERSION
from kreuzberg._mime_types import MARKDOWN_MIME_TYPE
from kreuzberg._string import normalize_spaces
from kreuzberg._sync import run_taskgroup
from kreuzberg._tmp import create_temp_file
from kreuzberg._types import ExtractionResult, Metadata
from kreuzberg.exceptions import MissingDependencyError, ParsingError, ValidationError
from kreuzberg.exceptions import MissingDependencyError, ParsingError

if TYPE_CHECKING: # pragma: no cover
from collections.abc import Mapping
Expand All @@ -23,10 +25,8 @@
if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import ExceptionGroup # type: ignore[import-not-found]


version_ref: Final[dict[str, bool]] = {"checked": False}


# Block-level node types in Pandoc AST
BLOCK_HEADER: Final = "Header" # Header with level, attributes and inline content
BLOCK_PARA: Final = "Para" # Paragraph containing inline content
Expand Down Expand Up @@ -228,20 +228,15 @@ def _extract_metadata(raw_meta: dict[str, Any]) -> Metadata:


def _get_pandoc_type_from_mime_type(mime_type: str) -> str:
if mime_type not in MIMETYPE_TO_PANDOC_TYPE_MAPPING or not any(
mime_type.startswith(value) for value in MIMETYPE_TO_PANDOC_TYPE_MAPPING
):
raise ValidationError(
f"Unsupported mime type: {mime_type}",
context={
"mime_type": mime_type,
"supported_mimetypes": ",".join(sorted(MIMETYPE_TO_PANDOC_TYPE_MAPPING)),
},
if pandoc_type := (MIMETYPE_TO_PANDOC_TYPE_MAPPING.get(mime_type, "")):
return pandoc_type

if any(k.startswith(mime_type) for k in MIMETYPE_TO_PANDOC_TYPE_MAPPING):
return next(
MIMETYPE_TO_PANDOC_TYPE_MAPPING[k] for k in MIMETYPE_TO_PANDOC_TYPE_MAPPING if k.startswith(mime_type)
)

return MIMETYPE_TO_PANDOC_TYPE_MAPPING.get(mime_type) or next(
MIMETYPE_TO_PANDOC_TYPE_MAPPING[k] for k in MIMETYPE_TO_PANDOC_TYPE_MAPPING if k.startswith(mime_type)
)
raise ValidationError(f"Unsupported mime type: {mime_type}")


async def _validate_pandoc_version() -> None:
Expand All @@ -251,20 +246,18 @@ async def _validate_pandoc_version() -> None:

command = ["pandoc", "--version"]
result = await run_process(command)
version = result.stdout.decode().split("\n")[0].split()[1]
major_version = int(version.split(".")[0])
if major_version < 2:
raise MissingDependencyError("Pandoc version 2 or above is required.")

version_match = re.search(r"pandoc\s+v?(\d+)\.\d+\.\d+", result.stdout.decode())
if not version_match or int(version_match.group(1)) < MINIMAL_SUPPORTED_PANDOC_VERSION:
raise MissingDependencyError("Pandoc version 2 or above is required")

version_ref["checked"] = True

except FileNotFoundError as e:
raise MissingDependencyError("Pandoc is not installed.") from e
raise MissingDependencyError("Pandoc is not installed") from e


async def _handle_extract_metadata(
input_file: str | PathLike[str], *, mime_type: str, max_processes: int = DEFAULT_MAX_PROCESSES
) -> Metadata:
async def _handle_extract_metadata(input_file: str | PathLike[str], *, mime_type: str) -> Metadata:
pandoc_type = _get_pandoc_type_from_mime_type(mime_type)
metadata_file, unlink = await create_temp_file(".json")
try:
Expand All @@ -276,15 +269,10 @@ async def _handle_extract_metadata(
"--standalone",
"--quiet",
"--output",
metadata_file,
str(metadata_file),
]

result = await to_process.run_sync(
partial(subprocess.run, capture_output=True),
command,
cancellable=True,
limiter=CapacityLimiter(max_processes),
)
result = await run_process(command)

if result.returncode != 0:
raise ParsingError("Failed to extract file data", context={"file": str(input_file), "error": result.stderr})
Expand All @@ -297,9 +285,7 @@ async def _handle_extract_metadata(
await unlink()


async def _handle_extract_file(
input_file: str | PathLike[str], *, mime_type: str, max_processes: int = DEFAULT_MAX_PROCESSES
) -> str:
async def _handle_extract_file(input_file: str | PathLike[str], *, mime_type: str) -> str:
pandoc_type = _get_pandoc_type_from_mime_type(mime_type)
output_path, unlink = await create_temp_file(".md")
try:
Expand All @@ -315,12 +301,7 @@ async def _handle_extract_file(

command.extend(["--output", str(output_path)])

result = await to_process.run_sync(
partial(subprocess.run, capture_output=True),
command,
cancellable=True,
limiter=CapacityLimiter(max_processes),
)
result = await run_process(command)

if result.returncode != 0:
raise ParsingError("Failed to extract file data", context={"file": str(input_file), "error": result.stderr})
Expand All @@ -334,15 +315,12 @@ async def _handle_extract_file(
await unlink()


async def process_file_with_pandoc(
input_file: str | PathLike[str], *, mime_type: str, max_processes: int = DEFAULT_MAX_PROCESSES
) -> ExtractionResult:
async def process_file_with_pandoc(input_file: str | PathLike[str], *, mime_type: str) -> ExtractionResult:
"""Process a single file using Pandoc and convert to markdown.
Args:
input_file: The path to the file to process.
mime_type: The mime type of the file.
max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1).
Raises:
ParsingError: If the file data could not be extracted.
Expand All @@ -354,44 +332,27 @@ async def process_file_with_pandoc(

_get_pandoc_type_from_mime_type(mime_type)

metadata: Metadata = {}
content: str = ""

try:
async with create_task_group() as tg:

async def _get_metadata() -> None:
nonlocal metadata
metadata = await _handle_extract_metadata(input_file, mime_type=mime_type, max_processes=max_processes)

async def _get_content() -> None:
nonlocal content
content = await _handle_extract_file(input_file, mime_type=mime_type, max_processes=max_processes)
metadata, content = await run_taskgroup(
partial(_handle_extract_metadata, input_file, mime_type=mime_type),
partial(_handle_extract_file, input_file, mime_type=mime_type),
)

tg.start_soon(_get_metadata)
tg.start_soon(_get_content)
return ExtractionResult(
content=normalize_spaces(cast(str, content)),
metadata=cast(Metadata, metadata),
mime_type=MARKDOWN_MIME_TYPE,
)
except ExceptionGroup as eg:
raise ParsingError(
"Failed to extract file data",
context={"file": str(input_file), "errors": ",".join([str(e) for e in eg.exceptions])},
) from eg.exceptions[0]

return ExtractionResult(
content=normalize_spaces(content),
metadata=metadata,
mime_type=MARKDOWN_MIME_TYPE,
)
raise ParsingError("Failed to process file", context={"file": str(input_file), "errors": eg.exceptions}) from eg


async def process_content_with_pandoc(
content: bytes, *, mime_type: str, max_processes: int = DEFAULT_MAX_PROCESSES
) -> ExtractionResult:
async def process_content_with_pandoc(content: bytes, *, mime_type: str) -> ExtractionResult:
"""Process content using Pandoc and convert to markdown.
Args:
content: The content to process.
mime_type: The mime type of the content.
max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1).
Returns:
ExtractionResult
Expand All @@ -400,7 +361,7 @@ async def process_content_with_pandoc(
input_file, unlink = await create_temp_file(f".{extension}")

await AsyncPath(input_file).write_bytes(content)
result = await process_file_with_pandoc(input_file, mime_type=mime_type, max_processes=max_processes)
result = await process_file_with_pandoc(input_file, mime_type=mime_type)

await unlink()
return result
23 changes: 11 additions & 12 deletions kreuzberg/_sync.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

import sys
from collections.abc import Awaitable
from functools import partial
from typing import TYPE_CHECKING, TypeVar, cast

from anyio import create_task_group
from anyio.to_thread import run_sync as any_io_run_sync

if TYPE_CHECKING: # pragma: no cover
from collections.abc import Callable
from collections.abc import Callable, Coroutine

if sys.version_info >= (3, 10):
from typing import ParamSpec
Expand All @@ -35,41 +34,41 @@ async def run_sync(sync_fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -
return cast(T, await any_io_run_sync(handler, *args, abandon_on_cancel=True)) # pyright: ignore [reportCallIssue]


async def run_taskgroup(*coroutines: Callable[[], Awaitable[T]]) -> list[T]:
async def run_taskgroup(*async_tasks: Callable[[], Coroutine[None, None, T]]) -> list[T]:
"""Run a list of coroutines concurrently.
Args:
coroutines: The list of coroutines to run.
*async_tasks: The list of coroutines to run.
Returns:
The results of the coroutines.
"""
results = cast(list[T], [None] * len(coroutines))
results = cast(list[T], [None] * len(async_tasks))

async def run_task(index: int, task: Callable[[], Awaitable[T]]) -> None:
async def run_task(index: int, task: Callable[[], Coroutine[None, None, T]]) -> None:
results[index] = await task()

async with create_task_group() as tg:
for i, coro in enumerate(coroutines):
tg.start_soon(run_task, i, coro)
for i, t in enumerate(async_tasks):
tg.start_soon(run_task, i, t)

return results


async def run_taskgroup_batched(*coroutines: Callable[[], Awaitable[T]], batch_size: int) -> list[T]:
async def run_taskgroup_batched(*async_tasks: Callable[[], Coroutine[None, None, T]], batch_size: int) -> list[T]:
"""Run a list of coroutines concurrently in batches.
Args:
coroutines: The list of coroutines to run.
*async_tasks: The list of coroutines to run.
batch_size: The size of each batch.
Returns:
The results of the coroutines.
"""
results: list[T] = []

for i in range(0, len(coroutines), batch_size):
batch = coroutines[i : i + batch_size]
for i in range(0, len(async_tasks), batch_size):
batch = async_tasks[i : i + batch_size]
results.extend(await run_taskgroup(*batch))

return results
10 changes: 3 additions & 7 deletions kreuzberg/_tesseract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from enum import Enum
from functools import partial
from os import PathLike
from typing import Any, Final, TypeVar, Union
from typing import Any, TypeVar, Union

from anyio import Path as AsyncPath
from anyio import run_process
from PIL.Image import Image

from kreuzberg._constants import DEFAULT_MAX_PROCESSES
from kreuzberg._constants import DEFAULT_MAX_PROCESSES, MINIMAL_SUPPORTED_TESSERACT_VERSION
from kreuzberg._mime_types import PLAIN_TEXT_MIME_TYPE
from kreuzberg._string import normalize_spaces
from kreuzberg._sync import run_sync, run_taskgroup_batched
Expand All @@ -22,8 +22,6 @@
if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import ExceptionGroup # type: ignore[import-not-found]

MINIMAL_SUPPORTED_TESSERACT_VERSION: Final[int] = 5

version_ref = {"checked": False}

T = TypeVar("T", bound=Union[Image, PathLike[str], str])
Expand Down Expand Up @@ -233,6 +231,4 @@ async def batch_process_images(
batch_size=max_processes,
)
except ExceptionGroup as eg:
raise ParsingError(
"Failed to process images with Tesseract", context={"errors": ",".join([str(e) for e in eg.exceptions])}
) from eg
raise ParsingError("Failed to process images with Tesseract", context={"errors": eg.exceptions}) from eg
Loading

0 comments on commit 43bf7b2

Please sign in to comment.