-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3b0f698
commit 4eb67cf
Showing
4 changed files
with
119 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import collections | ||
import copy | ||
import json | ||
from pathlib import Path | ||
|
||
import click | ||
|
||
from benchmark.utils.metrics import precision_recall | ||
from surya.debug.draw import draw_polys_on_image | ||
from surya.input.processing import convert_if_not_rgb | ||
from surya.common.util import rescale_bbox | ||
from surya.settings import settings | ||
from surya.detection import DetectionPredictor, InlineDetectionPredictor | ||
|
||
import os | ||
import time | ||
from tabulate import tabulate | ||
import datasets | ||
|
||
|
||
@click.command(help="Benchmark inline math detection model.") | ||
@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) | ||
@click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100) | ||
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) | ||
def main(results_dir: str, max_rows: int, debug: bool): | ||
det_predictor = DetectionPredictor() | ||
inline_det_predictor = InlineDetectionPredictor() | ||
|
||
dataset = datasets.load_dataset(settings.INLINE_MATH_BENCH_DATASET_NAME, split=f"train[:{max_rows}]") | ||
images = list(dataset["image"]) | ||
images = convert_if_not_rgb(images) | ||
correct_boxes = [] | ||
for i, boxes in enumerate(dataset["bboxes"]): | ||
img_size = images[i].size | ||
# Rescale from normalized 0-1 vals to image size | ||
correct_boxes.append([rescale_bbox(b, (1, 1), img_size) for b in boxes]) | ||
|
||
if settings.DETECTOR_STATIC_CACHE: | ||
# Run through one batch to compile the model | ||
det_predictor(images[:1]) | ||
inline_det_predictor(images[:1], [[]]) | ||
|
||
start = time.time() | ||
det_results = det_predictor(images) | ||
|
||
# Reformat text boxes to inline math input format | ||
text_boxes = [] | ||
for result in det_results: | ||
text_boxes.append([b.bbox for b in result.bboxes]) | ||
|
||
inline_results = inline_det_predictor(images, text_boxes) | ||
surya_time = time.time() - start | ||
|
||
result_path = Path(results_dir) / "inline_math_bench" | ||
result_path.mkdir(parents=True, exist_ok=True) | ||
|
||
page_metrics = collections.OrderedDict() | ||
for idx, (sb, cb) in enumerate(zip(inline_results, correct_boxes)): | ||
surya_boxes = [s.bbox for s in sb.bboxes] | ||
surya_polys = [s.polygon for s in sb.bboxes] | ||
|
||
surya_metrics = precision_recall(surya_boxes, cb) | ||
|
||
page_metrics[idx] = { | ||
"surya": surya_metrics, | ||
} | ||
|
||
if debug: | ||
bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx])) | ||
bbox_image.save(result_path / f"{idx}_bbox.png") | ||
|
||
mean_metrics = {} | ||
metric_types = sorted(page_metrics[0]["surya"].keys()) | ||
models = ["surya"] | ||
|
||
for k in models: | ||
for m in metric_types: | ||
metric = [] | ||
for page in page_metrics: | ||
metric.append(page_metrics[page][k][m]) | ||
if k not in mean_metrics: | ||
mean_metrics[k] = {} | ||
mean_metrics[k][m] = sum(metric) / len(metric) | ||
|
||
out_data = { | ||
"times": { | ||
"surya": surya_time, | ||
}, | ||
"metrics": mean_metrics, | ||
"page_metrics": page_metrics | ||
} | ||
|
||
with open(result_path / "results.json", "w+", encoding="utf-8") as f: | ||
json.dump(out_data, f, indent=4) | ||
|
||
table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types | ||
table_data = [ | ||
["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types], | ||
] | ||
|
||
print(tabulate(table_data, headers=table_headers, tablefmt="github")) | ||
print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.") | ||
print(f"Wrote results to {result_path}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters