diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index f3b5e75..31977d6 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -25,6 +25,10 @@ jobs: run: | poetry run python benchmark/detection.py --max_rows 2 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection + - name: Run inline detection benchmarj + run: | + poetry run python benchmark/inline_detection.py --max_rows 5 + poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/inline_math_bench/results.json --bench_type inline_detection - name: Run recognition benchmark test run: | poetry run python benchmark/recognition.py --max_rows 2 diff --git a/benchmark/inline_detection.py b/benchmark/inline_detection.py new file mode 100644 index 0000000..5fc42d0 --- /dev/null +++ b/benchmark/inline_detection.py @@ -0,0 +1,107 @@ +import collections +import copy +import json +from pathlib import Path + +import click + +from benchmark.utils.metrics import precision_recall +from surya.debug.draw import draw_polys_on_image +from surya.input.processing import convert_if_not_rgb +from surya.common.util import rescale_bbox +from surya.settings import settings +from surya.detection import DetectionPredictor, InlineDetectionPredictor + +import os +import time +from tabulate import tabulate +import datasets + + +@click.command(help="Benchmark inline math detection model.") +@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) +@click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100) +@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) +def main(results_dir: str, max_rows: int, debug: bool): + det_predictor = DetectionPredictor() + inline_det_predictor = InlineDetectionPredictor() + + dataset = datasets.load_dataset(settings.INLINE_MATH_BENCH_DATASET_NAME, split=f"train[:{max_rows}]") + images = list(dataset["image"]) + images = convert_if_not_rgb(images) + correct_boxes = [] + for i, boxes in enumerate(dataset["bboxes"]): + img_size = images[i].size + # Rescale from normalized 0-1 vals to image size + correct_boxes.append([rescale_bbox(b, (1, 1), img_size) for b in boxes]) + + if settings.DETECTOR_STATIC_CACHE: + # Run through one batch to compile the model + det_predictor(images[:1]) + inline_det_predictor(images[:1], [[]]) + + start = time.time() + det_results = det_predictor(images) + + # Reformat text boxes to inline math input format + text_boxes = [] + for result in det_results: + text_boxes.append([b.bbox for b in result.bboxes]) + + inline_results = inline_det_predictor(images, text_boxes) + surya_time = time.time() - start + + result_path = Path(results_dir) / "inline_math_bench" + result_path.mkdir(parents=True, exist_ok=True) + + page_metrics = collections.OrderedDict() + for idx, (sb, cb) in enumerate(zip(inline_results, correct_boxes)): + surya_boxes = [s.bbox for s in sb.bboxes] + surya_polys = [s.polygon for s in sb.bboxes] + + surya_metrics = precision_recall(surya_boxes, cb) + + page_metrics[idx] = { + "surya": surya_metrics, + } + + if debug: + bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx])) + bbox_image.save(result_path / f"{idx}_bbox.png") + + mean_metrics = {} + metric_types = sorted(page_metrics[0]["surya"].keys()) + models = ["surya"] + + for k in models: + for m in metric_types: + metric = [] + for page in page_metrics: + metric.append(page_metrics[page][k][m]) + if k not in mean_metrics: + mean_metrics[k] = {} + mean_metrics[k][m] = sum(metric) / len(metric) + + out_data = { + "times": { + "surya": surya_time, + }, + "metrics": mean_metrics, + "page_metrics": page_metrics + } + + with open(result_path / "results.json", "w+", encoding="utf-8") as f: + json.dump(out_data, f, indent=4) + + table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types + table_data = [ + ["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types], + ] + + print(tabulate(table_data, headers=table_headers, tablefmt="github")) + print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.") + print(f"Wrote results to {result_path}") + + +if __name__ == "__main__": + main() diff --git a/benchmark/utils/verify_benchmark_scores.py b/benchmark/utils/verify_benchmark_scores.py index 9b17b2d..f6b9084 100644 --- a/benchmark/utils/verify_benchmark_scores.py +++ b/benchmark/utils/verify_benchmark_scores.py @@ -18,6 +18,11 @@ def verify_det(data): raise ValueError("Scores do not meet the required threshold") +def verify_inline_det(data): + scores = data["metrics"]["surya"] + if scores["precision"] <= 0.5 or scores["recall"] <= 0.5: + raise ValueError("Scores do not meet the required threshold") + def verify_rec(data): scores = data["surya"] if scores["avg_score"] <= 0.9: @@ -62,6 +67,8 @@ def main(file_path, bench_type): verify_table_rec(data) elif bench_type == "texify": verify_texify(data) + elif bench_type == "inline_detection": + verify_inline_det(data) else: raise ValueError("Invalid benchmark type") diff --git a/surya/settings.py b/surya/settings.py index 3ae313c..91bdd0b 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -58,6 +58,7 @@ def TORCH_DEVICE_MODEL(self) -> str: # Inline math detection INLINE_MATH_MODEL_CHECKPOINT: str = "datalab-to/inline_math_det0@75aafc7aa3d494ece6496d28038c91f0d2518a43" INLINE_MATH_THRESHOLD: float = 0.9 #Threshold for inline math detection (above this is considered inline-math) + INLINE_MATH_BENCH_DATASET_NAME: str = "datalab-to/inline_detection_bench" # Text recognition RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec2@6611509b2c3a32c141703ce19adc899d9d0abf41"