Skip to content

Commit

Permalink
[Fix & Optimize] 인페인팅 로직 개선 및 종속성 최적화
Browse files Browse the repository at this point in the history
[Fix & Optimize] 인페인팅 로직 개선 및 종속성 최적화
  • Loading branch information
semnisem authored Nov 25, 2024
2 parents 0ccad2e + dbbff0f commit 2d6ae00
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 45 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ FROM python:3.12-slim
LABEL org.opencontainers.image.source="https://github.com/AI-SIP/MVP_CV"

# 필요한 시스템 패키지 먼저 설치
RUN apt-get update && apt-get install -y git \
RUN apt-get update && apt-get install -y \
libgl1-mesa-glx \
libglib2.0-0 \
gcc \
python3-dev && \
rm -rf /var/lib/apt/lists/* \
rm -rf /var/lib/apt/lists/*

# 작업 디렉토리 설정
WORKDIR /test
Expand Down
89 changes: 47 additions & 42 deletions app/AIProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import cv2
from ultralytics import YOLO
from ultralytics import SAM
from segment_anything import sam_model_registry, SamPredictor
# from segment_anything import sam_model_registry, SamPredictor
import io
import logging

Expand Down Expand Up @@ -50,9 +50,10 @@ def remove_alpha(self, image):
return image_rgb

def load_sam_model(self, model_type="vit_h"): # sam 로드
self.sam_model = sam_model_registry[model_type](checkpoint=self.sam_path)
self.sam_model.to(self.device)
self.predictor = SamPredictor(self.sam_model)
# self.sam_model = sam_model_registry[model_type](checkpoint=self.sam_path)
# self.sam_model.to(self.device)
# self.predictor = SamPredictor(self.sam_model)
logging.info("SAM model cannot be loaded")

def object_detection(self, img):
results = self.yolo_model.predict(source=img, imgsz=640, device=self.device,
Expand All @@ -76,9 +77,7 @@ def segment_from_yolo(self, image, bbox, save_path=None):
# kernel = np.ones((15, 15), np.uint8)
# mask_np = cv2.morphologyEx(mask_np, cv2.MORPH_GRADIENT, kernel)'''
masks_np = cv2.bitwise_or(masks_np, mask_np)
# cv2.imwrite(f'mask_box{i}.jpg', masks_np)

# cv2.imwrite(save_path, masks_np)
logging.info(f'1차 마스킹 - 바운딩 박스 {len(mask_boxes)}개 세그먼트 완료.')
return masks_np

Expand All @@ -96,49 +95,50 @@ def segment_from_user(self, image, user_inputs, bbox, save_path=None):
mask_np = mask_np.squeeze()
masks_np = cv2.bitwise_or(masks_np, mask_np)

# cv2.imwrite(save_path, mask_points_uint8)
logging.info(f'2차 마스킹 - 사용자 입력에 대해 {len(mask_points)}개 영역으로 세그먼트 완료.')
return masks_np

def inpainting(self, image, mask_total, bbox=None):

mask_total = np.zeros(image.shape[:2], dtype=np.uint8)
def inpaint_from_yolo(self, image, bbox=None):
inpainted_image = image.copy()
masks_np = np.zeros(image.shape[:2], dtype=np.uint8)
text_np = np.zeros(image.shape[:2], dtype=np.uint8) # 전체 roi_mask 저장
for b in bbox:
minx, miny, maxx, maxy = map(int, b)
mask_total[miny:maxy, minx:maxx] = 255 # 박스 영역을 255로 채움
roi = image[miny:maxy, minx:maxx] # 해당 이미지만 크롭
roi_mask = cv2.inRange(roi, (0, 0, 0), (45, 45, 45)) # roi 내에서 인쇄된 글씨는 255값
masks_np[miny:maxy, minx:maxx] = 255 # 박스 영역을 255로 채움

roi = image[miny:maxy, minx:maxx] # 박스 내 복사하고 텍스트만 추출하기
roi_mask = cv2.inRange(roi, (0, 0, 0), (40, 40, 45)) # roi 내에서 인쇄된 글씨는 255값
text_np[miny:maxy, minx:maxx] = roi_mask

if maxy < image.shape[0]-2: # 박스 근처 BG 컬러 샘플링
sample_color = image[maxy + 1, (minx+maxx)//2].tolist() # 박스 하단 모서리 중앙 아래 픽셀값 샘플링
else:
sample_color = [255, 255, 255] # 경계에 있을 경우, 기본 흰색
inpainted_image[miny:maxy, minx:maxx] = sample_color # 인페인팅
logging.info(f'1차 인페인팅 - yolo 박스 샘플링 컬러 rgb: {sample_color}')

# 복원 수행
text_np = cv2.dilate(text_np, np.ones((4, 4), np.uint8), iterations=1)
text_np = cv2.erode(text_np, np.ones((2, 2), np.uint8), iterations=2)
# cv2.imwrite('text_np.jpg', text_np.astype(np.uint8))

inpainted_image = image.copy()
inpainted_image[mask_total == 255] = [255, 255, 255]
inpainted_image[text_np == 255] = [30, 30, 30]
final_image = cv2.convertScaleAbs(inpainted_image, alpha=1.5, beta=15)

# inpainted_image = cv2.inpaint(image.copy(), mask_total, 15, cv2.INPAINT_TELEA)
# cv2.imwrite('test_images/inpainted_init.png', inpainted_image)
# cv2.imwrite('test_images/inpainted_final.png', final_image)
logging.info('Yolo 결과 인페인팅 및 후보정 완료.')
# inpainted_image = cv2.inpaint(inpainted_image, mask_total, 10, cv2.INPAINT_TELEA)

return final_image
return text_np, masks_np, inpainted_image

def inpaint_from_user(self, image, user_boxes, save_path=None):
masks_np = np.zeros(image.shape[:2], dtype=np.uint8)
for b in user_boxes: # 박스 내부 다 채우기(마스킹)
inpainted_image = image.copy()
masks_np = np.zeros(image.shape[:2], dtype=np.uint8) # 박스 위치 마스킹
for b in user_boxes:
minx, miny, maxx, maxy = map(int, b)
masks_np[miny:maxy, minx:maxx] = 255 # 박스 영역을 255로 채움
if minx > 2: # 박스 근처 BG 컬러 샘플링
sample_color = image[(miny + maxy) // 2, minx-2].tolist() # 우측 상단 모서리 바로 외부의 픽셀
else:
sample_color = [255, 255, 255] # 경계에 있을 경우, 기본 흰색
inpainted_image[miny:maxy, minx:maxx] = sample_color # 인페인팅
logging.info(f'2차 인페인팅 - user 박스 샘플링 컬러 RGB: {sample_color}')

# cv2.imwrite(save_path, mask_points_uint8)
inpainted_image = image.copy()
inpainted_image[masks_np == 255] = [255, 255, 255]
# final_image = cv2.convertScaleAbs(inpainted_image, alpha=1.5, beta=10)

logging.info(f'2차 마스킹 & 인페인팅 - 사용자 입력 {len(user_boxes)}개 영역 인페인팅 완료.')
return masks_np, inpainted_image

def combine_alpha(self, image_rgb):
Expand Down Expand Up @@ -169,32 +169,34 @@ def process(self, img_bytes, user_inputs, extension='jpg'):
bbox = self.object_detection(image)
if len(bbox) > 0:
logging.info("***** 1차: 객체 탐지 세그멘테이션 시작 ******")
masks_by_yolo = self.segment_from_yolo(image, bbox, save_path=None) # 'test_images/seg_box.png'
# masks_by_yolo = self.segment_from_yolo(image, bbox, save_path=None) # 'test_images/seg_box.png'
recovery, masks_by_yolo, image_output = self.inpaint_from_yolo(image_output, bbox)
masks_total = cv2.bitwise_or(masks_total, masks_by_yolo)
logging.info('***** 1차: 객차 탐지 인페인팅 수행 완료 ******')
# cv2.imwrite('inpainted_yolo_복원.jpg', recovery.astype(np.uint8)) # 인식 및 복원하는 문항
# cv2.imwrite('inpainted_yolo_결과.jpg', image_output) # 1차 마스킹 및 인페인팅 결과
else:
logging.info("***** 1차: 객체 탐지 세그멘테이션 스킵 ******")
masks_by_yolo = None


### 2차: Segment by User Prompt
if len(user_inputs) > 0:
logging.info("***** 2차: 사용자 입력 세그멘테이션 시작 ******")
masks_by_user = self.segment_from_user(image, user_inputs, bbox, save_path=None) # save_path='test_images/seg_points.png'
# masks_by_user = self.segment_from_user(image, user_inputs, bbox, save_path=None) # save_path='test_images/seg_points.png'
masks_by_user, image_output = self.inpaint_from_user(image_output, user_inputs)
masks_total = cv2.bitwise_or(masks_total, masks_by_user)
masks_by_user, image_output = self.inpaint_from_user(image, user_inputs, save_path=None)
_, mask_bytes = cv2.imencode("." + extension, image_output)
logging.info('***** 2차: 사용자 입력 인페인팅 수행 완료 ******')
# cv2.imwrite('inpainted_yolo&user_결과.jpg', image_output)
else:
logging.info("***** 2차: 사용자 입력 세그멘테이션 스킵 ******")
masks_by_user = None
mask_bytes = None

if isinstance(masks_total, np.ndarray) and len(bbox) > 0:
image_output = self.inpainting(image_output, masks_total, bbox)
logging.info('***** 인페인팅 수행 완료 ******')
image_output = cv2.convertScaleAbs(image_output, alpha=1.5, beta=10)
logging.info("***** 필기 제거 결과 후보정 완료 ******")

# 5가지 bytes 생성
_, input_bytes = cv2.imencode("." + extension, image)
_, result_bytes = cv2.imencode("." + extension, image_output)

if masks_by_yolo is not None:
mask_by_yolo_img = image.copy()
mask_by_yolo_img[masks_by_yolo == 255] = [0, 0, 0]
Expand All @@ -207,10 +209,13 @@ def process(self, img_bytes, user_inputs, extension='jpg'):
_, mask_by_user_bytes = cv2.imencode("." + extension, mask_by_user_img)
else:
mask_by_user_bytes = None

mask_total_img = image.copy()
mask_total_img[masks_total == 255] = [0, 0, 0]
_, mask_bytes = cv2.imencode("." + extension, mask_total_img)

# cv2.imwrite('input.jpg', image)
# cv2.imwrite('mask_total.jpg', mask_total_img)
# cv2.imwrite('output.jpg', image_output)

return (io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes),
io.BytesIO(mask_by_yolo_bytes), io.BytesIO(mask_by_user_bytes))
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,3 @@ torch==2.5.1
torchvision==0.20.1
matplotlib==3.9.1.post1
ultralytics==8.3.31
git+https://github.com/facebookresearch/segment-anything.git@dca509fe793f601edb92606367a655c15ac00fdf

0 comments on commit 2d6ae00

Please sign in to comment.