Skip to content

Commit

Permalink
Refactor table rec
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 7, 2025
1 parent 51c7c4a commit 8530131
Show file tree
Hide file tree
Showing 18 changed files with 418 additions and 478 deletions.
11 changes: 4 additions & 7 deletions detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from collections import defaultdict

from surya.input.load import load_from_folder, load_from_file
from surya.model.detection.model import load_model, load_processor
from surya.detection import batch_text_detection
from surya.postprocessing.affinity import draw_lines_on_image
from surya.detection import DetectionPredictor
from surya.detection.affinity import draw_lines_on_image
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
import os
Expand All @@ -23,9 +22,7 @@ def main():
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
args = parser.parse_args()

checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
model = load_model(checkpoint=checkpoint)
processor = load_processor(checkpoint=checkpoint)
det_predictor = DetectionPredictor()

if os.path.isdir(args.input_path):
images, names, _ = load_from_folder(args.input_path, args.max)
Expand All @@ -35,7 +32,7 @@ def main():
folder_name = os.path.basename(args.input_path).split(".")[0]

start = time.time()
predictions = batch_text_detection(images, model, processor, include_maps=args.debug)
predictions = det_predictor(images, include_maps=args.debug)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
end = time.time()
Expand Down
45 changes: 9 additions & 36 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,22 @@
import pypdfium2
import streamlit as st

from surya.detection import DetectionPredictor
from surya.recognition import RecognitionPredictor
from surya.layout import LayoutPredictor
from surya.ocr_error import OCRErrorPredictor
from surya.models import load_predictors

from surya.model.table_rec.model import load_model as load_table_model
from surya.model.table_rec.processor import load_processor as load_table_processor
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image

from surya.postprocessing.text import draw_text_on_image
from PIL import Image
from surya.recognition.languages import CODE_TO_LANGUAGE, replace_lang_with_code
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
from surya.settings import settings
from surya.tables import batch_table_recognition
from surya.postprocessing.util import rescale_bbox
from pdftext.extraction import plain_text_output


@st.cache_resource()
def load_det_cached():
return DetectionPredictor()

@st.cache_resource()
def load_rec_cached():
return RecognitionPredictor()

@st.cache_resource()
def load_layout_cached():
return LayoutPredictor()

@st.cache_resource()
def load_table_cached():
return load_table_model(), load_table_processor()

@st.cache_resource()
def load_ocr_error_cached():
return OCRErrorPredictor()
def load_predictors_cached():
return load_predictors()


def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
Expand All @@ -67,22 +45,22 @@ def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pag
for i in range(0, len(text), sample_gap):
samples.append(text[i:i + sample_len])

results = error_predictor(samples)
results = predictors["ocr_error"](samples)
label = "This PDF has good text."
if results.labels.count("bad") / len(results.labels) > .2:
label = "This PDF may have garbled or bad OCR text."
return label, results.labels


def text_detection(img) -> (Image.Image, TextDetectionResult):
pred = det_predictor([img])[0]
pred = predictors["detection"]([img])[0]
polygons = [p.polygon for p in pred.bboxes]
det_img = draw_polys_on_image(polygons, img.copy())
return det_img, pred


def layout_detection(img) -> (Image.Image, LayoutResult):
pred = layout_predictor([img])[0]
pred = predictors["layout"]([img])[0]
polygons = [p.polygon for p in pred.bboxes]
labels = [f"{p.label}-{p.position}" for p in pred.bboxes]
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
Expand All @@ -105,7 +83,7 @@ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Im
)
layout_tables.append(highres_bbox)

table_preds = batch_table_recognition(table_imgs, table_model, table_processor)
table_preds = predictors["table_rec"](table_imgs)
table_img = img.copy()

for results, table_bbox in zip(table_preds, layout_tables):
Expand All @@ -132,7 +110,7 @@ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Im
# Function for OCR
def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
replace_lang_with_code(langs)
img_pred = recognition_predictor([img], [langs], det_predictor, highres_images=[highres_img])[0]
img_pred = predictors["recognition"]([img], [langs], predictors["detection"], highres_images=[highres_img])[0]

bboxes = [l.bbox for l in img_pred.text_lines]
text = [l.text for l in img_pred.text_lines]
Expand Down Expand Up @@ -170,12 +148,7 @@ def page_counter(pdf_file):
st.set_page_config(layout="wide")
col1, col2 = st.columns([.5, .5])

det_predictor = load_det_cached()
recognition_predictor = load_rec_cached()
layout_predictor = load_layout_cached()
table_model, table_processor = load_table_cached()
error_predictor = load_ocr_error_cached()

predictors = load_predictors_cached()

st.markdown("""
# Surya OCR Demo
Expand Down
Empty file removed surya/languages.py
Empty file.
35 changes: 0 additions & 35 deletions surya/model/table_rec/columns.py

This file was deleted.

39 changes: 0 additions & 39 deletions surya/model/table_rec/model.py

This file was deleted.

5 changes: 4 additions & 1 deletion surya/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from surya.layout import LayoutPredictor
from surya.ocr_error import OCRErrorPredictor
from surya.recognition import RecognitionPredictor
from surya.table_rec import TableRecPredictor


def load_predictors(
device: str | torch.device | None = None,
dtype: torch.dtype | str | None = None) -> Dict[str, BasePredictor]:
dtype: torch.dtype | str | None = None
) -> Dict[str, BasePredictor]:
return {
"layout": LayoutPredictor(device=device, dtype=dtype),
"ocr_error": OCRErrorPredictor(device=device, dtype=dtype),
"recognition": RecognitionPredictor(device=device, dtype=dtype),
"detection": DetectionPredictor(device=device, dtype=dtype),
"table_rec": TableRecPredictor(device=device, dtype=dtype)
}
4 changes: 2 additions & 2 deletions surya/ocr_error/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from tqdm import tqdm

from surya.common.predictor import BasePredictor
from surya.ocr_error.loader import OCRErrorLoader
from surya.ocr_error.loader import OCRErrorModelLoader
from surya.ocr_error.model.config import ID2LABEL
from surya.schema import OCRErrorDetectionResult
from surya.settings import settings


class OCRErrorPredictor(BasePredictor):
model_loader_cls = OCRErrorLoader
model_loader_cls = OCRErrorModelLoader

def __call__(
self,
Expand Down
2 changes: 1 addition & 1 deletion surya/ocr_error/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from surya.settings import settings


class OCRErrorLoader(ModelLoader):
class OCRErrorModelLoader(ModelLoader):
def __init__(self, checkpoint: Optional[str] = None):
super().__init__(checkpoint)

Expand Down
Loading

0 comments on commit 8530131

Please sign in to comment.