Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Filter k most confident masks #720

Merged
merged 10 commits into from
Sep 25, 2024
13 changes: 10 additions & 3 deletions aaaaaa/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,21 @@ def detection(w: Widgets, n: int, is_img2img: bool):
visible=True,
elem_id=eid("ad_confidence"),
)
w.ad_mask_k_largest = gr.Slider(
label="Mask only the top k largest (0 to disable)" + suffix(n),
w.ad_mask_filter_method = gr.Radio(
choices=["Area", "Confidence"],
value="Area",
label="Method to filter top k masks by (confidence or area)",
visible=True,
elem_id=eid("ad_mask_filter_method"),
)
w.ad_mask_k = gr.Slider(
label="Mask only the top k (0 to disable)" + suffix(n),
minimum=0,
maximum=10,
step=1,
value=0,
visible=True,
elem_id=eid("ad_mask_k_largest"),
elem_id=eid("ad_mask_k"),
)

with gr.Column(variant="compact"):
Expand Down
12 changes: 9 additions & 3 deletions adetailer/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_prompt: str = ""
ad_negative_prompt: str = ""
ad_confidence: confloat(ge=0.0, le=1.0) = 0.3
ad_mask_k_largest: NonNegativeInt = 0
ad_mask_filter_method: Literal["Area", "Confidence"] = "Area"
ad_mask_k: NonNegativeInt = 0
ad_mask_min_ratio: confloat(ge=0.0, le=1.0) = 0.0
ad_mask_max_ratio: confloat(ge=0.0, le=1.0) = 1.0
ad_dilate_erode: int = 4
Expand Down Expand Up @@ -131,7 +132,11 @@ def extra_params(self, suffix: str = "") -> dict[str, Any]:
ppop("ADetailer prompt")
ppop("ADetailer negative prompt")
p.pop("ADetailer tab enable", None) # always pop
ppop("ADetailer mask only top k largest", cond=0)
ppop(
"ADetailer mask only top k",
["ADetailer mask only top k", "ADetailer method to decide top k masks"],
cond=0,
)
ppop("ADetailer mask min ratio", cond=0.0)
ppop("ADetailer mask max ratio", cond=1.0)
ppop("ADetailer x offset", cond=0)
Expand Down Expand Up @@ -217,7 +222,8 @@ def need_skip(self) -> bool:
("ad_prompt", "ADetailer prompt"),
("ad_negative_prompt", "ADetailer negative prompt"),
("ad_confidence", "ADetailer confidence"),
("ad_mask_k_largest", "ADetailer mask only top k largest"),
("ad_mask_filter_method", "ADetailer method to decide top k masks"),
("ad_mask_k", "ADetailer mask only top k"),
("ad_mask_min_ratio", "ADetailer mask min ratio"),
("ad_mask_max_ratio", "ADetailer mask max ratio"),
("ad_x_offset", "ADetailer x offset"),
Expand Down
1 change: 1 addition & 0 deletions adetailer/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class PredictOutput(Generic[T]):
bboxes: list[list[T]] = field(default_factory=list)
masks: list[Image.Image] = field(default_factory=list)
confidences: list[float] = field(default_factory=list)
preview: Optional[Image.Image] = None


Expand Down
23 changes: 23 additions & 0 deletions adetailer/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def filter_by_ratio(
idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)]
pred.bboxes = [pred.bboxes[i] for i in idx]
pred.masks = [pred.masks[i] for i in idx]
pred.confidences = [pred.confidences[i] for i in idx]
return pred


Expand All @@ -236,9 +237,31 @@ def filter_k_largest(pred: PredictOutput[T], k: int = 0) -> PredictOutput[T]:
idx = idx[::-1]
pred.bboxes = [pred.bboxes[i] for i in idx]
pred.masks = [pred.masks[i] for i in idx]
pred.confidences = [pred.confidences[i] for i in idx]
return pred


def filter_k_most_confident(pred: PredictOutput[T], k: int = 0) -> PredictOutput[T]:
if not pred.bboxes or not pred.confidences or k == 0:
return pred
idx = np.argsort(pred.confidences)[-k:]
idx = idx[::-1]
pred.bboxes = [pred.bboxes[i] for i in idx]
pred.masks = [pred.masks[i] for i in idx]
pred.confidences = [pred.confidences[i] for i in idx]
return pred


def filter_k_by(
pred: PredictOutput[T], k: int = 0, by: str = "Area"
) -> PredictOutput[T]:
if by == "Area":
return filter_k_largest(pred, k)
if by == "Confidence":
return filter_k_most_confident(pred, k)
raise RuntimeError


# Merge / Invert
def mask_merge(masks: list[Image.Image]) -> list[Image.Image]:
arrs = [np.array(m) for m in masks]
Expand Down
7 changes: 5 additions & 2 deletions adetailer/mediapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def mediapipe_face_detection(
preview_array = img_array.copy()

bboxes = []
confidences = []
for detection in pred.detections:
draw_util.draw_detection(preview_array, detection)

Expand All @@ -63,12 +64,15 @@ def mediapipe_face_detection(
x2 = x1 + w
y2 = y1 + h

confidences.append(detection.score)
bboxes.append([x1, y1, x2, y2])

masks = create_mask_from_bbox(bboxes, image.size)
preview = Image.fromarray(preview_array)

return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
return PredictOutput(
bboxes=bboxes, masks=masks, confidences=confidences, preview=preview
)


def mediapipe_face_mesh(
Expand Down Expand Up @@ -141,7 +145,6 @@ def mediapipe_face_mesh_eyes_only(

preview = image.copy()
masks = []

for landmarks in pred.multi_face_landmarks:
points = np.array(
[[land.x * w, land.y * h] for land in landmarks.landmark], dtype=int
Expand Down
7 changes: 6 additions & 1 deletion adetailer/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,16 @@ def ultralytics_predict(
masks = create_mask_from_bbox(bboxes, image.size)
else:
masks = mask_to_pil(pred[0].masks.data, image.size)

confidences = pred[0].boxes.conf.cpu().numpy().tolist()

preview = pred[0].plot()
preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB)
preview = Image.fromarray(preview)

return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
return PredictOutput(
bboxes=bboxes, masks=masks, confidences=confidences, preview=preview
)


def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str):
Expand Down
4 changes: 2 additions & 2 deletions scripts/!adetailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from adetailer.common import PredictOutput, ensure_pil_image, safe_mkdir
from adetailer.mask import (
filter_by_ratio,
filter_k_largest,
filter_k_by,
has_intersection,
is_all_black,
mask_preprocess,
Expand Down Expand Up @@ -596,7 +596,7 @@ def pred_preprocessing(self, p, pred: PredictOutput, args: ADetailerArgs):
pred = filter_by_ratio(
pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio
)
pred = filter_k_largest(pred, k=args.ad_mask_k_largest)
pred = filter_k_by(pred, k=args.ad_mask_k, by=args.ad_mask_filter_method)
pred = self.sort_bboxes(pred)
masks = mask_preprocess(
pred.masks,
Expand Down
Loading