Skip to content

Commit

Permalink
Fix: 추가 - 필기제거 강도조절 옵션
Browse files Browse the repository at this point in the history
  • Loading branch information
semnisem committed Nov 26, 2024
1 parent dbbff0f commit 25b012c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
10 changes: 7 additions & 3 deletions app/AIProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, yolo_path: str, sam_path: str, device: torch.device = torch.d
self.size = 0
self.height = 0
self.width = 0
self.upper_rgb = (40, 40, 45)

def remove_points_in_bboxes(self, point_list, label_list, bbox_list):
def is_point_in_bbox(p, b):
Expand Down Expand Up @@ -60,7 +61,7 @@ def object_detection(self, img):
iou=0.3, conf=0.3)
bbox = results[0].boxes.xyxy.tolist()
self.indices = [index for index, value in enumerate(results[0].boxes.cls) if value == 1.0]
logging.info(f'객체 탐지 - {self.indices} 박스에 동그라미 존재')
logging.info(f'객체 탐지 - {len(bbox)}개 결과 중, 동그란 마킹 {len(self.indices)}개 ({self.indices})')
return bbox

def segment_from_yolo(self, image, bbox, save_path=None):
Expand Down Expand Up @@ -107,7 +108,7 @@ def inpaint_from_yolo(self, image, bbox=None):
masks_np[miny:maxy, minx:maxx] = 255 # 박스 영역을 255로 채움

roi = image[miny:maxy, minx:maxx] # 박스 내 복사하고 텍스트만 추출하기
roi_mask = cv2.inRange(roi, (0, 0, 0), (40, 40, 45)) # roi 내에서 인쇄된 글씨는 255값
roi_mask = cv2.inRange(roi, (0, 0, 0), self.upper_rgb) # roi 내에서 인쇄된 글씨는 255값
text_np[miny:maxy, minx:maxx] = roi_mask

if maxy < image.shape[0]-2: # 박스 근처 BG 컬러 샘플링
Expand Down Expand Up @@ -148,7 +149,7 @@ def combine_alpha(self, image_rgb):
else:
return image_rgb

def process(self, img_bytes, user_inputs, extension='jpg'):
def process(self, img_bytes, user_inputs, user_intensity, extension='jpg'):
""" local test 용도 Vs. server test 용도 구분 """
### local 용도
# img_path = img_bytes
Expand All @@ -163,6 +164,9 @@ def process(self, img_bytes, user_inputs, extension='jpg'):
### ready
#self.load_sam_model()
#self.predictor.set_image(image)
if user_intensity is not None and len(user_intensity) == 3:
self.upper_rgb = (user_intensity[0], user_intensity[1], user_intensity[2])
logging.info(f"인식할 text 색상 범위: (0, 0, 0) ~ {self.upper_rgb}")
masks_total = np.zeros(image.shape[:2], dtype=np.uint8)

### 1차: Segment by Object Detection
Expand Down
7 changes: 4 additions & 3 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ async def processShape(request: Request):
data = await request.json()
full_url = data['fullUrl']
point_list = data.get('points')
label_list = data.get('labels') # value or None
intensity = data.get('intensity') # value or None
logger.info(f"사용자 입력 포인트: {point_list}")
logger.info(f"사용자 입력 라벨: {label_list}")
logger.info(f"사용자 입력 강도: {intensity}")

try:
s3_key = parse_s3_url(full_url)
Expand All @@ -126,7 +126,8 @@ async def processShape(request: Request):
# aiProcessor = AIProcessor(yolo_path='/Users/semin/models/yolo11_best.pt', sam_path='/Users/semin/models/mobile_sam.pt') # local
aiProcessor = AIProcessor(yolo_path="../models/yolo11_best.pt", sam_path="../models/mobile_sam.pt") # server
img_input_bytes, img_mask_bytes, img_output_bytes, one, two = aiProcessor.process(img_bytes=corrected_img_bytes,
user_inputs=point_list)
user_inputs=point_list,
user_intensity=intensity)
logger.info("AI 필기 제거 프로세스 완료")

upload_image_to_s3(img_input_bytes, paths["input_path"])
Expand Down

0 comments on commit 25b012c

Please sign in to comment.