Skip to content

Commit

Permalink
update layout_postprocess (#2637)
Browse files Browse the repository at this point in the history
* update layout_postprocess

* update layout_postprocess

* update layout_postprocess
  • Loading branch information
Sunting78 authored Dec 12, 2024
1 parent e98fde5 commit 4d6b62f
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 6 deletions.
119 changes: 115 additions & 4 deletions paddlex/inference/components/task_related/det.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,55 @@ def apply(self, img):
}


def compute_iou(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
box1_area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
box2_area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
iou = inter_area / float(box1_area + box2_area - inter_area)
return iou


def is_box_mostly_inside(inner_box, outer_box, threshold=0.9):
x1 = max(inner_box[0], outer_box[0])
y1 = max(inner_box[1], outer_box[1])
x2 = min(inner_box[2], outer_box[2])
y2 = min(inner_box[3], outer_box[3])
inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
inner_box_area = (inner_box[2] - inner_box[0] + 1) * (inner_box[3] - inner_box[1] + 1)
return (inter_area / inner_box_area) >= threshold


def non_max_suppression(boxes, scores, iou_threshold):
if len(boxes) == 0:
return []
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])

w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
iou = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(iou <= iou_threshold)[0]
order = order[inds + 1]
return keep


class DetPostProcess(BaseComponent):
"""Save Result Transform"""

Expand All @@ -236,15 +285,78 @@ class DetPostProcess(BaseComponent):
DEAULT_INPUTS = {"boxes": "boxes", "img_size": "ori_img_size"}
DEAULT_OUTPUTS = {"boxes": "boxes"}

def __init__(self, threshold=0.5, labels=None):
def __init__(self, threshold=0.5, labels=None, layout_postprocess=False):
super().__init__()
self.threshold = threshold
self.labels = labels
self.layout_postprocess = layout_postprocess

def apply(self, boxes, img_size):
"""apply"""
expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
boxes = boxes[expect_boxes, :]
if isinstance(self.threshold, float):
expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
boxes = boxes[expect_boxes, :]
elif isinstance(self.threshold, dict):
category_filtered_boxes = []
for cat_id in np.unique(boxes[:, 0]):
category_boxes = boxes[boxes[:, 0] == cat_id]
category_scores = category_boxes[:, 1]
category_threshold = self.threshold.get(int(cat_id), 0.5)
selected_indices = category_scores > category_threshold
category_filtered_boxes.append(category_boxes[selected_indices])
boxes = np.vstack(category_filtered_boxes) if category_filtered_boxes else np.array([])

if self.layout_postprocess:
filtered_boxes = []
### Layout postprocess for NMS
for cat_id in np.unique(boxes[:, 0]):
category_boxes = boxes[boxes[:, 0] == cat_id]
category_scores = category_boxes[:, 1]
if len(category_boxes) > 0:
nms_indices = non_max_suppression(category_boxes[:, 2:], category_scores, 0.5)
category_boxes = category_boxes[nms_indices]
keep_boxes = []
for i, box in enumerate(category_boxes):
if all(not is_box_mostly_inside(box[2:], other_box[2:]) for j, other_box in enumerate(category_boxes) if i != j):
keep_boxes.append(box)
filtered_boxes.extend(keep_boxes)
boxes = np.array(filtered_boxes)
### Layout postprocess for removing boxes inside image category box
if self.labels and "image" in self.labels:
image_cls_id = self.labels.index('image')
if len(boxes) > 0:
image_boxes = boxes[boxes[:, 0] == image_cls_id]
other_boxes = boxes[boxes[:, 0] != image_cls_id]
to_keep = []
for box in other_boxes:
keep = True
for img_box in image_boxes:
if (box[2] >= img_box[2] and box[3] >= img_box[3] and
box[4] <= img_box[4] and box[5] <= img_box[5]):
keep = False
break
if keep:
to_keep.append(box)
boxes = np.vstack([image_boxes, to_keep]) if to_keep else image_boxes
### Layout postprocess for overlaps
final_boxes = []
while len(boxes) > 0:
current_box = boxes[0]
current_score = current_box[1]
overlaps = [current_box]
non_overlaps = []
for other_box in boxes[1:]:
iou = compute_iou(current_box[2:], other_box[2:])
if iou > 0.95:
if other_box[1] > current_score:
overlaps.append(other_box)
else:
non_overlaps.append(other_box)
best_box = max(overlaps, key=lambda x: x[1])
final_boxes.append(best_box)
boxes = np.array(non_overlaps)
boxes = np.array(final_boxes)

if boxes.shape[1] == 6:
"""For Normal Object Detection"""
boxes = restructured_boxes(boxes, self.labels, img_size)
Expand All @@ -257,7 +369,6 @@ def apply(self, boxes, img_size):
f"The shape of boxes should be 6 or 10, instead of {boxes.shape[1]}"
)
result = {"boxes": boxes}

return result


Expand Down
1 change: 1 addition & 0 deletions paddlex/inference/models/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _build_components(self):
DetPostProcess(
threshold=self.config["draw_threshold"],
labels=self.config["label_list"],
layout_postprocess=self.config.get("layout_postprocess", False),
),
]
)
Expand Down
4 changes: 2 additions & 2 deletions paddlex/inference/results/det.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def draw_box(img, boxes):
Returns:
img (PIL.Image.Image): visualized image
"""
font_size = int(0.024 * int(img.width)) + 2
font_size = int(0.018 * int(img.width)) + 2
font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")

draw_thickness = int(max(img.size) * 0.005)
draw_thickness = int(max(img.size) * 0.002)
draw = ImageDraw.Draw(img)
label2color = {}
catid2fontcolor = {}
Expand Down

0 comments on commit 4d6b62f

Please sign in to comment.