diff --git a/Dockerfile b/Dockerfile index 610fab7..91a6904 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,12 +5,12 @@ FROM python:3.12-slim LABEL org.opencontainers.image.source="https://github.com/AI-SIP/MVP_CV" # 필요한 시스템 패키지 먼저 설치 -RUN apt-get update && apt-get install -y --no-install-recommends \ +RUN apt-get update && apt-get install -y git \ libgl1-mesa-glx \ libglib2.0-0 \ gcc \ python3-dev && \ - rm -rf /var/lib/apt/lists/* + rm -rf /var/lib/apt/lists/* \ # 작업 디렉토리 설정 WORKDIR /test @@ -26,4 +26,4 @@ COPY ./app /test/app/ WORKDIR /test/app # 컨테이너 실행 시 실행할 명령어 -CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file +CMD ["bash", "-c", "mkdir -p /test/models && uvicorn main:app --host 0.0.0.0 --port 8000"] \ No newline at end of file diff --git a/app/AIProcessor.py b/app/AIProcessor.py new file mode 100644 index 0000000..fc80705 --- /dev/null +++ b/app/AIProcessor.py @@ -0,0 +1,197 @@ +import torch +import torchvision +import numpy as np +import cv2 +from ultralytics import YOLO +from ultralytics import SAM +from segment_anything import sam_model_registry, SamPredictor +import io +import logging + + +class AIProcessor: + def __init__(self, yolo_path: str, sam_path: str, device: torch.device = torch.device("cpu")): + self.yolo_path = yolo_path + self.sam_path = sam_path + self.device = device + self.yolo_model = YOLO(self.yolo_path) # yolo 로드 + self.sam_model = SAM(self.sam_path) # None + self.predictor = None + self.indices = None + self.alpha_channel = None + self.size = 0 + self.height = 0 + self.width = 0 + + def remove_points_in_bboxes(self, point_list, label_list, bbox_list): + def is_point_in_bbox(p, b): + x, y = p + x_min, y_min, x_max, y_max = b + return x_min <= x <= x_max and y_min <= y <= y_max + + filtered_p = [] + filtered_l = [] + for p, l in zip(point_list, label_list): + if not any(is_point_in_bbox(p, b) for b in bbox_list): + filtered_p.append(p) + filtered_l.append(l) + + return filtered_p, filtered_l + + def remove_alpha(self, image): + self.height, self.width = image.shape[:2] + self.size = self.height * self.width + if image.shape[2] == 4: # RGBA or RGB + self.alpha_channel = image[:, :, 3] + image_rgb = image[:, :, :3] + else: + image_rgb = image + + return image_rgb + + def load_sam_model(self, model_type="vit_h"): # sam 로드 + self.sam_model = sam_model_registry[model_type](checkpoint=self.sam_path) + self.sam_model.to(self.device) + self.predictor = SamPredictor(self.sam_model) + + def object_detection(self, img): + results = self.yolo_model.predict(source=img, imgsz=640, device=self.device, + iou=0.3, conf=0.3) + bbox = results[0].boxes.xyxy.tolist() + self.indices = [index for index, value in enumerate(results[0].boxes.cls) if value == 1.0] + logging.info(f'객체 탐지 - {self.indices} 박스에 동그라미 존재') + return bbox + + def segment_from_points(self, image, user_points, user_labels, bbox, save_path=None): + '''input_point = np.array(user_points) + input_label = np.array(user_labels) + + masks_points, scores, logits = self.predictor.predict( + point_coords=input_point, + point_labels=input_label, + multimask_output=False, + )''' + # filtered_points, filtered_labels = self.remove_points_in_bboxes(user_points, user_labels, bbox) + # logging.info(f'2차 마스킹 - 사용자 입력 필터링: from {len(user_points)}개 to {len(filtered_points)}개') + + # results = self.sam_model.predict(source=image, points=filtered_points, labels=filtered_labels) + results = self.sam_model.predict(source=image, bboxes=[user_points]) + mask_points = results[0].masks.data + + masks_np = np.zeros(mask_points.shape[-2:], dtype=np.uint8) + for mask in mask_points: + mask_np = mask.cpu().numpy().astype(np.uint8) * 255 + mask_np = mask_np.squeeze() + masks_np = cv2.bitwise_or(masks_np, mask_np) + + # cv2.imwrite(save_path, mask_points_uint8) + logging.info(f'2차 마스킹 - 사용자 입력에 대해 {len(mask_points)}개 영역으로 세그먼트 완료.') + return masks_np + + def segment_from_boxes(self, image, bbox, save_path=None): + ''' + input_boxes = torch.tensor(bbox, device=self.predictor.device) + transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2]) + masks, _, _ = self.predictor.predict_torch( + point_coords=None, + point_labels=None, + boxes=transformed_boxes, + multimask_output=False, + ) + ''' + results = self.sam_model.predict(source=image, bboxes=bbox) + mask_boxes = results[0].masks.data + + masks_np = np.zeros(mask_boxes.shape[-2:], dtype=np.uint8) + for i, mask in enumerate(mask_boxes): + mask_np = mask.cpu().numpy().astype(np.uint8) * 255 # True는 255, False는 0으로 변환 + mask_np = mask_np.squeeze() + if i in self.indices: + # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13)) # Adjust size as needed + kernel = np.ones((11, 11), np.uint8) + mask_np = cv2.morphologyEx(mask_np, cv2.MORPH_GRADIENT, kernel) + masks_np = cv2.bitwise_or(masks_np, mask_np) + # cv2.imwrite(f'mask_box{i}.jpg', masks_np) + + # cv2.imwrite(save_path, masks_np) + logging.info(f'1차 마스킹 - 바운딩 박스 {len(mask_boxes)}개 세그먼트 완료.') + return masks_np + + def inpainting(self, image, mask_total): + # inpainted_image = cv2.inpaint(image.copy(), mask_total, 10, cv2.INPAINT_TELEA) + print(mask_total.shape) # (1893, 1577) with 0 or 255 + # print(image.shape) # (1893, 1577, 3) + + inpainted_image = image.copy() + inpainted_image[mask_total == 255] = [255, 255, 255] + final_image = cv2.convertScaleAbs(inpainted_image, alpha=1.5, beta=10) + # cv2.imwrite('test_images/inpainted_init.png', inpainted_image) + # cv2.imwrite('test_images/inpainted_final.png', final_image) + logging.info('인페인팅 및 후보정 완료.') + + return final_image + + def combine_alpha(self, image_rgb): + if self.alpha_channel is not None: # RGBA + image_rgba = cv2.merge([image_rgb, self.alpha_channel]) + return image_rgba + else: + return image_rgb + + def process(self, img_bytes, user_points, user_labels, extension='jpg'): + """ local test 용도 Vs. server test 용도 구분 """ + ### local 용도 + # img_path = img_bytes + # image = cv2.imread(img_path, cv2.IMREAD_COLOR) + + ### server 용도 + buffer = np.frombuffer(img_bytes, dtype=np.uint8) + image = cv2.imdecode(buffer, cv2.IMREAD_UNCHANGED) + + ### ready + #self.load_sam_model() + #self.predictor.set_image(image) + masks_total = np.zeros(image.shape[:2], dtype=np.uint8) + + ### 1차: Object Detection & Segment by Box + bbox = self.object_detection(image) + if len(bbox) > 0: + masks_by_box = self.segment_from_boxes(image, bbox, save_path=None) # 'test_images/seg_box.png' + masks_total = cv2.bitwise_or(masks_total, masks_by_box) + #logging.info( f"1차 마스킹 후 shape 점검: YOLOv11 감지된 영역 shape: {masks_by_box.shape}, 이미지 영역 shape: {image.shape}") # (1893, 1577, 3) (1893, 1577) + else: + masks_by_box = None + ### 2차: points arguments by User & Segment by Points + if len(user_points) > 0: + mask_by_point = self.segment_from_points(image, user_points, user_labels, bbox, save_path=None) # save_path='test_images/seg_points.png' + masks_total = cv2.bitwise_or(masks_total, mask_by_point) + else: + mask_by_point = None + # cv2.imwrite('test_images/mask_total.png', masks_total) + if isinstance(masks_total, np.ndarray): + image_output = self.inpainting(image, masks_total) + logging.info('최종 마스킹 이미지 생성 완료') + else: + image_output = image + logging.info('최종 마스킹이 없거나, numpy 형식의 배열이 아님.') + + # image_input = self.combine_alpha(image_rgb) + # image_output = self.combine_alpha(image_output) + _, input_bytes = cv2.imencode("." + extension, image) + _, mask_bytes = cv2.imencode("." + extension, masks_total.astype(np.uint8)) + _, result_bytes = cv2.imencode("." + extension, image_output) + + if mask_by_point is not None and masks_by_box is not None: + _, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8)) + _, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8)) + return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), io.BytesIO(mask_by_point_bytes) + elif mask_by_point is not None: + _, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8)) + return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, io.BytesIO(mask_by_point_bytes) + elif masks_by_box is not None: + _, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8)) + return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), None + else: + return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, None + + diff --git a/app/main.py b/app/main.py index 5abd3dd..98eeb41 100644 --- a/app/main.py +++ b/app/main.py @@ -7,6 +7,7 @@ from starlette import status from starlette.responses import StreamingResponse, JSONResponse from ColorRemover import ColorRemover +from AIProcessor import AIProcessor import ImageFunctions as ImageManager import logging import os @@ -39,8 +40,10 @@ def create_file_path(obj_path, extension): file_id = uuid4() # 각 클라이언트마다 고유한 파일 ID 생성 dir_path = obj_path.rsplit('/', 1)[0] paths = {"input_path": f"{dir_path}/{file_id}.input.{extension}", - "output_path": f"{dir_path}/{file_id}.output.{extension}", "mask_path": f"{dir_path}/{file_id}.mask.{extension}", + "output_path": f"{dir_path}/{file_id}.output.{extension}", + "one": f"{dir_path}/{file_id}.mask_b.{extension}", + "two": f"{dir_path}/{file_id}.mask_p.{extension}", "extension": extension} return paths @@ -65,10 +68,84 @@ def upload_image_to_s3(file_bytes, file_path): s3_client.upload_fileobj(file_bytes, BUCKET_NAME, file_path) +def download_model_from_s3(yolo_path: str = 'models/yolo11_best.pt', sam_path: str = "models/mobile_sam.pt"): # models/sam_vit_h_4b8939.pth + dest_dir = f'../' # 모델을 저장할 컨테이너 내 경로 + try: + yolo_full_path = dest_dir+yolo_path + sam_full_path = dest_dir+sam_path + if not os.path.exists(yolo_full_path): + s3_client.download_file(BUCKET_NAME, yolo_path, yolo_full_path) + logger.info(f'YOLOv11 & SAM models downloaded successfully to {dest_dir}') + else: + logger.info(f'YOLOv11 already exists at {yolo_full_path}') + if not os.path.exists(sam_full_path): + s3_client.download_file(BUCKET_NAME, sam_path, sam_full_path) + logger.info(f'SAM models downloaded successfully to {dest_dir}') + else: + logger.info(f'SAM models already exists at {dest_dir}') + + logger.info(f"Files in 'models' Dir: {os.listdir(dest_dir)}") + except Exception as e: + print(f'Failed to download model: {e}') + + @app.get("/", status_code=status.HTTP_200_OK) def greeting(): - return JSONResponse(content={"message": "Hello! Let's start image processing"}) + return JSONResponse(content={"message": "Hello! Welcome to OnO's FastAPI Server!"}) + +@app.get("/load-models", status_code=status.HTTP_200_OK) +async def get_models(): + try: + download_model_from_s3() + except Exception as e: + logger.error("Error with Download & Saving AIs: %s", e) + raise HTTPException(status_code=500, detail="Error with Download & Saving AIs") + +get_models() + +@app.post("/process-shape") +async def processShape(request: Request): + """ AI handwriting detection & Telea Algorithm-based inpainting """ + data = await request.json() + full_url = data['fullUrl'] + point_list = data.get('points') + label_list = data.get('labels') # value or None + logger.info(f"사용자 입력 포인트: {point_list}") + logger.info(f"사용자 입력 라벨: {label_list}") + + try: + s3_key = parse_s3_url(full_url) + paths = create_file_path(s3_key, s3_key.split(".")[-1]) + img_bytes = download_image_from_s3(s3_key) # download from S3 + corrected_img_bytes = ImageManager.correct_rotation(img_bytes, paths['extension']) + logger.info(f"시용자 입력 이미지({s3_key}) 다운로드 및 전처리 완료") + + # aiProcessor = AIProcessor(yolo_path='/Users/semin/models/yolo11_best.pt', sam_path='/Users/semin/models/mobile_sam.pt') # local + aiProcessor = AIProcessor(yolo_path="../models/yolo11_best.pt", sam_path="../models/mobile_sam.pt") # server + img_input_bytes, img_mask_bytes, img_output_bytes, one, two = aiProcessor.process(img_bytes=corrected_img_bytes, + user_points=point_list, + user_labels=label_list) + logger.info("AI 필기 제거 프로세스 완료") + + upload_image_to_s3(img_input_bytes, paths["input_path"]) + upload_image_to_s3(img_mask_bytes, paths["mask_path"]) + upload_image_to_s3(img_output_bytes, paths["output_path"]) + if one is not None: + upload_image_to_s3(one, paths["one"]) + if two is not None: + upload_image_to_s3(two, paths["two"]) + + logger.info("AI 필기 제거 결과 이미지 업로드 완료") + return JSONResponse(content={"message": "File processed successfully", "path": paths}) + + except KeyError as e: + raise HTTPException(status_code=400, detail=f"Missing key: {e.args[0]}") + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + except Exception as e: + logger.error("Error during processing: %s", e) + raise HTTPException(status_code=500, detail="Error processing the image.") @app.post("/process-color") async def processColor(request: Request): @@ -458,8 +535,8 @@ async def retrieve(problem_text: str): async def augment(curriculum_context, query): prompt = ("너는 고등학생의 오답 문제를 통해 약점을 보완해주는 공책이야. \ 교육과정을 참고해서 오답 문제 핵심 의도를 바탕으로 문제에서 헷갈릴만한 요소, \ - 학생이 놓친 것 같은 중요한 개념을 찾아 그 개념에 대해 4줄 이내로 설명해주고, \ - 그 개념을 적용해서 풀이를 요점에 따라서 짧게(4줄 내외) 작성해줘. \ + 이 문제를 틀렸다면 놓칠 것 같은 같은 중요한 개념을 연관지어 그 개념에 대해 4줄 이내로 설명해주고, \ + 그 개념을 적용해서 풀이를 핵심적 원리에 집중하여 짧게(4줄 내외) 작성해줘. \ 만약 오답 문제와 교과과정이 관련이 없다고 판단되면, 교육과정은 참고하지 않으면 돼. \n\n\n") passage = f"오답 문제 : {query} \n\n\n" context = f"교과과정 : {curriculum_context} \n\n\n" diff --git a/requirements.txt b/requirements.txt index 2abd4e0..230ff78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,4 +44,9 @@ botocore==1.34.141 requests==2.32.3 charset-normalizer==3.3.2 openai==1.46.0 -pymilvus==2.4.6 \ No newline at end of file +pymilvus==2.4.6 +torch==2.5.1 +torchvision==0.20.1 +matplotlib==3.9.1.post1 +ultralytics==8.3.31 +git+https://github.com/facebookresearch/segment-anything.git@dca509fe793f601edb92606367a655c15ac00fdf