Skip to content

Commit

Permalink
Refactor layout
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 6, 2025
1 parent 1de6d91 commit d7f1567
Show file tree
Hide file tree
Showing 19 changed files with 345 additions and 386 deletions.
9 changes: 3 additions & 6 deletions detect_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from collections import defaultdict

from surya.input.load import load_from_folder, load_from_file
from surya.layout import batch_layout_detection
from surya.model.layout.model import load_model
from surya.model.layout.processor import load_processor
from surya.layout import LayoutPredictor
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
import os
Expand All @@ -22,8 +20,7 @@ def main():
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
args = parser.parse_args()

model = load_model()
processor = load_processor()
layout_predictor = LayoutPredictor()

if os.path.isdir(args.input_path):
images, names, _ = load_from_folder(args.input_path, args.max)
Expand All @@ -33,7 +30,7 @@ def main():
folder_name = os.path.basename(args.input_path).split(".")[0]

start = time.time()
layout_predictions = batch_layout_detection(images, model, processor)
layout_predictions = layout_predictor(images)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
if args.debug:
Expand Down
12 changes: 5 additions & 7 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
import pypdfium2
import streamlit as st

from surya.layout import batch_layout_detection
from surya.detection import DetectionPredictor
from surya.recognition import RecognitionPredictor
from surya.layout import LayoutPredictor

from surya.model.layout.model import load_model as load_layout_model
from surya.model.layout.processor import load_processor as load_layout_processor
from surya.model.table_rec.model import load_model as load_table_model
from surya.model.table_rec.processor import load_processor as load_table_processor
from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor
Expand All @@ -37,7 +35,7 @@ def load_rec_cached():

@st.cache_resource()
def load_layout_cached():
return load_layout_model(), load_layout_processor()
return LayoutPredictor()

@st.cache_resource()
def load_table_cached():
Expand Down Expand Up @@ -85,7 +83,7 @@ def text_detection(img) -> (Image.Image, TextDetectionResult):


def layout_detection(img) -> (Image.Image, LayoutResult):
pred = batch_layout_detection([img], layout_model, layout_processor)[0]
pred = layout_predictor([img])[0]
polygons = [p.polygon for p in pred.bboxes]
labels = [f"{p.label}-{p.position}" for p in pred.bboxes]
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
Expand Down Expand Up @@ -139,7 +137,7 @@ def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):

bboxes = [l.bbox for l in img_pred.text_lines]
text = [l.text for l in img_pred.text_lines]
rec_img = draw_text_on_image(bboxes, text, img.size, langs, has_math="_math" in langs)
rec_img = draw_text_on_image(bboxes, text, img.size, langs)
return rec_img, img_pred


Expand Down Expand Up @@ -175,7 +173,7 @@ def page_counter(pdf_file):

det_predictor = load_det_cached()
recognition_predictor = load_rec_cached()
layout_model, layout_processor = load_layout_cached()
layout_predictor = load_layout_cached()
table_model, table_processor = load_table_cached()
ocr_error_model, ocr_error_processor = load_ocr_error_cached()

Expand Down
9 changes: 1 addition & 8 deletions ocr_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main():
for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)):
bboxes = [l.bbox for l in pred.text_lines]
pred_text = [l.text for l in pred.text_lines]
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs if langs else False)
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs)
page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png"))

out_preds = defaultdict(list)
Expand All @@ -81,10 +81,3 @@ def main():

if __name__ == "__main__":
main()







4 changes: 2 additions & 2 deletions surya/detection/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def model(
self,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype | str] = None
):
) -> EfficientViTForSemanticSegmentation:
config = EfficientViTConfig.from_pretrained(self.checkpoint)
model = EfficientViTForSemanticSegmentation.from_pretrained(self.checkpoint, torch_dtype=dtype, config=config,
ignore_mismatched_sizes=True)
Expand All @@ -39,5 +39,5 @@ def model(
print(f"Loaded detection model {self.checkpoint} on device {device} with dtype {dtype}")
return model

def processor(self):
def processor(self) -> SegformerImageProcessor:
return SegformerImageProcessor.from_pretrained(self.checkpoint)
249 changes: 0 additions & 249 deletions surya/layout.py

This file was deleted.

Loading

0 comments on commit d7f1567

Please sign in to comment.