Skip to content

Commit

Permalink
Merge pull request #2 from Bhashini-IITJ/east_script_changes
Browse files Browse the repository at this point in the history
East script changes
  • Loading branch information
anikde authored Oct 25, 2024
2 parents 9629d7b + 998b254 commit eb9d3f5
Show file tree
Hide file tree
Showing 17 changed files with 116 additions and 148 deletions.
2 changes: 2 additions & 0 deletions EAST/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__
results
24 changes: 24 additions & 0 deletions EAST/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@


## Scene Text Detection - East
To get started create a virtual env and install the PyTorch version > 2.4.
### Installation
```commandline
conda create -n east_infer python=3.12
conda activate east_infer
conda install pytorch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 pytorch-cuda=11.8 -c pytorch -c nvidia
cd SceneTextDetection/East/
pip install -r requirements.txt
```

### Inference

The script ```infer.py``` shall be used for inference. Get more details about using CLI ```python infer.py -h```.

Model checkpoints can also be accessed from github [assets](https://github.com/Bhashini-IITJ/SceneTextDetection/releases/tag/EAST).
```
python infer.py --image_path ../demo_images/image_90.jpg --model_checkpoint tmp/epoch_990_checkpoint.pth.tar
```

### Acknowledgement
EAST re-implemenation [repository](https://github.com/foamliu/EAST).
File renamed without changes.
77 changes: 77 additions & 0 deletions EAST/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import torch
import cv2
import numpy as np
import time
import warnings
import config as cfg
from model import East
import utils

# Suppress warnings
warnings.filterwarnings("ignore")

def predict(image_path, model_checkpoint):
# Load image
# im = cv2.imread(image_path)
im = cv2.imread(image_path)[:, :, ::-1]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the EAST model and load checkpoint
model = East(device)
model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids)

# Load the model checkpoint with weights_only=True
checkpoint = torch.load(model_checkpoint, map_location=torch.device(device), weights_only=True)
model.load_state_dict(checkpoint['state_dict'])
model.eval()

# Resize image and convert to tensor format
im_resized, (ratio_h, ratio_w) = utils.resize_image(im)
im_resized = im_resized.astype(np.float32).transpose(2, 0, 1)
im_tensor = torch.from_numpy(im_resized).unsqueeze(0).cpu()

# Inference
timer = {'net': 0, 'restore': 0, 'nms': 0}
start = time.time()
score, geometry = model(im_tensor)
timer['net'] = time.time() - start

# Process output
score = score.permute(0, 2, 3, 1).data.cpu().numpy()
geometry = geometry.permute(0, 2, 3, 1).data.cpu().numpy()

# Detect boxes
boxes, timer = utils.detect(
score_map=score, geo_map=geometry, timer=timer,
score_map_thresh=cfg.score_map_thresh, box_thresh=cfg.box_thresh,
nms_thres=cfg.box_thresh
)
bbox_result_dict = {'detections': []}

# Parse detected boxes and adjust coordinates
if boxes is not None:
boxes = boxes[:, :8].reshape((-1, 4, 2))
boxes[:, :, 0] /= ratio_w
boxes[:, :, 1] /= ratio_h
for box in boxes:
box = utils.sort_poly(box.astype(np.int32))
if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5:
continue
bbox_result_dict['detections'].append([
[int(coord[0]), int(coord[1])] for coord in box
])

return bbox_result_dict

if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Text detection using EAST model')
parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file')
args = parser.parse_args()

# Run prediction and get results as dictionary
detection_result = predict(args.image_path, args.model_checkpoint)
print(detection_result)
File renamed without changes.
10 changes: 5 additions & 5 deletions East/model.py → EAST/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _initialize_weights(self):
m.bias.data.zero_()


def mobilenet(pretrained=True, **kwargs):
def mobilenet(device, pretrained=True, **kwargs):
"""
Constructs a ResNet-50 model.
Args:
Expand All @@ -123,7 +123,7 @@ def mobilenet(pretrained=True, **kwargs):
model = MobileNetV2()
if pretrained:
model_dict = model.state_dict()
pretrained_dict = torch.load(cfg.pretrained_basemodel_path,map_location=torch.device('cpu'))
pretrained_dict = torch.load(cfg.pretrained_basemodel_path,map_location=torch.device(f'{device}'), weights_only=True)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
Expand All @@ -134,9 +134,9 @@ def mobilenet(pretrained=True, **kwargs):


class East(nn.Module):
def __init__(self):
def __init__(self, device):
super(East, self).__init__()
self.mobilenet = mobilenet(True)
self.mobilenet = mobilenet(device, pretrained=True)
# self.si for stage i
self.s1 = nn.Sequential(*list(self.mobilenet.children())[0][0:4])
self.s2 = nn.Sequential(*list(self.mobilenet.children())[0][4:7])
Expand Down Expand Up @@ -237,4 +237,4 @@ def forward(self, images):
return F_score, F_geometry


model=East()
# model=East(device)
File renamed without changes.
2 changes: 0 additions & 2 deletions East/requirements.txt → EAST/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
torch==2.4.1
torchvision==0.19.1
shapely==2.0.6
opencv-python==4.10.0.84
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions East/utils.py → EAST/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_
start = time.time()
text_box_restored = preprossing.restore_rectangle(xy_text[:, ::-1] * 4,
geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2
print('{} text boxes before nms'.format(text_box_restored.shape[0]))
# print('{} text boxes before nms'.format(text_box_restored.shape[0]))
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = text_box_restored.reshape((-1, 8))
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
Expand All @@ -193,7 +193,7 @@ def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_
start = time.time()
boxes = locality_aware_nms.nms_locality(boxes.astype(np.float64), nms_thres)
timer['nms'] = time.time() - start
print(timer['nms'])
# print(timer['nms'])
if boxes.shape[0] == 0:
return None, timer

Expand Down
25 changes: 0 additions & 25 deletions East/README.md

This file was deleted.

113 changes: 0 additions & 113 deletions East/infer.py

This file was deleted.

7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
# SceneTextDetection
# SceneTextDetection
This repository provides implementations of various scene text detection models, focusing on the detection of text in images.
1. **EAST**: An Efficient and Accurate Scene Text Detector ([paper](https://arxiv.org/abs/1704.03155))

# Fine-tune schema
The models in this repository have been fine-tuned on the [Bharat Scene Text Dataset (BSTD)](https://github.com/Bhashini-IITJ/BharatSceneTextDataset), a large-scale dataset designed specifically for scene text detection tasks. The demo images used in this repository are also sourced from this dataset.
Binary file added demo_images/image_142.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo_images/image_90.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo_images/image_944.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit eb9d3f5

Please sign in to comment.