forked from thtrieu/darkflow
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDarkflowModel.py
90 lines (81 loc) · 3.12 KB
/
DarkflowModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python
# Imports
from darkflow.net.build import TFNet
import tensorflow as tf
class DarkflowModel:
def __init__(self,
model_config,
threshold: float,
labels='./labels.txt',
gpu_usage: float = 0.5,
weights=None,
model_pb=None,
model_meta=None,
checkpoint=None,
construct: bool = True):
self.print_this("BEGIN INITIALIZATION")
self.model_config = model_config
self.labels = labels
self.id_dict = self.load_id_dict(self.labels)
self.model_pb = model_pb
self.model_meta = model_meta
self.use_pb_load = False
if model_pb is not None:
weights_info = "Loading from protobuf file: " + str(self.model_pb)
self.use_pb_load = True
elif weights is not None:
self.load = weights
weights_info = "Loading from weights file: " + str(self.load)
elif checkpoint is not None:
self.load = checkpoint
weights_info = "Loading from checkpoint: " + str(self.load)
self.threshold = threshold
self.gpu_usage = gpu_usage
self.model = None
if construct:
self.construct()
self.print_this("Loading from config file: " + str(self.model_config))
self.print_this("Loading labels from: " + str(self.labels))
self.print_this("Detection Threshold: " + str(self.threshold))
self.print_this("GPU Usage: " + str(self.gpu_usage))
self.print_this(weights_info)
self.print_this("INITIALIZATION COMPLETE")
def load_id_dict(self, labels):
id_dict = {}
with open(labels, 'r') as lf:
for num, line in enumerate(lf, 1):
id_dict[line.strip()] = (num-1)
return id_dict
def construct(self):
self.print_this("Constructing network")
config = tf.ConfigProto(log_device_placement=True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
if self.use_pb_load:
self.print_this("from protobuf")
options = {
'model': self.model_config,
'labels': self.labels,
'pbload': self.model_pb,
'metaload': self.model_pb,
'threshold': self.threshold,
'gpu': self.gpu_usage
}
else:
options = {
'model': self.model_config,
'labels':self.labels,
'load': self.load,
'threshold': self.threshold,
'gpu': self.gpu_usage
}
self.model = TFNet(options)
self.print_this("Network Contructed")
def infer_one(self, im):
self.print_this("Inferring on image... ")
result = self.model.return_predict(im)
self.print_this("Inference complete!")
return result
def print_this(self, to_print):
print("[DARKFLOW-MODEL]: ", end="")
print(to_print)