Skip to content

Commit

Permalink
Fix: 세그먼트 모델 변경 - 경량화
Browse files Browse the repository at this point in the history
경량화
  • Loading branch information
semnisem committed Nov 19, 2024
1 parent ca972d4 commit 52e9eda
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 18 deletions.
60 changes: 44 additions & 16 deletions app/AIProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import cv2
from ultralytics import YOLO
from ultralytics import SAM
from segment_anything import sam_model_registry, SamPredictor
import io
import logging
Expand All @@ -14,7 +15,7 @@ def __init__(self, yolo_path: str, sam_path: str, device: torch.device = torch.d
self.sam_path = sam_path
self.device = device
self.yolo_model = YOLO(self.yolo_path) # yolo 로드
self.sam_model = None
self.sam_model = SAM(self.sam_path) # None
self.predictor = None

self.alpha_channel = None
Expand Down Expand Up @@ -61,21 +62,33 @@ def object_detection(self, img):
return bbox

def segment_from_points(self, user_points, user_labels, save_path=None):
input_point = np.array(user_points)
'''input_point = np.array(user_points)
input_label = np.array(user_labels)
masks_points, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
)'''
filtered_points, filtered_labels = self.remove_points_in_bboxes(user_points, user_labels, bbox)
logging.info(f'2차 마스킹 - 사용자 입력 필터링: from {len(user_points)}개 to {len(filtered_points)}개')

results = self.sam_model.predict(points=filtered_points, labels=filtered_labels)
mask_points = results[0].masks.data

masks_np = np.zeros(mask_points.shape[-2:], dtype=np.uint8)
for mask in mask_points:
mask_np = mask.cpu().numpy().astype(np.uint8) * 255
mask_np = mask_np.squeeze()
masks_np = cv2.bitwise_or(masks_np, mask_np)

mask_points_uint8 = (masks_points[0] * 255).astype(np.uint8)
# mask_points_uint8 = (masks_points[0] * 255).astype(np.uint8)
# cv2.imwrite(save_path, mask_points_uint8)
logging.info('2차 마스킹 - 사용자 입력 점 세그먼트 완료.')
return mask_points_uint8
logging.info(f'2차 마스킹 - 사용자 입력에 대해 {len(mask_points)}개 영역으로 세그먼트 완료.')
return masks_np

def segment_from_boxes(self, image, bbox, save_path=None):
'''
input_boxes = torch.tensor(bbox, device=self.predictor.device)
transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = self.predictor.predict_torch(
Expand All @@ -84,14 +97,17 @@ def segment_from_boxes(self, image, bbox, save_path=None):
boxes=transformed_boxes,
multimask_output=False,
)
masks_np = np.zeros(masks.shape[-2:], dtype=np.uint8)
for mask in masks:
'''
mask_boxes = self.sam_model.predict(bboxes=bbox)

masks_np = np.zeros(mask_boxes.shape[-2:], dtype=np.uint8)
for mask in mask_boxes:
mask_np = mask.cpu().numpy().astype(np.uint8) * 255
mask_np = mask_np.squeeze()
masks_np = cv2.bitwise_or(masks_np, mask_np)

# cv2.imwrite(save_path, masks_np)
logging.info('1차 마스킹 - 바운딩 박스 세그먼트 완료.')
logging.info(f'1차 마스킹 - 바운딩 박스 {mask_boxes} 세그먼트 완료.')
return masks_np

def inpainting(self, image, mask_total):
Expand Down Expand Up @@ -121,22 +137,24 @@ def process(self, img_bytes, user_points, user_labels, extension='jpg'):
image = cv2.imdecode(buffer, cv2.IMREAD_UNCHANGED)

### ready
self.load_sam_model()
self.predictor.set_image(image)
#self.load_sam_model()
#self.predictor.set_image(image)
masks_total = np.zeros(image.shape[:2], dtype=np.uint8)

### 1차: Object Detection & Segment by Box
bbox = self.object_detection(image)
if len(bbox) > 0:
masks_by_box = self.segment_from_boxes(image, bbox, save_path=None) # 'test_images/seg_box.png'
masks_total = cv2.bitwise_or(masks_total, masks_by_box)
logging.info(
f"1차 마스킹 후 shape 점검: YOLOv11 감지된 영역 shape: {masks_by_box.shape}, 이미지 영역 shape: {image.shape}") # (1893, 1577, 3) (1893, 1577)

#logging.info( f"1차 마스킹 후 shape 점검: YOLOv11 감지된 영역 shape: {masks_by_box.shape}, 이미지 영역 shape: {image.shape}") # (1893, 1577, 3) (1893, 1577)
else:
masks_by_box = None
### 2차: points arguments by User & Segment by Points
if len(user_points) > 0:
mask_by_point = self.segment_from_points(user_points, user_labels, save_path=None) # save_path='test_images/seg_points.png'
masks_total = cv2.bitwise_or(masks_total, mask_by_point)
else:
mask_by_point = None
# cv2.imwrite('test_images/mask_total.png', masks_total)
if isinstance(masks_total, np.ndarray):
image_output = self.inpainting(image, masks_total)
Expand All @@ -151,7 +169,17 @@ def process(self, img_bytes, user_points, user_labels, extension='jpg'):
_, mask_bytes = cv2.imencode("." + extension, masks_total.astype(np.uint8))
_, result_bytes = cv2.imencode("." + extension, image_output)

# return 0, 0, 0
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes)
if mask_by_point is not None and masks_by_box is not None:
_, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8))
_, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), io.BytesIO(mask_by_point_bytes)
elif mask_by_point is not None:
_, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, io.BytesIO(mask_by_point_bytes)
elif masks_by_box is not None:
_, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), None
else:
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, None


3 changes: 1 addition & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def upload_image_to_s3(file_bytes, file_path):
s3_client.upload_fileobj(file_bytes, BUCKET_NAME, file_path)


def download_model_from_s3(yolo_path: str = 'models/yolo11_best.pt', sam_path: str = 'models/sam_vit_h_4b8939.pth'):
def download_model_from_s3(yolo_path: str = 'models/yolo11_best.pt', sam_path: str = "models/mobile_sam.pt"): # models/sam_vit_h_4b8939.pth
dest_dir = f'../' # 모델을 저장할 컨테이너 내 경로
try:
yolo_full_path = dest_dir+yolo_path
Expand All @@ -82,7 +82,6 @@ def download_model_from_s3(yolo_path: str = 'models/yolo11_best.pt', sam_path: s
else:
logger.info(f'SAM models already exists at {dest_dir}')


logger.info(f"Files in 'models' Dir: {os.listdir(dest_dir)}")
except Exception as e:
print(f'Failed to download model: {e}')
Expand Down

0 comments on commit 52e9eda

Please sign in to comment.