diff --git a/surya/detection/__init__.py b/surya/detection/__init__.py index 7c38ba1..337af78 100644 --- a/surya/detection/__init__.py +++ b/surya/detection/__init__.py @@ -143,7 +143,7 @@ def batch_generator(self, iterable, batch_size=None): for i in range(0, len(iterable), batch_size): yield iterable[i:i+batch_size] - def __call__(self, images, text_boxes, batch_size=None, include_maps=False) -> List[TextDetectionResult]: + def __call__(self, images, text_boxes: List[List[List[float]]], batch_size=None, include_maps=False) -> List[TextDetectionResult]: detection_generator = self.batch_detection(images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE) text_box_generator = self.batch_generator(text_boxes) diff --git a/tests/test_inline_detection.py b/tests/test_inline_detection.py index b63105e..19bbe28 100644 --- a/tests/test_inline_detection.py +++ b/tests/test_inline_detection.py @@ -1,12 +1,15 @@ from PIL import Image, ImageDraw -def test_latex_ocr(inline_detection_predictor): +def test_inline_detection(inline_detection_predictor, detection_predictor): img = Image.new('RGB', (1024, 1024), color='white') draw = ImageDraw.Draw(img) draw.text((10, 200), "where B(x, ϵ) is the norm ball with radius xadv = x + ϵ · sgn (∇xL(f (x, w), y))", fill='black', font_size=48) - inline_detection_results = inline_detection_predictor([img]) + detection_results = detection_predictor([img]) + bboxes = [[bb.bbox for bb in detection_results[0].bboxes]] + + inline_detection_results = inline_detection_predictor([img], bboxes) assert len(inline_detection_results) == 1 assert inline_detection_results[0].image_bbox == [0, 0, 1024, 1024]