From 51c7c4acc6eddb58c526196f68354e73b18a59be Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Mon, 6 Jan 2025 21:39:45 -0500 Subject: [PATCH] Refactor OCR error model --- ocr_app.py | 9 ++- surya/model/ocr_error/model.py | 27 --------- surya/models.py | 20 +++++++ surya/ocr_error.py | 48 --------------- surya/ocr_error/__init__.py | 60 +++++++++++++++++++ surya/ocr_error/loader.py | 41 +++++++++++++ surya/ocr_error/model/__init__.py | 0 .../ocr_error => ocr_error/model}/config.py | 0 .../ocr_error => ocr_error/model}/encoder.py | 2 +- surya/{model => }/ocr_error/tokenizer.py | 0 surya/table_rec/__init__.py | 0 surya/table_rec/loader.py | 0 surya/table_rec/model/__init__.py | 0 13 files changed, 126 insertions(+), 81 deletions(-) delete mode 100644 surya/model/ocr_error/model.py create mode 100644 surya/models.py delete mode 100644 surya/ocr_error.py create mode 100644 surya/ocr_error/__init__.py create mode 100644 surya/ocr_error/loader.py create mode 100644 surya/ocr_error/model/__init__.py rename surya/{model/ocr_error => ocr_error/model}/config.py (100%) rename surya/{model/ocr_error => ocr_error/model}/encoder.py (99%) rename surya/{model => }/ocr_error/tokenizer.py (100%) create mode 100644 surya/table_rec/__init__.py create mode 100644 surya/table_rec/loader.py create mode 100644 surya/table_rec/model/__init__.py diff --git a/ocr_app.py b/ocr_app.py index 8d46eb9c..b29c7894 100644 --- a/ocr_app.py +++ b/ocr_app.py @@ -8,10 +8,10 @@ from surya.detection import DetectionPredictor from surya.recognition import RecognitionPredictor from surya.layout import LayoutPredictor +from surya.ocr_error import OCRErrorPredictor from surya.model.table_rec.model import load_model as load_table_model from surya.model.table_rec.processor import load_processor as load_table_processor -from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image from surya.postprocessing.text import draw_text_on_image @@ -22,7 +22,6 @@ from surya.tables import batch_table_recognition from surya.postprocessing.util import rescale_bbox from pdftext.extraction import plain_text_output -from surya.ocr_error import batch_ocr_error_detection @st.cache_resource() @@ -43,7 +42,7 @@ def load_table_cached(): @st.cache_resource() def load_ocr_error_cached(): - return load_ocr_error_model(), load_ocr_error_processor() + return OCRErrorPredictor() def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15): @@ -68,7 +67,7 @@ def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pag for i in range(0, len(text), sample_gap): samples.append(text[i:i + sample_len]) - results = batch_ocr_error_detection(samples, ocr_error_model, ocr_error_processor) + results = error_predictor(samples) label = "This PDF has good text." if results.labels.count("bad") / len(results.labels) > .2: label = "This PDF may have garbled or bad OCR text." @@ -175,7 +174,7 @@ def page_counter(pdf_file): recognition_predictor = load_rec_cached() layout_predictor = load_layout_cached() table_model, table_processor = load_table_cached() -ocr_error_model, ocr_error_processor = load_ocr_error_cached() +error_predictor = load_ocr_error_cached() st.markdown(""" diff --git a/surya/model/ocr_error/model.py b/surya/model/ocr_error/model.py deleted file mode 100644 index 72ed5816..00000000 --- a/surya/model/ocr_error/model.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations -import torch - -from surya.model.ocr_error.encoder import DistilBertForSequenceClassification -from surya.model.ocr_error.config import DistilBertConfig -from surya.model.ocr_error.tokenizer import DistilBertTokenizer -from surya.settings import settings - -def load_model(checkpoint=settings.OCR_ERROR_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE) -> DistilBertForSequenceClassification: - config = DistilBertConfig.from_pretrained(checkpoint) - model = DistilBertForSequenceClassification.from_pretrained(checkpoint, torch_dtype=dtype, config=config).to(device).eval() - - if settings.OCR_ERROR_STATIC_CACHE: - torch.set_float32_matmul_precision('high') - torch._dynamo.config.cache_size_limit = 1 - torch._dynamo.config.suppress_errors = False - - print(f"Compiling detection model {checkpoint} on device {device} with dtype {dtype}") - model = torch.compile(model) - - return model - -def load_tokenizer(checkpoint=settings.OCR_ERROR_MODEL_CHECKPOINT): - tokenizer = DistilBertTokenizer.from_pretrained(checkpoint) - return tokenizer - - diff --git a/surya/models.py b/surya/models.py new file mode 100644 index 00000000..44c5ced7 --- /dev/null +++ b/surya/models.py @@ -0,0 +1,20 @@ +from typing import Dict + +import torch + +from surya.common.predictor import BasePredictor +from surya.detection import DetectionPredictor +from surya.layout import LayoutPredictor +from surya.ocr_error import OCRErrorPredictor +from surya.recognition import RecognitionPredictor + + +def load_predictors( + device: str | torch.device | None = None, + dtype: torch.dtype | str | None = None) -> Dict[str, BasePredictor]: + return { + "layout": LayoutPredictor(device=device, dtype=dtype), + "ocr_error": OCRErrorPredictor(device=device, dtype=dtype), + "recognition": RecognitionPredictor(device=device, dtype=dtype), + "detection": DetectionPredictor(device=device, dtype=dtype), + } \ No newline at end of file diff --git a/surya/ocr_error.py b/surya/ocr_error.py deleted file mode 100644 index 0cf714f4..00000000 --- a/surya/ocr_error.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import List, Optional -from math import ceil -from tqdm import tqdm -import torch -import numpy as np - -from surya.model.ocr_error.model import DistilBertTokenizer -from surya.model.ocr_error.encoder import DistilBertForSequenceClassification -from surya.model.ocr_error.config import ID2LABEL -from surya.settings import settings -from surya.schema import OCRErrorDetectionResult - -def get_batch_size(): - batch_size = settings.OCR_ERROR_BATCH_SIZE - if batch_size is None: - batch_size = 8 - if settings.TORCH_DEVICE_MODEL == "mps": - batch_size = 8 - if settings.TORCH_DEVICE_MODEL == "cuda": - batch_size = 64 - return batch_size - -def batch_ocr_error_detection( - texts: List[str], - model: DistilBertForSequenceClassification, - tokenizer: DistilBertTokenizer, - batch_size: Optional[int] = None -): - if batch_size is None: - batch_size = get_batch_size() - - num_batches = ceil(len(texts)/batch_size) - texts_processed = tokenizer(texts, padding='longest', truncation=True, return_tensors='pt') - predictions = [] - for batch_idx in tqdm(range(num_batches)): - start_idx, end_idx = batch_idx*batch_size, (batch_idx+1)*batch_size - batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to(model.device) - batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to(model.device) - - with torch.inference_mode(): - pred = model(batch_input_ids, attention_mask=batch_attention_mask) - logits = pred.logits.detach().cpu().numpy().astype(np.float32) - predictions.extend(np.argmax(logits, axis=1).tolist()) - - return OCRErrorDetectionResult( - texts=texts, - labels=[ID2LABEL[p] for p in predictions] - ) diff --git a/surya/ocr_error/__init__.py b/surya/ocr_error/__init__.py new file mode 100644 index 00000000..4b02e89a --- /dev/null +++ b/surya/ocr_error/__init__.py @@ -0,0 +1,60 @@ +import math +from typing import List, Optional + +import numpy as np +import torch +from tqdm import tqdm + +from surya.common.predictor import BasePredictor +from surya.ocr_error.loader import OCRErrorLoader +from surya.ocr_error.model.config import ID2LABEL +from surya.schema import OCRErrorDetectionResult +from surya.settings import settings + + +class OCRErrorPredictor(BasePredictor): + model_loader_cls = OCRErrorLoader + + def __call__( + self, + texts: List[str], + batch_size: Optional[int] = None + ): + return self.batch_ocr_error_detection(texts, batch_size) + + @staticmethod + def get_batch_size(): + batch_size = settings.OCR_ERROR_BATCH_SIZE + if batch_size is None: + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "mps": + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "cuda": + batch_size = 64 + return batch_size + + def batch_ocr_error_detection( + self, + texts: List[str], + batch_size: Optional[int] = None + ): + if batch_size is None: + batch_size = self.get_batch_size() + + num_batches = math.ceil(len(texts) / batch_size) + texts_processed = self.processor(texts, padding='longest', truncation=True, return_tensors='pt') + predictions = [] + for batch_idx in tqdm(range(num_batches)): + start_idx, end_idx = batch_idx * batch_size, (batch_idx + 1) * batch_size + batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to(self.model.device) + batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to(self.model.device) + + with torch.inference_mode(): + pred = self.model(batch_input_ids, attention_mask=batch_attention_mask) + logits = pred.logits.detach().cpu().numpy().astype(np.float32) + predictions.extend(np.argmax(logits, axis=1).tolist()) + + return OCRErrorDetectionResult( + texts=texts, + labels=[ID2LABEL[p] for p in predictions] + ) \ No newline at end of file diff --git a/surya/ocr_error/loader.py b/surya/ocr_error/loader.py new file mode 100644 index 00000000..3f3de546 --- /dev/null +++ b/surya/ocr_error/loader.py @@ -0,0 +1,41 @@ +from typing import Optional + +import torch + +from surya.common.load import ModelLoader +from surya.ocr_error.model.config import DistilBertConfig +from surya.ocr_error.model.encoder import DistilBertForSequenceClassification +from surya.ocr_error.tokenizer import DistilBertTokenizer +from surya.settings import settings + + +class OCRErrorLoader(ModelLoader): + def __init__(self, checkpoint: Optional[str] = None): + super().__init__(checkpoint) + + if self.checkpoint is None: + self.checkpoint = settings.OCR_ERROR_MODEL_CHECKPOINT + + def model( + self, + device=settings.TORCH_DEVICE_MODEL, + dtype=settings.MODEL_DTYPE + ) -> DistilBertForSequenceClassification: + config = DistilBertConfig.from_pretrained(self.checkpoint) + model = DistilBertForSequenceClassification.from_pretrained(self.checkpoint, torch_dtype=dtype, config=config).to( + device).eval() + + if settings.OCR_ERROR_STATIC_CACHE: + torch.set_float32_matmul_precision('high') + torch._dynamo.config.cache_size_limit = 1 + torch._dynamo.config.suppress_errors = False + + print(f"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}") + model = torch.compile(model) + + return model + + def processor( + self + ) -> DistilBertTokenizer: + return DistilBertTokenizer.from_pretrained(self.checkpoint) \ No newline at end of file diff --git a/surya/ocr_error/model/__init__.py b/surya/ocr_error/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/surya/model/ocr_error/config.py b/surya/ocr_error/model/config.py similarity index 100% rename from surya/model/ocr_error/config.py rename to surya/ocr_error/model/config.py diff --git a/surya/model/ocr_error/encoder.py b/surya/ocr_error/model/encoder.py similarity index 99% rename from surya/model/ocr_error/encoder.py rename to surya/ocr_error/model/encoder.py index 4e27700d..1be59ef5 100644 --- a/surya/model/ocr_error/encoder.py +++ b/surya/ocr_error/model/encoder.py @@ -21,7 +21,7 @@ from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -from surya.model.ocr_error.config import DistilBertConfig +from surya.ocr_error.model.config import DistilBertConfig def _get_unpad_data(attention_mask): diff --git a/surya/model/ocr_error/tokenizer.py b/surya/ocr_error/tokenizer.py similarity index 100% rename from surya/model/ocr_error/tokenizer.py rename to surya/ocr_error/tokenizer.py diff --git a/surya/table_rec/__init__.py b/surya/table_rec/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/surya/table_rec/loader.py b/surya/table_rec/loader.py new file mode 100644 index 00000000..e69de29b diff --git a/surya/table_rec/model/__init__.py b/surya/table_rec/model/__init__.py new file mode 100644 index 00000000..e69de29b