diff --git a/app/AIProcessor.py b/app/AIProcessor.py index d6947a1..da2ab17 100644 --- a/app/AIProcessor.py +++ b/app/AIProcessor.py @@ -3,6 +3,7 @@ import numpy as np import cv2 from ultralytics import YOLO +from ultralytics import SAM from segment_anything import sam_model_registry, SamPredictor import io import logging @@ -14,7 +15,7 @@ def __init__(self, yolo_path: str, sam_path: str, device: torch.device = torch.d self.sam_path = sam_path self.device = device self.yolo_model = YOLO(self.yolo_path) # yolo 로드 - self.sam_model = None + self.sam_model = SAM(self.sam_path) # None self.predictor = None self.alpha_channel = None @@ -61,21 +62,33 @@ def object_detection(self, img): return bbox def segment_from_points(self, user_points, user_labels, save_path=None): - input_point = np.array(user_points) + '''input_point = np.array(user_points) input_label = np.array(user_labels) masks_points, scores, logits = self.predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=False, - ) + )''' + filtered_points, filtered_labels = self.remove_points_in_bboxes(user_points, user_labels, bbox) + logging.info(f'2차 마스킹 - 사용자 입력 필터링: from {len(user_points)}개 to {len(filtered_points)}개') + + results = self.sam_model.predict(points=filtered_points, labels=filtered_labels) + mask_points = results[0].masks.data + + masks_np = np.zeros(mask_points.shape[-2:], dtype=np.uint8) + for mask in mask_points: + mask_np = mask.cpu().numpy().astype(np.uint8) * 255 + mask_np = mask_np.squeeze() + masks_np = cv2.bitwise_or(masks_np, mask_np) - mask_points_uint8 = (masks_points[0] * 255).astype(np.uint8) + # mask_points_uint8 = (masks_points[0] * 255).astype(np.uint8) # cv2.imwrite(save_path, mask_points_uint8) - logging.info('2차 마스킹 - 사용자 입력 점 세그먼트 완료.') - return mask_points_uint8 + logging.info(f'2차 마스킹 - 사용자 입력에 대해 {len(mask_points)}개 영역으로 세그먼트 완료.') + return masks_np def segment_from_boxes(self, image, bbox, save_path=None): + ''' input_boxes = torch.tensor(bbox, device=self.predictor.device) transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2]) masks, _, _ = self.predictor.predict_torch( @@ -84,14 +97,17 @@ def segment_from_boxes(self, image, bbox, save_path=None): boxes=transformed_boxes, multimask_output=False, ) - masks_np = np.zeros(masks.shape[-2:], dtype=np.uint8) - for mask in masks: + ''' + mask_boxes = self.sam_model.predict(bboxes=bbox) + + masks_np = np.zeros(mask_boxes.shape[-2:], dtype=np.uint8) + for mask in mask_boxes: mask_np = mask.cpu().numpy().astype(np.uint8) * 255 mask_np = mask_np.squeeze() masks_np = cv2.bitwise_or(masks_np, mask_np) # cv2.imwrite(save_path, masks_np) - logging.info('1차 마스킹 - 바운딩 박스 세그먼트 완료.') + logging.info(f'1차 마스킹 - 바운딩 박스 {mask_boxes}개 세그먼트 완료.') return masks_np def inpainting(self, image, mask_total): @@ -121,8 +137,8 @@ def process(self, img_bytes, user_points, user_labels, extension='jpg'): image = cv2.imdecode(buffer, cv2.IMREAD_UNCHANGED) ### ready - self.load_sam_model() - self.predictor.set_image(image) + #self.load_sam_model() + #self.predictor.set_image(image) masks_total = np.zeros(image.shape[:2], dtype=np.uint8) ### 1차: Object Detection & Segment by Box @@ -130,13 +146,15 @@ def process(self, img_bytes, user_points, user_labels, extension='jpg'): if len(bbox) > 0: masks_by_box = self.segment_from_boxes(image, bbox, save_path=None) # 'test_images/seg_box.png' masks_total = cv2.bitwise_or(masks_total, masks_by_box) - logging.info( - f"1차 마스킹 후 shape 점검: YOLOv11 감지된 영역 shape: {masks_by_box.shape}, 이미지 영역 shape: {image.shape}") # (1893, 1577, 3) (1893, 1577) - + #logging.info( f"1차 마스킹 후 shape 점검: YOLOv11 감지된 영역 shape: {masks_by_box.shape}, 이미지 영역 shape: {image.shape}") # (1893, 1577, 3) (1893, 1577) + else: + masks_by_box = None ### 2차: points arguments by User & Segment by Points if len(user_points) > 0: mask_by_point = self.segment_from_points(user_points, user_labels, save_path=None) # save_path='test_images/seg_points.png' masks_total = cv2.bitwise_or(masks_total, mask_by_point) + else: + mask_by_point = None # cv2.imwrite('test_images/mask_total.png', masks_total) if isinstance(masks_total, np.ndarray): image_output = self.inpainting(image, masks_total) @@ -151,7 +169,17 @@ def process(self, img_bytes, user_points, user_labels, extension='jpg'): _, mask_bytes = cv2.imencode("." + extension, masks_total.astype(np.uint8)) _, result_bytes = cv2.imencode("." + extension, image_output) - # return 0, 0, 0 - return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes) + if mask_by_point is not None and masks_by_box is not None: + _, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8)) + _, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8)) + return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), io.BytesIO(mask_by_point_bytes) + elif mask_by_point is not None: + _, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8)) + return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, io.BytesIO(mask_by_point_bytes) + elif masks_by_box is not None: + _, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8)) + return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), None + else: + return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, None diff --git a/app/main.py b/app/main.py index cf6b286..bdc21ec 100644 --- a/app/main.py +++ b/app/main.py @@ -66,7 +66,7 @@ def upload_image_to_s3(file_bytes, file_path): s3_client.upload_fileobj(file_bytes, BUCKET_NAME, file_path) -def download_model_from_s3(yolo_path: str = 'models/yolo11_best.pt', sam_path: str = 'models/sam_vit_h_4b8939.pth'): +def download_model_from_s3(yolo_path: str = 'models/yolo11_best.pt', sam_path: str = "models/mobile_sam.pt"): # models/sam_vit_h_4b8939.pth dest_dir = f'../' # 모델을 저장할 컨테이너 내 경로 try: yolo_full_path = dest_dir+yolo_path @@ -82,7 +82,6 @@ def download_model_from_s3(yolo_path: str = 'models/yolo11_best.pt', sam_path: s else: logger.info(f'SAM models already exists at {dest_dir}') - logger.info(f"Files in 'models' Dir: {os.listdir(dest_dir)}") except Exception as e: print(f'Failed to download model: {e}')