-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_dataset.py
71 lines (58 loc) · 2.92 KB
/
test_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import argparse
import torch
from torchvision import transforms
from src.dataset import CocoDataset, Resizer, Normalizer
from src.config import COCO_CLASSES, colors
import cv2
import shutil
def get_args():
parser = argparse.ArgumentParser("SparseSwin: Swin transformer with sparse transformer block")
parser.add_argument("--image_size", type=int, default=512, help="The common width and height for all images")
parser.add_argument("--data_path", type=str, default="data/COCO", help="the root folder of dataset")
parser.add_argument("--cls_threshold", type=float, default=0.5)
parser.add_argument("--nms_threshold", type=float, default=0.5)
parser.add_argument("--pretrained_model", type=str, default="trained_models/signatrix_sparseswin_coco.pth")
parser.add_argument("--output", type=str, default="predictions")
args = parser.parse_args()
return args
def test(opt):
model = torch.load(opt.pretrained_model)
model.cuda()
dataset = CocoDataset(opt.data_path, set='val2017', transform=transforms.Compose([Normalizer(), Resizer()]))
if os.path.isdir(opt.output):
shutil.rmtree(opt.output)
os.makedirs(opt.output)
for index in range(len(dataset)):
print(index)
data = dataset[index]
scale = data['scale']
with torch.no_grad():
scores, labels, boxes = model(data['img'].cuda().permute(2, 0, 1).float().unsqueeze(dim=0))
boxes /= scale
if boxes.shape[0] > 0:
image_info = dataset.coco.loadImgs(dataset.image_ids[index])[0]
path = os.path.join(dataset.root_dir, 'images', dataset.set_name, image_info['file_name'])
output_image = cv2.imread(path)
for box_id in range(boxes.shape[0]):
pred_prob = float(scores[box_id])
if pred_prob < opt.cls_threshold:
break
pred_label = int(labels[box_id])
xmin, ymin, xmax, ymax = boxes[box_id, :]
color = colors[pred_label]
xmin = int(round(float(xmin), 0))
ymin = int(round(float(ymin), 0))
xmax = int(round(float(xmax), 0))
ymax = int(round(float(ymax), 0))
cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), color, 2)
text_size = cv2.getTextSize(COCO_CLASSES[pred_label] + ' : %.2f' % pred_prob, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
cv2.rectangle(output_image, (xmin, ymin), (xmin + text_size[0] + 3, ymin + text_size[1] + 4), color, -1)
cv2.putText(
output_image, COCO_CLASSES[pred_label] + ' : %.2f' % pred_prob,
(xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1,
(255, 255, 255), 1)
cv2.imwrite("{}/{}_prediction.jpg".format(opt.output, image_info["file_name"][:-4]), output_image)
if __name__ == "__main__":
opt = get_args()
test(opt)