Skip to content

Commit

Permalink
[Feat] 하이브리드 필기 인식
Browse files Browse the repository at this point in the history
[Feat] 하이브리드 필기 인식
  • Loading branch information
semnisem authored Nov 24, 2024
2 parents e26dc4f + 2ccb36c commit 0ccad2e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 77 deletions.
169 changes: 94 additions & 75 deletions app/AIProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,32 @@ def object_detection(self, img):
logging.info(f'객체 탐지 - {self.indices} 박스에 동그라미 존재')
return bbox

def segment_from_points(self, image, user_points, user_labels, bbox, save_path=None):
'''input_point = np.array(user_points)
input_label = np.array(user_labels)
masks_points, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)'''
def segment_from_yolo(self, image, bbox, save_path=None):
results = self.sam_model.predict(source=image, bboxes=bbox)
mask_boxes = results[0].masks.data

masks_np = np.zeros(mask_boxes.shape[-2:], dtype=np.uint8)
for i, mask in enumerate(mask_boxes):
mask_np = mask.cpu().numpy().astype(np.uint8) * 255 # True는 255, False는 0으로 변환
mask_np = mask_np.squeeze()
'''if i in self.indices: # 원형 마킹이라면
mask_np = cv2.dilate(mask_np, np.ones((3, 3), np.uint8), iterations=1)
# kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) # Adjust size as needed
# kernel = np.ones((15, 15), np.uint8)
# mask_np = cv2.morphologyEx(mask_np, cv2.MORPH_GRADIENT, kernel)'''
masks_np = cv2.bitwise_or(masks_np, mask_np)
# cv2.imwrite(f'mask_box{i}.jpg', masks_np)

# cv2.imwrite(save_path, masks_np)
logging.info(f'1차 마스킹 - 바운딩 박스 {len(mask_boxes)}개 세그먼트 완료.')
return masks_np

def segment_from_user(self, image, user_inputs, bbox, save_path=None):
# filtered_points, filtered_labels = self.remove_points_in_bboxes(user_points, user_labels, bbox)
# logging.info(f'2차 마스킹 - 사용자 입력 필터링: from {len(user_points)}개 to {len(filtered_points)}개')

# results = self.sam_model.predict(source=image, points=filtered_points, labels=filtered_labels)
results = self.sam_model.predict(source=image, bboxes=[user_points])
results = self.sam_model.predict(source=image, bboxes=user_inputs)
mask_points = results[0].masks.data

masks_np = np.zeros(mask_points.shape[-2:], dtype=np.uint8)
Expand All @@ -88,57 +100,55 @@ def segment_from_points(self, image, user_points, user_labels, bbox, save_path=
logging.info(f'2차 마스킹 - 사용자 입력에 대해 {len(mask_points)}개 영역으로 세그먼트 완료.')
return masks_np

def segment_from_boxes(self, image, bbox, save_path=None):
'''
input_boxes = torch.tensor(bbox, device=self.predictor.device)
transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = self.predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
'''
results = self.sam_model.predict(source=image, bboxes=bbox)
mask_boxes = results[0].masks.data
def inpainting(self, image, mask_total, bbox=None):

masks_np = np.zeros(mask_boxes.shape[-2:], dtype=np.uint8)
for i, mask in enumerate(mask_boxes):
mask_np = mask.cpu().numpy().astype(np.uint8) * 255 # True는 255, False는 0으로 변환
mask_np = mask_np.squeeze()
if i in self.indices:
# kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) # Adjust size as needed
kernel = np.ones((15, 15), np.uint8)
mask_np = cv2.morphologyEx(mask_np, cv2.MORPH_GRADIENT, kernel)
masks_np = cv2.bitwise_or(masks_np, mask_np)
# cv2.imwrite(f'mask_box{i}.jpg', masks_np)

# cv2.imwrite(save_path, masks_np)
logging.info(f'1차 마스킹 - 바운딩 박스 {len(mask_boxes)}개 세그먼트 완료.')
return masks_np
mask_total = np.zeros(image.shape[:2], dtype=np.uint8)
text_np = np.zeros(image.shape[:2], dtype=np.uint8) # 전체 roi_mask 저장
for b in bbox:
minx, miny, maxx, maxy = map(int, b)
mask_total[miny:maxy, minx:maxx] = 255 # 박스 영역을 255로 채움
roi = image[miny:maxy, minx:maxx] # 해당 이미지만 크롭
roi_mask = cv2.inRange(roi, (0, 0, 0), (45, 45, 45)) # roi 내에서 인쇄된 글씨는 255값
text_np[miny:maxy, minx:maxx] = roi_mask

def inpainting(self, image, mask_total):
# inpainted_image = cv2.inpaint(image.copy(), mask_total, 10, cv2.INPAINT_TELEA)
print(mask_total.shape) # (1893, 1577) with 0 or 255
# print(image.shape) # (1893, 1577, 3)
text_np = cv2.dilate(text_np, np.ones((4, 4), np.uint8), iterations=1)
text_np = cv2.erode(text_np, np.ones((2, 2), np.uint8), iterations=2)
# cv2.imwrite('text_np.jpg', text_np.astype(np.uint8))

inpainted_image = image.copy()
inpainted_image[mask_total == 255] = [255, 255, 255]
final_image = cv2.convertScaleAbs(inpainted_image, alpha=1.5, beta=10)
inpainted_image[text_np == 255] = [30, 30, 30]
final_image = cv2.convertScaleAbs(inpainted_image, alpha=1.5, beta=15)

# inpainted_image = cv2.inpaint(image.copy(), mask_total, 15, cv2.INPAINT_TELEA)
# cv2.imwrite('test_images/inpainted_init.png', inpainted_image)
# cv2.imwrite('test_images/inpainted_final.png', final_image)
logging.info('인페인팅 및 후보정 완료.')
logging.info('Yolo 결과 인페인팅 및 후보정 완료.')

return final_image

def inpaint_from_user(self, image, user_boxes, save_path=None):
masks_np = np.zeros(image.shape[:2], dtype=np.uint8)
for b in user_boxes: # 박스 내부 다 채우기(마스킹)
minx, miny, maxx, maxy = map(int, b)
masks_np[miny:maxy, minx:maxx] = 255 # 박스 영역을 255로 채움

# cv2.imwrite(save_path, mask_points_uint8)
inpainted_image = image.copy()
inpainted_image[masks_np == 255] = [255, 255, 255]
# final_image = cv2.convertScaleAbs(inpainted_image, alpha=1.5, beta=10)

logging.info(f'2차 마스킹 & 인페인팅 - 사용자 입력 {len(user_boxes)}개 영역 인페인팅 완료.')
return masks_np, inpainted_image

def combine_alpha(self, image_rgb):
if self.alpha_channel is not None: # RGBA
image_rgba = cv2.merge([image_rgb, self.alpha_channel])
return image_rgba
else:
return image_rgb

def process(self, img_bytes, user_points, user_labels, extension='jpg'):
def process(self, img_bytes, user_inputs, extension='jpg'):
""" local test 용도 Vs. server test 용도 구분 """
### local 용도
# img_path = img_bytes
Expand All @@ -147,51 +157,60 @@ def process(self, img_bytes, user_points, user_labels, extension='jpg'):
### server 용도
buffer = np.frombuffer(img_bytes, dtype=np.uint8)
image = cv2.imdecode(buffer, cv2.IMREAD_UNCHANGED)
image_output = image.copy()
logging.info(f"이미지 처리 시작 - 사이즈: {image.shape[:2]}")

### ready
#self.load_sam_model()
#self.predictor.set_image(image)
masks_total = np.zeros(image.shape[:2], dtype=np.uint8)

### 1차: Object Detection & Segment by Box
### 1차: Segment by Object Detection
bbox = self.object_detection(image)
if len(bbox) > 0:
masks_by_box = self.segment_from_boxes(image, bbox, save_path=None) # 'test_images/seg_box.png'
masks_total = cv2.bitwise_or(masks_total, masks_by_box)
#logging.info( f"1차 마스킹 후 shape 점검: YOLOv11 감지된 영역 shape: {masks_by_box.shape}, 이미지 영역 shape: {image.shape}") # (1893, 1577, 3) (1893, 1577)
else:
masks_by_box = None
### 2차: points arguments by User & Segment by Points
if len(user_points) > 0:
mask_by_point = self.segment_from_points(image, user_points, user_labels, bbox, save_path=None) # save_path='test_images/seg_points.png'
masks_total = cv2.bitwise_or(masks_total, mask_by_point)
logging.info("***** 1차: 객체 탐지 세그멘테이션 시작 ******")
masks_by_yolo = self.segment_from_yolo(image, bbox, save_path=None) # 'test_images/seg_box.png'
masks_total = cv2.bitwise_or(masks_total, masks_by_yolo)
else:
mask_by_point = None
# cv2.imwrite('test_images/mask_total.png', masks_total)
if isinstance(masks_total, np.ndarray):
image_output = self.inpainting(image, masks_total)
logging.info('최종 마스킹 이미지 생성 완료')
logging.info("***** 1차: 객체 탐지 세그멘테이션 스킵 ******")
masks_by_yolo = None


### 2차: Segment by User Prompt
if len(user_inputs) > 0:
logging.info("***** 2차: 사용자 입력 세그멘테이션 시작 ******")
masks_by_user = self.segment_from_user(image, user_inputs, bbox, save_path=None) # save_path='test_images/seg_points.png'
masks_total = cv2.bitwise_or(masks_total, masks_by_user)
masks_by_user, image_output = self.inpaint_from_user(image, user_inputs, save_path=None)
_, mask_bytes = cv2.imencode("." + extension, image_output)
else:
image_output = image
logging.info('최종 마스킹이 없거나, numpy 형식의 배열이 아님.')
logging.info("***** 2차: 사용자 입력 세그멘테이션 스킵 ******")
masks_by_user = None
mask_bytes = None

if isinstance(masks_total, np.ndarray) and len(bbox) > 0:
image_output = self.inpainting(image_output, masks_total, bbox)
logging.info('***** 인페인팅 수행 완료 ******')

# image_input = self.combine_alpha(image_rgb)
# image_output = self.combine_alpha(image_output)
_, input_bytes = cv2.imencode("." + extension, image)
_, mask_bytes = cv2.imencode("." + extension, masks_total.astype(np.uint8))
_, result_bytes = cv2.imencode("." + extension, image_output)

if mask_by_point is not None and masks_by_box is not None:
_, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8))
_, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), io.BytesIO(mask_by_point_bytes)
elif mask_by_point is not None:
_, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, io.BytesIO(mask_by_point_bytes)
elif masks_by_box is not None:
_, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), None
if masks_by_yolo is not None:
mask_by_yolo_img = image.copy()
mask_by_yolo_img[masks_by_yolo == 255] = [0, 0, 0]
_, mask_by_yolo_bytes = cv2.imencode("." + extension, mask_by_yolo_img)
else:
mask_by_yolo_bytes = None
if masks_by_user is not None:
mask_by_user_img = image.copy()
mask_by_user_img[masks_by_user == 255] = [0, 0, 0]
_, mask_by_user_bytes = cv2.imencode("." + extension, mask_by_user_img)
else:
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, None
mask_by_user_bytes = None

mask_total_img = image.copy()
mask_total_img[masks_total == 255] = [0, 0, 0]
_, mask_bytes = cv2.imencode("." + extension, mask_total_img)

return (io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes),
io.BytesIO(mask_by_yolo_bytes), io.BytesIO(mask_by_user_bytes))
3 changes: 1 addition & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ async def processShape(request: Request):
# aiProcessor = AIProcessor(yolo_path='/Users/semin/models/yolo11_best.pt', sam_path='/Users/semin/models/mobile_sam.pt') # local
aiProcessor = AIProcessor(yolo_path="../models/yolo11_best.pt", sam_path="../models/mobile_sam.pt") # server
img_input_bytes, img_mask_bytes, img_output_bytes, one, two = aiProcessor.process(img_bytes=corrected_img_bytes,
user_points=point_list,
user_labels=label_list)
user_inputs=point_list)
logger.info("AI 필기 제거 프로세스 완료")

upload_image_to_s3(img_input_bytes, paths["input_path"])
Expand Down

0 comments on commit 0ccad2e

Please sign in to comment.