Skip to content

Commit

Permalink
update the dataset API
Browse files Browse the repository at this point in the history
  • Loading branch information
nico-zck committed Apr 23, 2021
1 parent beb714c commit f3e6cec
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 58 deletions.
2 changes: 1 addition & 1 deletion dataset_api/eval_pixel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
@Author : Nico
"""
from collections import OrderedDict
from typing import Tuple

import numpy as np
from typing import Tuple


def _compute_confusion_matrix(binary_target: np.ndarray, binary_pred: np.ndarray):
Expand Down
4 changes: 2 additions & 2 deletions dataset_api/eval_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def __f1(pre, rec):


def evaluation_region(binary_pixel_pred: np.ndarray, binary_pixel_target: np.ndarray,
diff_scale: bool = False, return_info: bool = False):
diff_size: bool = False, return_info: bool = False):
binary_pixel_pred = binary_pixel_pred.squeeze()
binary_pixel_target = binary_pixel_target.squeeze()

Expand Down Expand Up @@ -288,7 +288,7 @@ def evaluation_region(binary_pixel_pred: np.ndarray, binary_pixel_target: np.nda
region_metrics['Rec'] = _recall_for_region_target(info_target_overlap, iou_thresh=IOUs,
ol_thresh=OVERLAP_THRESH)
region_metrics['F1'] = __f1(region_metrics['Pre'], region_metrics['Rec'])
if diff_scale:
if diff_size:
region_metrics['Pre_small'] = _precision_for_region_pred(info_pred_overlap, iou_thresh=IOUs,
ol_thresh=OVERLAP_THRESH, scale='small')
region_metrics['Rec_small'] = _recall_for_region_target(info_target_overlap, iou_thresh=IOUs,
Expand Down
50 changes: 34 additions & 16 deletions dataset_api/zl_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class ZLEval:
Evaluator for the ZJU-Leaper Dataset.
"""

def __init__(self, binary_pixel_target: np.ndarray, binary_pixel_pred: np.ndarray):
def __init__(self, binary_pixel_target: np.ndarray, binary_pixel_pred: np.ndarray,
eval_diff_size: bool = True):
"""
Create an object to evaluate an inspection algorithm.
:param binary_pixel_target: binarized pixel-wise ground truths.
Expand All @@ -39,6 +40,8 @@ def __init__(self, binary_pixel_target: np.ndarray, binary_pixel_pred: np.ndarra
self.binary_pixel_target = binary_pixel_target
self.binary_pixel_pred = binary_pixel_pred

self.eval_diff_size = eval_diff_size

def evaluate(self) -> dict:
"""
Calculate all metrics on the results and return metrics as a dict.
Expand All @@ -50,30 +53,45 @@ def evaluate(self) -> dict:
binary_pixel_target=self.binary_pixel_target)
region_metrics, info_region = evaluation_region(binary_pixel_pred=self.binary_pixel_pred,
binary_pixel_target=self.binary_pixel_target,
diff_scale=True, return_info=True)
diff_size=self.eval_diff_size, return_info=True)
sample_metrics = evaluation_sample(binary_pixel_pred=self.binary_pixel_pred,
binary_pixel_target=self.binary_pixel_target, info_region=info_region)
summary_score = 0.4 * pixel_metrics['Dice'] + 0.4 * region_metrics['F1'] + 0.2 * sample_metrics['F1']

self.metrics_dict = OrderedDict(
Pix_Dice=pixel_metrics['Dice'],
Reg_F1=region_metrics['F1'],
Sam_F1=sample_metrics['F1'],

## other metrics
Reg_Pre_S=region_metrics['Pre_small'],
Reg_Rec_S=region_metrics['Rec_small'],
Reg_Pre_M=region_metrics['Pre_medium'],
Reg_Rec_M=region_metrics['Rec_medium'],
Reg_Pre_L=region_metrics['Pre_large'],
Reg_Rec_L=region_metrics['Rec_large'],
Reg_F1_S=region_metrics['F1_small'],
Reg_F1_M=region_metrics['F1_medium'],
Reg_F1_L=region_metrics['F1_large'],
# Pix_Pre=pixel_metrics['Pre'],
# Pix_Rec=pixel_metrics['Rec'],
F1_Pix=pixel_metrics['Dice'],

# Reg_Pre=region_metrics['Pre'],
# Reg_Rec=region_metrics['Rec'],
F1_Reg=region_metrics['F1'],

# Sam_Pre=sample_metrics['Pre'],
# Sam_Rec=sample_metrics['Rec'],
F1_Sam=sample_metrics['F1'],
# Sam_Acc=sample_metrics['Acc'],
# Sam_FPR=sample_metrics['FPR'],

SCORE=summary_score,
)

if self.eval_diff_size:
self.metrics_dict.update(dict(
## other metrics
Reg_Pre_S=region_metrics['Pre_small'],
Reg_Rec_S=region_metrics['Rec_small'],
Reg_F1_S=region_metrics['F1_small'],

Reg_Pre_M=region_metrics['Pre_medium'],
Reg_Rec_M=region_metrics['Rec_medium'],
Reg_F1_M=region_metrics['F1_medium'],

Reg_Pre_L=region_metrics['Pre_large'],
Reg_Rec_L=region_metrics['Rec_large'],
Reg_F1_L=region_metrics['F1_large'],
))

return self.metrics_dict

def summarize(self):
Expand Down
79 changes: 40 additions & 39 deletions dataset_api/zl_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def info(self) -> dict:
:return: {'pattern': int, 'id': int, 'defective': boolean}
"""
xml_path = self.xml_path.format(id=self.id)
p_id = int(ET.parse(xml_path).find('pattern').text)
p_id = int(ET.parse(xml_path).find('pattern_id').text)
defect_flag = bool(int(ET.parse(xml_path).find('defective').text))
info = {'pattern': p_id, 'id': self.id, 'defective': defect_flag}
info = {'pattern_id': p_id, 'id': self.id, 'defective': defect_flag}
return info

def annotation(self, ann_type: str = None):
Expand Down Expand Up @@ -139,7 +139,7 @@ def annotation(self, ann_type: str = None):


class ZLFabric:
def __init__(self, dir: str, fabric: Union[str, int], setting: Union[str, int], seed: int = None):
def __init__(self, dir: str, fabric: Union[str, int], setting: Union[str, int], seed: int = 0):
"""
Create an object to manage ZJU-Leaper dataset.
Expand Down Expand Up @@ -178,9 +178,9 @@ def __init__(self, dir: str, fabric: Union[str, int], setting: Union[str, int],

self.rnd = random.Random(seed)

def create_zl_imgs_given_ids(self, ids: list, subset: str, ann_type: str) -> ZLIMGS:
def _create_zl_imgs_given_ids(self, ids: list, subset: str, ann_type: str) -> ZLIMGS:
"""
Create ZLImage objects given the image IDs
:param ids:
:param subset: ["none", "small", "train", "dev", "test"]
:param ann_type: ["none", "label", "bbox", "mask"]
Expand All @@ -192,6 +192,7 @@ def create_zl_imgs_given_ids(self, ids: list, subset: str, ann_type: str) -> ZLI
if subset == 'none':
ids = []
elif subset == 'small':
### WARNING: the alteration of random seed will effect the sample of "small" subset
ids = self.rnd.sample(ids, len(ids) // 10)
else:
pass
Expand Down Expand Up @@ -219,20 +220,20 @@ def prepare_train(self) -> Tuple[ZLIMGS, ZLIMGS, ZLIMGS, ZLIMGS]:
ids_train_defect = ids_json['defect']['train']

# train
zlimgs_train_normal = self.create_zl_imgs_given_ids(ids=ids_train_normal,
subset=CONFIG[self.setting]['normal_train'],
ann_type=CONFIG[self.setting]['ann_train'])
zlimgs_train_defect = self.create_zl_imgs_given_ids(ids=ids_train_defect,
subset=CONFIG[self.setting]['defect_train'],
ann_type=CONFIG[self.setting]['ann_train'])
zlimgs_train_normal = self._create_zl_imgs_given_ids(ids=ids_train_normal,
subset=CONFIG[self.setting]['normal_train'],
ann_type=CONFIG[self.setting]['ann_train'])
zlimgs_train_defect = self._create_zl_imgs_given_ids(ids=ids_train_defect,
subset=CONFIG[self.setting]['defect_train'],
ann_type=CONFIG[self.setting]['ann_train'])

# train eval
zlimgs_train_eval_normal = self.create_zl_imgs_given_ids(ids=ids_train_normal,
subset=CONFIG[self.setting]['normal_train'],
ann_type=CONFIG[self.setting]['ann_eval'])
zlimgs_train_eval_defect = self.create_zl_imgs_given_ids(ids=ids_train_defect,
subset=CONFIG[self.setting]['defect_train'],
ann_type=CONFIG[self.setting]['ann_eval'])
zlimgs_train_eval_normal = self._create_zl_imgs_given_ids(ids=ids_train_normal,
subset=CONFIG[self.setting]['normal_train'],
ann_type=CONFIG[self.setting]['ann_eval'])
zlimgs_train_eval_defect = self._create_zl_imgs_given_ids(ids=ids_train_defect,
subset=CONFIG[self.setting]['defect_train'],
ann_type=CONFIG[self.setting]['ann_eval'])

return zlimgs_train_normal, zlimgs_train_defect, zlimgs_train_eval_normal, zlimgs_train_eval_defect

Expand Down Expand Up @@ -267,24 +268,24 @@ def prepare_k_fold(self, k_fold: int, shuffle: bool = True) \
ids_k_train_defect = list(itertools.chain(*_ids_folds_defect))

# train
zlimgs_k_train_normal = self.create_zl_imgs_given_ids(ids=ids_k_train_normal,
subset=CONFIG[self.setting]['normal_train'],
ann_type=CONFIG[self.setting]['ann_train'])
zlimgs_k_train_defect = self.create_zl_imgs_given_ids(ids=ids_k_train_defect,
subset=CONFIG[self.setting]['defect_train'],
ann_type=CONFIG[self.setting]['ann_train'])
zlimgs_k_train_normal = self._create_zl_imgs_given_ids(ids=ids_k_train_normal,
subset=CONFIG[self.setting]['normal_train'],
ann_type=CONFIG[self.setting]['ann_train'])
zlimgs_k_train_defect = self._create_zl_imgs_given_ids(ids=ids_k_train_defect,
subset=CONFIG[self.setting]['defect_train'],
ann_type=CONFIG[self.setting]['ann_train'])
# train-eval
zlimgs_k_train_eval_normal = self.create_zl_imgs_given_ids(ids=ids_k_train_normal,
subset=CONFIG[self.setting]['normal_train'],
ann_type=CONFIG[self.setting]['ann_eval'])
zlimgs_k_train_eval_defect = self.create_zl_imgs_given_ids(ids=ids_k_train_defect,
subset=CONFIG[self.setting]['defect_train'],
ann_type=CONFIG[self.setting]['ann_eval'])
zlimgs_k_train_eval_normal = self._create_zl_imgs_given_ids(ids=ids_k_train_normal,
subset=CONFIG[self.setting]['normal_train'],
ann_type=CONFIG[self.setting]['ann_eval'])
zlimgs_k_train_eval_defect = self._create_zl_imgs_given_ids(ids=ids_k_train_defect,
subset=CONFIG[self.setting]['defect_train'],
ann_type=CONFIG[self.setting]['ann_eval'])
# dev
zlimgs_k_dev_normal = self.create_zl_imgs_given_ids(ids=ids_k_dev_normal, subset='dev',
ann_type=CONFIG[self.setting]['ann_eval'])
zlimgs_k_dev_defect = self.create_zl_imgs_given_ids(ids=ids_k_dev_defect, subset='dev',
ann_type=CONFIG[self.setting]['ann_eval'])
zlimgs_k_dev_normal = self._create_zl_imgs_given_ids(ids=ids_k_dev_normal, subset='dev',
ann_type=CONFIG[self.setting]['ann_eval'])
zlimgs_k_dev_defect = self._create_zl_imgs_given_ids(ids=ids_k_dev_defect, subset='dev',
ann_type=CONFIG[self.setting]['ann_eval'])
zlimgs_folds.append((zlimgs_k_train_normal, zlimgs_k_train_defect,
zlimgs_k_train_eval_normal, zlimgs_k_train_eval_defect,
zlimgs_k_dev_normal, zlimgs_k_dev_defect))
Expand All @@ -301,10 +302,10 @@ def prepare_test(self) -> Tuple[ZLIMGS, ZLIMGS]:
ids_test_defect = ids_json['defect']['test']

# test
zlimgs_test_normal = self.create_zl_imgs_given_ids(ids=ids_test_normal,
subset=CONFIG[self.setting]['normal_test'],
ann_type=CONFIG[self.setting]['ann_test'])
zlimgs_test_defect = self.create_zl_imgs_given_ids(ids=ids_test_defect,
subset=CONFIG[self.setting]['defect_test'],
ann_type=CONFIG[self.setting]['ann_test'])
zlimgs_test_normal = self._create_zl_imgs_given_ids(ids=ids_test_normal,
subset=CONFIG[self.setting]['normal_test'],
ann_type=CONFIG[self.setting]['ann_test'])
zlimgs_test_defect = self._create_zl_imgs_given_ids(ids=ids_test_defect,
subset=CONFIG[self.setting]['defect_test'],
ann_type=CONFIG[self.setting]['ann_test'])
return zlimgs_test_normal, zlimgs_test_defect

0 comments on commit f3e6cec

Please sign in to comment.