-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdemo_show.py
98 lines (64 loc) · 2.23 KB
/
demo_show.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from torch.autograd import Variable
import utils
import dataset
from PIL import Image
import matplotlib.pyplot as plt
import collections
import os
import models.crnn as crnn
model_path = '/data_2/project_2021/crnn/crnn.pytorch-master/data/netCRNN_28500_1.pth'
img_path = '/data_1/everyday/xian/20210308_crnn/IIIT5K/train/108_12.png'
alphabet = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
dir_img = "/data_1/everyday/xian/20210308_crnn/IIIT5K/test/"
nclass = len(alphabet) + 1
model = crnn.CRNN(32, 1, nclass, 256)#model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():
model = model.cuda()
#
# for m in model.state_dict().keys():
# print("==:: ", m)
load_model_ = torch.load(model_path)
# for k, v in load_model_.items():
# print(k," ::shape",v.shape)
state_dict_rename = collections.OrderedDict()
for k, v in load_model_.items():
name = k[7:] # remove `module.`
state_dict_rename[name] = v
print('loading pretrained model from %s' % model_path)
model.load_state_dict(state_dict_rename)
converter = utils.strLabelConverter(alphabet)
transformer = dataset.resizeNormalize((100, 32))
list_img = os.listdir(dir_img)
for cnt,img_name in enumerate(list_img):
print(cnt,img_name)
path_img = dir_img + img_name
image = Image.open(path_img).convert('L')
image = transformer(image)
if torch.cuda.is_available():
image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)
model.eval()
preds = model(image)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
raw_pred = converter.decode(preds.data, preds_size.data, raw=True)
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
print('%-20s => %-20s' % (raw_pred, sim_pred))
image_show = Image.open(path_img)
plt.figure("show")
plt.imshow(image_show)
plt.show()
# list_img = os.listdir(dir_img)
# for cnt,img_name in enumerate(list_img):
# print(cnt,img_name)
# path_img = dir_img + img_name
#
# image_show = Image.open(path_img)
# # image_show.show()
#
# plt.figure("dog")
# plt.imshow(image_show)
# plt.show()