-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
53 lines (40 loc) · 1.54 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import base64
import torch
from io import BytesIO
from PIL import Image
from torchvision import transforms
import joblib
from utils import ctob, nms
class YoLo:
def __init__(self, image):
transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(256,256)),
transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])
im_bytes = base64.b64decode(image) # im_bytes is a binary image
im_file = BytesIO(im_bytes) # convert image to file-like object
img = Image.open(im_file)
self.image = transformer(img).reshape((1,3,256,256))
self.anchor = torch.tensor([[[0.2788, 0.2163], [0.3750, 0.4760], [0.8966, 0.7837]],
[[0.0721, 0.1466], [0.1490, 0.1082], [0.1418, 0.2861]],
[[0.0240, 0.0312], [0.0385, 0.0721], [0.0793, 0.0553]]])
def load_model(self):
return torch.jit.load('Model/model_398_trace.pt')
def load_classes(self):
return joblib.load('Model/class_dict.pickle')
def predict(self):
model = self.load_model()
model.eval()
with torch.no_grad():
prediction = model(self.image)
return prediction
def getbbox(self):
prediction = self.predict()
bboxes = ctob(prediction, self.anchor)
prediction = nms(bboxes[0],self.load_classes(), 0.05, 0.5)
return prediction
def predict_box(image):
model = YoLo(image)
result = model.getbbox()
return result