-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
71 lines (52 loc) · 2.23 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import numpy as np
import cv2
from utils.datasets import letterbox
from utils.general import non_max_suppression, scale_coords
from utils.plots import Annotator
MODEL_PATH = 'runs/train/exp4/weights/best.pt'
img_size = 640
conf_thres = 0.5 # confidence threshold
iou_thres = 0.45 # NMS IOU threshold
max_det = 1000 # maximum detections per image
classes = None # filter by class
agnostic_nms = False # class-agnostic NMS
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
ckpt = torch.load(MODEL_PATH, map_location=device)
model = ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()
class_names = ['횡단보도', '빨간불', '초록불'] # model.names
stride = int(model.stride.max())
colors = ((50, 50, 50), (0, 0, 255), (0, 255, 0)) # (gray, red, green)
cap = cv2.VideoCapture('data/sample.mp4')
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
out = cv2.VideoWriter('data/output.mp4', fourcc, cap.get(cv2.CAP_PROP_FPS), (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
while cap.isOpened():
ret, img = cap.read()
if not ret:
break
# preprocess
img_input = letterbox(img, img_size, stride=stride)[0]
img_input = img_input.transpose((2, 0, 1))[::-1]
img_input = np.ascontiguousarray(img_input)
img_input = torch.from_numpy(img_input).to(device)
img_input = img_input.float()
img_input /= 255.
img_input = img_input.unsqueeze(0)
# inference
pred = model(img_input, augment=False, visualize=False)[0]
# postprocess
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)[0]
pred = pred.cpu().numpy()
pred[:, :4] = scale_coords(img_input.shape[2:], pred[:, :4], img.shape).round()
annotator = Annotator(img.copy(), line_width=3, example=str(class_names), font='data/malgun.ttf')
for p in pred:
class_name = class_names[int(p[5])]
x1, y1, x2, y2 = p[:4]
annotator.box_label([x1, y1, x2, y2], '%s %d' % (class_name, float(p[4]) * 100), color=colors[int(p[5])])
result_img = annotator.result()
cv2.imshow('result', result_img)
out.write(result_img)
if cv2.waitKey(1) == ord('q'):
break
cap.release()
out.release()