diff --git a/.gitignore b/.gitignore index 7290d7c..dc0f659 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ Medical_Datasets/ lfw/ logs/ model_data/ +.temp_map_out/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/get_map.py b/get_map.py index f4a0ecb..5b92c9f 100644 --- a/get_map.py +++ b/get_map.py @@ -15,11 +15,11 @@ if __name__ == "__main__": ''' - Recall和Precision不像AP是一个面积的概念,在门限值不同时,网络的Recall和Precision值是不同的。 - map计算结果中的Recall和Precision代表的是当预测时,门限置信度为0.5时,所对应的Recall和Precision值。 + Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。 + 默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。 - 此处获得的./map_out/detection-results/里面的txt的框的数量会比直接predict多一些,这是因为这里的门限低, - 目的是为了计算不同门限条件下的Recall和Precision值,从而实现map的计算。 + 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值 + 因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框, ''' #------------------------------------------------------------------------------------------------------------------# # map_mode用于指定该文件运行时计算的内容 @@ -30,16 +30,41 @@ # map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行 #-------------------------------------------------------------------------------------------------------------------# map_mode = 0 - #-------------------------------------------------------# + #--------------------------------------------------------------------------------------# # 此处的classes_path用于指定需要测量VOC_map的类别 # 一般情况下与训练和预测所用的classes_path一致即可 - #-------------------------------------------------------# + #--------------------------------------------------------------------------------------# classes_path = 'model_data/voc_classes.txt' - #-------------------------------------------------------# - # MINOVERLAP用于指定想要获得的mAP0.x + #--------------------------------------------------------------------------------------# + # MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。 # 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。 - #-------------------------------------------------------# + # + # 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。 + # 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低, + #--------------------------------------------------------------------------------------# MINOVERLAP = 0.5 + #--------------------------------------------------------------------------------------# + # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP + # 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。 + # + # 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。 + # 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。 + #--------------------------------------------------------------------------------------# + confidence = 0.001 + #--------------------------------------------------------------------------------------# + # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。 + # + # 该值一般不调整。 + #--------------------------------------------------------------------------------------# + nms_iou = 0.5 + #---------------------------------------------------------------------------------------------------------------# + # Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。 + # + # 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。 + # 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。 + # 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。 + #---------------------------------------------------------------------------------------------------------------# + score_threhold = 0.5 #-------------------------------------------------------# # map_vis用于指定是否开启VOC_map计算的可视化 #-------------------------------------------------------# @@ -69,7 +94,7 @@ if map_mode == 0 or map_mode == 1: print("Load model.") - yolo = YOLO(confidence = 0.001, nms_iou = 0.5) + yolo = YOLO(confidence = confidence, nms_iou = nms_iou) print("Load model done.") print("Get predict result.") @@ -109,7 +134,7 @@ if map_mode == 0 or map_mode == 3: print("Get map.") - get_map(MINOVERLAP, True, path = map_out_path) + get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path) print("Get map done.") if map_mode == 4: diff --git a/train.py b/train.py index e3cd43f..d013415 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,7 @@ from nets.yolo import get_train_model, yolo_body from nets.yolo_training import get_lr_scheduler -from utils.callbacks import LossHistory, ModelCheckpoint +from utils.callbacks import EvalCallback, LossHistory, ModelCheckpoint from utils.dataloader import YoloDatasets from utils.utils import get_anchors, get_classes, show_config from utils.utils_fit import fit_one_epoch @@ -173,6 +173,17 @@ #------------------------------------------------------------------# save_dir = 'logs' #------------------------------------------------------------------# + # eval_flag 是否在训练时进行评估,评估对象为验证集 + # 安装pycocotools库后,评估体验更佳。 + # eval_period 代表多少个epoch评估一次,不建议频繁的评估 + # 评估需要消耗较多的时间,频繁评估会导致训练非常慢 + # 此处获得的mAP会与get_map.py获得的会有所不同,原因有二: + # (一)此处获得的mAP为验证集的mAP。 + # (二)此处设置评估参数较为保守,目的是加快评估速度。 + #------------------------------------------------------------------# + eval_flag = True + eval_period = 10 + #------------------------------------------------------------------# # num_workers 用于设置是否使用多线程读取数据,1代表关闭多线程 # 开启后会加快数据读取速度,但是会占用更多内存 # keras里开启多线程有些时候速度反而慢了许多 @@ -335,6 +346,8 @@ time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S') log_dir = os.path.join(save_dir, "loss_" + str(time_str)) loss_history = LossHistory(log_dir) + eval_callback = EvalCallback(model_body, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, \ + eval_flag=eval_flag, period=eval_period) #---------------------------------------# # 开始模型训练 #---------------------------------------# @@ -386,7 +399,7 @@ lr = lr_scheduler_func(epoch) K.set_value(optimizer.lr, lr) - fit_one_epoch(model_body, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, + fit_one_epoch(model_body, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, end_epoch, input_shape, anchors, anchors_mask, num_classes, save_period, save_dir, strategy) train_dataloader.on_epoch_end() @@ -419,7 +432,9 @@ monitor = 'val_loss', save_weights_only = True, save_best_only = True, period = 1) early_stopping = EarlyStopping(monitor='val_loss', min_delta = 0, patience = 10, verbose = 1) lr_scheduler = LearningRateScheduler(lr_scheduler_func, verbose = 1) - callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler] + eval_callback = EvalCallback(model_body, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, \ + eval_flag=eval_flag, period=eval_period) + callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler, eval_callback] if start_epoch < end_epoch: print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size)) @@ -456,7 +471,7 @@ #---------------------------------------# lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) lr_scheduler = LearningRateScheduler(lr_scheduler_func, verbose = 1) - callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler] + callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler, eval_callback] for i in range(len(model_body.layers)): model_body.layers[i].trainable = True diff --git a/utils/callbacks.py b/utils/callbacks.py index 92d893e..b2a621c 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -4,10 +4,22 @@ import matplotlib matplotlib.use('Agg') from matplotlib import pyplot as plt -import numpy as np import scipy.signal + +import shutil +import numpy as np +import tensorflow as tf + from tensorflow import keras from tensorflow.keras import backend as K +from tensorflow.keras.layers import Input, Lambda +from tensorflow.keras.models import Model +from PIL import Image +from tqdm import tqdm +from .utils import cvtColor, preprocess_input, resize_image +from .utils_bbox import DecodeBox +from .utils_map import get_coco_map, get_map + class LossHistory(keras.callbacks.Callback): def __init__(self, log_dir): @@ -75,6 +87,169 @@ def on_epoch_end(self, batch, logs=None): if self.verbose > 0: print('Setting learning rate to %s.' % (learning_rate)) +class EvalCallback(keras.callbacks.Callback): + def __init__(self, model_body, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir,\ + map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1): + super(EvalCallback, self).__init__() + + self.model_body = model_body + self.input_shape = input_shape + self.anchors = anchors + self.anchors_mask = anchors_mask + self.class_names = class_names + self.num_classes = num_classes + self.val_lines = val_lines + self.log_dir = log_dir + self.map_out_path = map_out_path + self.max_boxes = max_boxes + self.confidence = confidence + self.nms_iou = nms_iou + self.letterbox_image = letterbox_image + self.MINOVERLAP = MINOVERLAP + self.eval_flag = eval_flag + self.period = period + + #---------------------------------------------------------# + # 在DecodeBox函数中,我们会对预测结果进行后处理 + # 后处理的内容包括,解码、非极大抑制、门限筛选等 + #---------------------------------------------------------# + self.input_image_shape = Input([2,],batch_size=1) + inputs = [*self.model_body.output, self.input_image_shape] + outputs = Lambda( + DecodeBox, + output_shape = (1,), + name = 'yolo_eval', + arguments = { + 'anchors' : self.anchors, + 'num_classes' : self.num_classes, + 'input_shape' : self.input_shape, + 'anchor_mask' : self.anchors_mask, + 'confidence' : self.confidence, + 'nms_iou' : self.nms_iou, + 'max_boxes' : self.max_boxes, + 'letterbox_image' : self.letterbox_image + } + )(inputs) + self.yolo_model = Model([self.model_body.input, self.input_image_shape], outputs) + + self.maps = [0] + self.epoches = [0] + if self.eval_flag: + with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: + f.write(str(0)) + f.write("\n") + + @tf.function + def get_pred(self, image_data, input_image_shape): + out_boxes, out_scores, out_classes = self.yolo_model([image_data, input_image_shape], training=False) + return out_boxes, out_scores, out_classes + + def get_map_txt(self, image_id, image, class_names, map_out_path): + f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w") + #---------------------------------------------------------# + # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 + #---------------------------------------------------------# + image = cvtColor(image) + #---------------------------------------------------------# + # 给图像增加灰条,实现不失真的resize + # 也可以直接resize进行识别 + #---------------------------------------------------------# + image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image) + #---------------------------------------------------------# + # 添加上batch_size维度,并进行归一化 + #---------------------------------------------------------# + image_data = np.expand_dims(preprocess_input(np.array(image_data, dtype='float32')), 0) + + #---------------------------------------------------------# + # 将图像输入网络当中进行预测! + #---------------------------------------------------------# + input_image_shape = np.expand_dims(np.array([image.size[1], image.size[0]], dtype='float32'), 0) + outputs = self.get_pred(image_data, input_image_shape) + out_boxes, out_scores, out_classes = [out.numpy() for out in outputs] + + top_100 = np.argsort(out_scores)[::-1][:self.max_boxes] + out_boxes = out_boxes[top_100] + out_scores = out_scores[top_100] + out_classes = out_classes[top_100] + + for i, c in enumerate(out_classes): + predicted_class = self.class_names[int(c)] + try: + score = str(out_scores[i].numpy()) + except: + score = str(out_scores[i]) + top, left, bottom, right = out_boxes[i] + if predicted_class not in class_names: + continue + + f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) + + f.close() + return + + def on_epoch_end(self, epoch, logs=None): + temp_epoch = epoch + 1 + if temp_epoch % self.period == 0 and self.eval_flag: + if not os.path.exists(self.map_out_path): + os.makedirs(self.map_out_path) + if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")): + os.makedirs(os.path.join(self.map_out_path, "ground-truth")) + if not os.path.exists(os.path.join(self.map_out_path, "detection-results")): + os.makedirs(os.path.join(self.map_out_path, "detection-results")) + print("Get map.") + for annotation_line in tqdm(self.val_lines): + line = annotation_line.split() + image_id = os.path.basename(line[0]).split('.')[0] + #------------------------------# + # 读取图像并转换成RGB图像 + #------------------------------# + image = Image.open(line[0]) + #------------------------------# + # 获得预测框 + #------------------------------# + gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) + #------------------------------# + # 获得预测txt + #------------------------------# + self.get_map_txt(image_id, image, self.class_names, self.map_out_path) + + #------------------------------# + # 获得真实框txt + #------------------------------# + with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: + for box in gt_boxes: + left, top, right, bottom, obj = box + obj_name = self.class_names[obj] + new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) + + print("Calculate Map.") + try: + temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1] + except: + temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path) + self.maps.append(temp_map) + self.epoches.append(temp_epoch) + + with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: + f.write(str(temp_map)) + f.write("\n") + + plt.figure() + plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map') + + plt.grid(True) + plt.xlabel('Epoch') + plt.ylabel('Map %s'%str(self.MINOVERLAP)) + plt.title('A Map Curve') + plt.legend(loc="upper right") + + plt.savefig(os.path.join(self.log_dir, "epoch_map.png")) + plt.cla() + plt.close("all") + + print("Get map done.") + shutil.rmtree(self.map_out_path) + class ModelCheckpoint(keras.callbacks.Callback): def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, diff --git a/utils/utils_fit.py b/utils/utils_fit.py index d7b5b09..a203915 100644 --- a/utils/utils_fit.py +++ b/utils/utils_fit.py @@ -74,7 +74,7 @@ def distributed_val_step(images, targets, net, optimizer): axis=None) return distributed_val_step -def fit_one_epoch(net, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, +def fit_one_epoch(net, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, input_shape, anchors, anchors_mask, num_classes, save_period, save_dir, strategy): train_step = get_train_step_fn(input_shape, anchors, anchors_mask, num_classes, strategy) @@ -113,6 +113,7 @@ def fit_one_epoch(net, loss_history, optimizer, epoch, epoch_step, epoch_step_va logs = {'loss': loss.numpy() / epoch_step, 'val_loss': val_loss.numpy() / epoch_step_val} loss_history.on_epoch_end([], logs) + eval_callback.on_epoch_end(epoch, logs) print('Epoch:'+ str(epoch+1) + '/' + str(Epoch)) print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val)) diff --git a/utils/utils_map.py b/utils/utils_map.py index 7b69140..c777322 100644 --- a/utils/utils_map.py +++ b/utils/utils_map.py @@ -5,9 +5,15 @@ import os import shutil import sys - +try: + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval +except: + pass import cv2 -import matplotlib.pyplot as plt +import matplotlib +matplotlib.use('Agg') +from matplotlib import pyplot as plt import numpy as np ''' @@ -267,7 +273,7 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out # close the plot plt.close() -def get_map(MINOVERLAP, draw_plot, path = './map_out'): +def get_map(MINOVERLAP, draw_plot, score_threhold=0.5, path = './map_out'): GT_PATH = os.path.join(path, 'ground-truth') DR_PATH = os.path.join(path, 'detection-results') IMG_PATH = os.path.join(path, 'images-optional') @@ -287,7 +293,13 @@ def get_map(MINOVERLAP, draw_plot, path = './map_out'): if os.path.exists(RESULTS_FILES_PATH): shutil.rmtree(RESULTS_FILES_PATH) + else: + os.makedirs(RESULTS_FILES_PATH) if draw_plot: + try: + matplotlib.use('TkAgg') + except: + pass os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP")) os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1")) os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall")) @@ -421,12 +433,12 @@ def get_map(MINOVERLAP, draw_plot, path = './map_out'): tp = [0] * nd fp = [0] * nd score = [0] * nd - score05_idx = 0 + score_threhold_idx = 0 for idx, detection in enumerate(dr_data): file_id = detection["file_id"] score[idx] = float(detection["confidence"]) - if score[idx] > 0.5: - score05_idx = idx + if score[idx] >= score_threhold: + score_threhold_idx = idx if show_animation: ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*") @@ -564,9 +576,9 @@ def get_map(MINOVERLAP, draw_plot, path = './map_out'): text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100) if len(prec)>0: - F1_text = "{0:.2f}".format(F1[score05_idx]) + " = " + class_name + " F1 " - Recall_text = "{0:.2f}%".format(rec[score05_idx]*100) + " = " + class_name + " Recall " - Precision_text = "{0:.2f}%".format(prec[score05_idx]*100) + " = " + class_name + " Precision " + F1_text = "{0:.2f}".format(F1[score_threhold_idx]) + " = " + class_name + " F1 " + Recall_text = "{0:.2f}%".format(rec[score_threhold_idx]*100) + " = " + class_name + " Recall " + Precision_text = "{0:.2f}%".format(prec[score_threhold_idx]*100) + " = " + class_name + " Precision " else: F1_text = "0.00" + " = " + class_name + " F1 " Recall_text = "0.00%" + " = " + class_name + " Recall " @@ -575,11 +587,12 @@ def get_map(MINOVERLAP, draw_plot, path = './map_out'): rounded_prec = [ '%.2f' % elem for elem in prec ] rounded_rec = [ '%.2f' % elem for elem in rec ] results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n") + if len(prec)>0: - print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\ - + " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100)) + print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=" + "{0:.2f}".format(F1[score_threhold_idx])\ + + " ; Recall=" + "{0:.2f}%".format(rec[score_threhold_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score_threhold_idx]*100)) else: - print(text + "\t||\tscore_threhold=0.5 : F1=0.00% ; Recall=0.00% ; Precision=0.00%") + print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=0.00% ; Recall=0.00% ; Precision=0.00%") ap_dictionary[class_name] = ap n_images = counter_images_per_class[class_name] @@ -605,7 +618,7 @@ def get_map(MINOVERLAP, draw_plot, path = './map_out'): plt.cla() plt.plot(score, F1, "-", color='orangered') - plt.title('class: ' + F1_text + "\nscore_threhold=0.5") + plt.title('class: ' + F1_text + "\nscore_threhold=" + str(score_threhold)) plt.xlabel('Score_Threhold') plt.ylabel('F1') axes = plt.gca() @@ -615,7 +628,7 @@ def get_map(MINOVERLAP, draw_plot, path = './map_out'): plt.cla() plt.plot(score, rec, "-H", color='gold') - plt.title('class: ' + Recall_text + "\nscore_threhold=0.5") + plt.title('class: ' + Recall_text + "\nscore_threhold=" + str(score_threhold)) plt.xlabel('Score_Threhold') plt.ylabel('Recall') axes = plt.gca() @@ -625,7 +638,7 @@ def get_map(MINOVERLAP, draw_plot, path = './map_out'): plt.cla() plt.plot(score, prec, "-s", color='palevioletred') - plt.title('class: ' + Precision_text + "\nscore_threhold=0.5") + plt.title('class: ' + Precision_text + "\nscore_threhold=" + str(score_threhold)) plt.xlabel('Score_Threhold') plt.ylabel('Precision') axes = plt.gca() @@ -636,7 +649,9 @@ def get_map(MINOVERLAP, draw_plot, path = './map_out'): if show_animation: cv2.destroyAllWindows() - + if n_classes == 0: + print("未检测到任何种类,请检查标签信息与get_map.py中的classes_path是否修改。") + return 0 results_file.write("\n# mAP of all classes\n") mAP = sum_AP / n_classes text = "mAP = {0:.2f}%".format(mAP*100) @@ -780,6 +795,7 @@ def get_map(MINOVERLAP, draw_plot, path = './map_out'): plot_color, "" ) + return mAP def preprocess_gt(gt_path, class_names): image_ids = os.listdir(gt_path) @@ -820,6 +836,8 @@ def preprocess_gt(gt_path, class_names): class_name = class_name[:-1] left, top, right, bottom = float(left), float(top), float(right), float(bottom) + if class_name not in class_names: + continue cls_id = class_names.index(class_name) + 1 bbox = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0] boxes_per_image.append(bbox) @@ -865,6 +883,8 @@ def preprocess_dr(dr_path, class_names): left, top, right, bottom = float(left), float(top), float(right), float(bottom) result = {} result["image_id"] = str(image_id) + if class_name not in class_names: + continue result["category_id"] = class_names.index(class_name) + 1 result["bbox"] = [left, top, right - left, bottom - top] result["score"] = float(confidence) @@ -872,9 +892,6 @@ def preprocess_dr(dr_path, class_names): return results def get_coco_map(class_names, path): - from pycocotools.coco import COCO - from pycocotools.cocoeval import COCOeval - GT_PATH = os.path.join(path, 'ground-truth') DR_PATH = os.path.join(path, 'detection-results') COCO_PATH = os.path.join(path, 'coco_eval') @@ -892,6 +909,9 @@ def get_coco_map(class_names, path): with open(DR_JSON_PATH, "w") as f: results_dr = preprocess_dr(DR_PATH, class_names) json.dump(results_dr, f, indent=4) + if len(results_dr) == 0: + print("未检测到任何目标。") + return [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] cocoGt = COCO(GT_JSON_PATH) cocoDt = cocoGt.loadRes(DR_JSON_PATH) @@ -899,3 +919,5 @@ def get_coco_map(class_names, path): cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() + + return cocoEval.stats \ No newline at end of file diff --git a/yolo.py b/yolo.py index 313d5c4..ec30622 100644 --- a/yolo.py +++ b/yolo.py @@ -288,7 +288,6 @@ def sigmoid(x): print("Save to the " + heatmap_save_path) plt.show() - #---------------------------------------------------# # 检测图片 #---------------------------------------------------#