Skip to content

Commit

Permalink
Refactor OCR error model
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 7, 2025
1 parent d7f1567 commit 51c7c4a
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 81 deletions.
9 changes: 4 additions & 5 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from surya.detection import DetectionPredictor
from surya.recognition import RecognitionPredictor
from surya.layout import LayoutPredictor
from surya.ocr_error import OCRErrorPredictor

from surya.model.table_rec.model import load_model as load_table_model
from surya.model.table_rec.processor import load_processor as load_table_processor
from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image

from surya.postprocessing.text import draw_text_on_image
Expand All @@ -22,7 +22,6 @@
from surya.tables import batch_table_recognition
from surya.postprocessing.util import rescale_bbox
from pdftext.extraction import plain_text_output
from surya.ocr_error import batch_ocr_error_detection


@st.cache_resource()
Expand All @@ -43,7 +42,7 @@ def load_table_cached():

@st.cache_resource()
def load_ocr_error_cached():
return load_ocr_error_model(), load_ocr_error_processor()
return OCRErrorPredictor()


def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
Expand All @@ -68,7 +67,7 @@ def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pag
for i in range(0, len(text), sample_gap):
samples.append(text[i:i + sample_len])

results = batch_ocr_error_detection(samples, ocr_error_model, ocr_error_processor)
results = error_predictor(samples)
label = "This PDF has good text."
if results.labels.count("bad") / len(results.labels) > .2:
label = "This PDF may have garbled or bad OCR text."
Expand Down Expand Up @@ -175,7 +174,7 @@ def page_counter(pdf_file):
recognition_predictor = load_rec_cached()
layout_predictor = load_layout_cached()
table_model, table_processor = load_table_cached()
ocr_error_model, ocr_error_processor = load_ocr_error_cached()
error_predictor = load_ocr_error_cached()


st.markdown("""
Expand Down
27 changes: 0 additions & 27 deletions surya/model/ocr_error/model.py

This file was deleted.

20 changes: 20 additions & 0 deletions surya/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Dict

import torch

from surya.common.predictor import BasePredictor
from surya.detection import DetectionPredictor
from surya.layout import LayoutPredictor
from surya.ocr_error import OCRErrorPredictor
from surya.recognition import RecognitionPredictor


def load_predictors(
device: str | torch.device | None = None,
dtype: torch.dtype | str | None = None) -> Dict[str, BasePredictor]:
return {
"layout": LayoutPredictor(device=device, dtype=dtype),
"ocr_error": OCRErrorPredictor(device=device, dtype=dtype),
"recognition": RecognitionPredictor(device=device, dtype=dtype),
"detection": DetectionPredictor(device=device, dtype=dtype),
}
48 changes: 0 additions & 48 deletions surya/ocr_error.py

This file was deleted.

60 changes: 60 additions & 0 deletions surya/ocr_error/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import math
from typing import List, Optional

import numpy as np
import torch
from tqdm import tqdm

from surya.common.predictor import BasePredictor
from surya.ocr_error.loader import OCRErrorLoader
from surya.ocr_error.model.config import ID2LABEL
from surya.schema import OCRErrorDetectionResult
from surya.settings import settings


class OCRErrorPredictor(BasePredictor):
model_loader_cls = OCRErrorLoader

def __call__(
self,
texts: List[str],
batch_size: Optional[int] = None
):
return self.batch_ocr_error_detection(texts, batch_size)

@staticmethod
def get_batch_size():
batch_size = settings.OCR_ERROR_BATCH_SIZE
if batch_size is None:
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 64
return batch_size

def batch_ocr_error_detection(
self,
texts: List[str],
batch_size: Optional[int] = None
):
if batch_size is None:
batch_size = self.get_batch_size()

num_batches = math.ceil(len(texts) / batch_size)
texts_processed = self.processor(texts, padding='longest', truncation=True, return_tensors='pt')
predictions = []
for batch_idx in tqdm(range(num_batches)):
start_idx, end_idx = batch_idx * batch_size, (batch_idx + 1) * batch_size
batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to(self.model.device)
batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to(self.model.device)

with torch.inference_mode():
pred = self.model(batch_input_ids, attention_mask=batch_attention_mask)
logits = pred.logits.detach().cpu().numpy().astype(np.float32)
predictions.extend(np.argmax(logits, axis=1).tolist())

return OCRErrorDetectionResult(
texts=texts,
labels=[ID2LABEL[p] for p in predictions]
)
41 changes: 41 additions & 0 deletions surya/ocr_error/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Optional

import torch

from surya.common.load import ModelLoader
from surya.ocr_error.model.config import DistilBertConfig
from surya.ocr_error.model.encoder import DistilBertForSequenceClassification
from surya.ocr_error.tokenizer import DistilBertTokenizer
from surya.settings import settings


class OCRErrorLoader(ModelLoader):
def __init__(self, checkpoint: Optional[str] = None):
super().__init__(checkpoint)

if self.checkpoint is None:
self.checkpoint = settings.OCR_ERROR_MODEL_CHECKPOINT

def model(
self,
device=settings.TORCH_DEVICE_MODEL,
dtype=settings.MODEL_DTYPE
) -> DistilBertForSequenceClassification:
config = DistilBertConfig.from_pretrained(self.checkpoint)
model = DistilBertForSequenceClassification.from_pretrained(self.checkpoint, torch_dtype=dtype, config=config).to(
device).eval()

if settings.OCR_ERROR_STATIC_CACHE:
torch.set_float32_matmul_precision('high')
torch._dynamo.config.cache_size_limit = 1
torch._dynamo.config.suppress_errors = False

print(f"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}")
model = torch.compile(model)

return model

def processor(
self
) -> DistilBertTokenizer:
return DistilBertTokenizer.from_pretrained(self.checkpoint)
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa

from surya.model.ocr_error.config import DistilBertConfig
from surya.ocr_error.model.config import DistilBertConfig


def _get_unpad_data(attention_mask):
Expand Down
File renamed without changes.
Empty file added surya/table_rec/__init__.py
Empty file.
Empty file added surya/table_rec/loader.py
Empty file.
Empty file.

0 comments on commit 51c7c4a

Please sign in to comment.