-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathtest.py
101 lines (71 loc) · 2.8 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from models import SegmentNet, DecisionNet, weights_init_normal
from dataset import KolektorDataset
import torch.nn as nn
import torch
from torchvision import datasets
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader
import os
import sys
import argparse
import time
import PIL.Image as Image
parser = argparse.ArgumentParser()
parser.add_argument("--cuda", type=bool, default=True, help="number of gpu")
parser.add_argument("--test_seg_epoch", type=int, default=60, help="test segment epoch")
parser.add_argument("--test_dec_epoch", type=int, default=60, help="test segment epoch")
parser.add_argument("--img_height", type=int, default=704, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
opt = parser.parse_args()
print(opt)
dataSetRoot = "/home/sean/Projects/SegDecNet/Data"
# ***********************************************************************
# Build nets
segment_net = SegmentNet(init_weights=True)
decision_net = DecisionNet(init_weights=True)
if opt.cuda:
segment_net = segment_net.cuda()
decision_net = decision_net.cuda()
if opt.test_seg_epoch != 0:
# Load pretrained models
segment_net.load_state_dict(torch.load("./saved_models/segment_net_%d.pth" % (opt.test_seg_epoch)))
if opt.test_dec_epoch != 0:
# Load pretrained models
decision_net.load_state_dict(torch.load("./saved_models/decision_net_%d.pth" % (opt.test_dec_epoch)))
transforms_ = transforms.Compose([
transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
transforms.ToTensor(),
#transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
testloader = DataLoader(
KolektorDataset(dataSetRoot, transforms_=transforms_, transforms_mask= None, subFold="Test", isTrain=False),
batch_size=1,
shuffle=False,
num_workers=0,
)
#segment_net.eval()
#decision_net.eval()
for i, testBatch in enumerate(testloader):
torch.cuda.synchronize()
t1 = time.time()
imgTest = testBatch["img"].cuda()
with torch.no_grad():
rstTest = segment_net(imgTest)
fTest = rstTest["f"]
segTest = rstTest["seg"]
with torch.no_grad():
cTest = decision_net(fTest, segTest)
torch.cuda.synchronize()
t2 = time.time()
if cTest.item() > 0.5:
labelStr = "NG"
else:
labelStr = "OK"
save_path_str = os.path.join(dataSetRoot, "testResult")
if os.path.exists(save_path_str) == False:
os.makedirs(save_path_str, exist_ok=True)
print("processing image NO %d, time comsuption %fs"%(i, t2 - t1))
save_image(imgTest.data, "%s/img_%d_%s.jpg"% (save_path_str, i, labelStr))
save_image(segTest.data, "%s/img_%d_seg_%s.jpg"% (save_path_str, i, labelStr))