Skip to content

Commit

Permalink
Add inline detection benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Feb 10, 2025
1 parent 3b0f698 commit 4eb67cf
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ jobs:
run: |
poetry run python benchmark/detection.py --max_rows 2
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection
- name: Run inline detection benchmarj
run: |
poetry run python benchmark/inline_detection.py --max_rows 5
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/inline_math_bench/results.json --bench_type inline_detection
- name: Run recognition benchmark test
run: |
poetry run python benchmark/recognition.py --max_rows 2
Expand Down
107 changes: 107 additions & 0 deletions benchmark/inline_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import collections
import copy
import json
from pathlib import Path

import click

from benchmark.utils.metrics import precision_recall
from surya.debug.draw import draw_polys_on_image
from surya.input.processing import convert_if_not_rgb
from surya.common.util import rescale_bbox
from surya.settings import settings
from surya.detection import DetectionPredictor, InlineDetectionPredictor

import os
import time
from tabulate import tabulate
import datasets


@click.command(help="Benchmark inline math detection model.")
@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
@click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100)
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
def main(results_dir: str, max_rows: int, debug: bool):
det_predictor = DetectionPredictor()
inline_det_predictor = InlineDetectionPredictor()

dataset = datasets.load_dataset(settings.INLINE_MATH_BENCH_DATASET_NAME, split=f"train[:{max_rows}]")
images = list(dataset["image"])
images = convert_if_not_rgb(images)
correct_boxes = []
for i, boxes in enumerate(dataset["bboxes"]):
img_size = images[i].size
# Rescale from normalized 0-1 vals to image size
correct_boxes.append([rescale_bbox(b, (1, 1), img_size) for b in boxes])

if settings.DETECTOR_STATIC_CACHE:
# Run through one batch to compile the model
det_predictor(images[:1])
inline_det_predictor(images[:1], [[]])

start = time.time()
det_results = det_predictor(images)

# Reformat text boxes to inline math input format
text_boxes = []
for result in det_results:
text_boxes.append([b.bbox for b in result.bboxes])

inline_results = inline_det_predictor(images, text_boxes)
surya_time = time.time() - start

result_path = Path(results_dir) / "inline_math_bench"
result_path.mkdir(parents=True, exist_ok=True)

page_metrics = collections.OrderedDict()
for idx, (sb, cb) in enumerate(zip(inline_results, correct_boxes)):
surya_boxes = [s.bbox for s in sb.bboxes]
surya_polys = [s.polygon for s in sb.bboxes]

surya_metrics = precision_recall(surya_boxes, cb)

page_metrics[idx] = {
"surya": surya_metrics,
}

if debug:
bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx]))
bbox_image.save(result_path / f"{idx}_bbox.png")

mean_metrics = {}
metric_types = sorted(page_metrics[0]["surya"].keys())
models = ["surya"]

for k in models:
for m in metric_types:
metric = []
for page in page_metrics:
metric.append(page_metrics[page][k][m])
if k not in mean_metrics:
mean_metrics[k] = {}
mean_metrics[k][m] = sum(metric) / len(metric)

out_data = {
"times": {
"surya": surya_time,
},
"metrics": mean_metrics,
"page_metrics": page_metrics
}

with open(result_path / "results.json", "w+", encoding="utf-8") as f:
json.dump(out_data, f, indent=4)

table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types
table_data = [
["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types],
]

print(tabulate(table_data, headers=table_headers, tablefmt="github"))
print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.")
print(f"Wrote results to {result_path}")


if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions benchmark/utils/verify_benchmark_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def verify_det(data):
raise ValueError("Scores do not meet the required threshold")


def verify_inline_det(data):
scores = data["metrics"]["surya"]
if scores["precision"] <= 0.5 or scores["recall"] <= 0.5:
raise ValueError("Scores do not meet the required threshold")

def verify_rec(data):
scores = data["surya"]
if scores["avg_score"] <= 0.9:
Expand Down Expand Up @@ -62,6 +67,8 @@ def main(file_path, bench_type):
verify_table_rec(data)
elif bench_type == "texify":
verify_texify(data)
elif bench_type == "inline_detection":
verify_inline_det(data)
else:
raise ValueError("Invalid benchmark type")

Expand Down
1 change: 1 addition & 0 deletions surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
# Inline math detection
INLINE_MATH_MODEL_CHECKPOINT: str = "datalab-to/inline_math_det0@75aafc7aa3d494ece6496d28038c91f0d2518a43"
INLINE_MATH_THRESHOLD: float = 0.9 #Threshold for inline math detection (above this is considered inline-math)
INLINE_MATH_BENCH_DATASET_NAME: str = "datalab-to/inline_detection_bench"

# Text recognition
RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec2@6611509b2c3a32c141703ce19adc899d9d0abf41"
Expand Down

0 comments on commit 4eb67cf

Please sign in to comment.