Skip to content

Commit

Permalink
HotFix: AI 에러 수정
Browse files Browse the repository at this point in the history
로직 에러
  • Loading branch information
semnisem committed Nov 19, 2024
1 parent 8d85636 commit 8b141d5
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions app/AIProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, yolo_path: str, sam_path: str, device: torch.device = torch.d
self.height = 0
self.width = 0

def remove_points_in_bboxes(point_list, label_list, bbox_list):
def remove_points_in_bboxes(self, point_list, label_list, bbox_list):
def is_point_in_bbox(p, b):
x, y = p
x_min, y_min, x_max, y_max = b
Expand Down Expand Up @@ -61,7 +61,7 @@ def object_detection(self, img):
logging.info(f'1차 마스킹 - 바운딩 박스 생성 완료, {len(bbox)}개')
return bbox

def segment_from_points(self, user_points, user_labels, save_path=None):
def segment_from_points(self, image, user_points, user_labels, bbox, save_path=None):
'''input_point = np.array(user_points)
input_label = np.array(user_labels)
Expand All @@ -73,7 +73,7 @@ def segment_from_points(self, user_points, user_labels, save_path=None):
filtered_points, filtered_labels = self.remove_points_in_bboxes(user_points, user_labels, bbox)
logging.info(f'2차 마스킹 - 사용자 입력 필터링: from {len(user_points)}개 to {len(filtered_points)}개')

results = self.sam_model.predict(points=filtered_points, labels=filtered_labels)
results = self.sam_model.predict(source=image, points=filtered_points, labels=filtered_labels)
mask_points = results[0].masks.data

masks_np = np.zeros(mask_points.shape[-2:], dtype=np.uint8)
Expand All @@ -98,7 +98,8 @@ def segment_from_boxes(self, image, bbox, save_path=None):
multimask_output=False,
)
'''
mask_boxes = self.sam_model.predict(bboxes=bbox)
results = self.sam_model.predict(source=image, bboxes=bbox)
mask_boxes = results[0].masks.data

masks_np = np.zeros(mask_boxes.shape[-2:], dtype=np.uint8)
for mask in mask_boxes:
Expand All @@ -107,7 +108,7 @@ def segment_from_boxes(self, image, bbox, save_path=None):
masks_np = cv2.bitwise_or(masks_np, mask_np)

# cv2.imwrite(save_path, masks_np)
logging.info(f'1차 마스킹 - 바운딩 박스 {mask_boxes}개 세그먼트 완료.')
logging.info(f'1차 마스킹 - 바운딩 박스 {len(mask_boxes)}개 세그먼트 완료.')
return masks_np

def inpainting(self, image, mask_total):
Expand Down Expand Up @@ -144,14 +145,14 @@ def process(self, img_bytes, user_points, user_labels, extension='jpg'):
### 1차: Object Detection & Segment by Box
bbox = self.object_detection(image)
if len(bbox) > 0:
masks_by_box = self.segment_from_boxes(image, bbox, save_path=None) # 'test_images/seg_box.png'
masks_by_box = self.segment_from_boxes(image, bbox, save_path=None) # 'test_images/seg_box.png'
masks_total = cv2.bitwise_or(masks_total, masks_by_box)
#logging.info( f"1차 마스킹 후 shape 점검: YOLOv11 감지된 영역 shape: {masks_by_box.shape}, 이미지 영역 shape: {image.shape}") # (1893, 1577, 3) (1893, 1577)
else:
masks_by_box = None
### 2차: points arguments by User & Segment by Points
if len(user_points) > 0:
mask_by_point = self.segment_from_points(user_points, user_labels, save_path=None) # save_path='test_images/seg_points.png'
mask_by_point = self.segment_from_points(image, user_points, user_labels, bbox, save_path=None) # save_path='test_images/seg_points.png'
masks_total = cv2.bitwise_or(masks_total, mask_by_point)
else:
mask_by_point = None
Expand Down

0 comments on commit 8b141d5

Please sign in to comment.