Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] AI기반 필기인식 기능 #27

Merged
merged 17 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ FROM python:3.12-slim
LABEL org.opencontainers.image.source="https://github.com/AI-SIP/MVP_CV"

# 필요한 시스템 패키지 먼저 설치
RUN apt-get update && apt-get install -y --no-install-recommends \
RUN apt-get update && apt-get install -y git \
libgl1-mesa-glx \
libglib2.0-0 \
gcc \
python3-dev && \
rm -rf /var/lib/apt/lists/*
rm -rf /var/lib/apt/lists/* \

# 작업 디렉토리 설정
WORKDIR /test
Expand All @@ -26,4 +26,4 @@ COPY ./app /test/app/
WORKDIR /test/app

# 컨테이너 실행 시 실행할 명령어
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
CMD ["bash", "-c", "mkdir -p /test/models && uvicorn main:app --host 0.0.0.0 --port 8000"]
197 changes: 197 additions & 0 deletions app/AIProcessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import torch
import torchvision
import numpy as np
import cv2
from ultralytics import YOLO
from ultralytics import SAM
from segment_anything import sam_model_registry, SamPredictor
import io
import logging


class AIProcessor:
def __init__(self, yolo_path: str, sam_path: str, device: torch.device = torch.device("cpu")):
self.yolo_path = yolo_path
self.sam_path = sam_path
self.device = device
self.yolo_model = YOLO(self.yolo_path) # yolo 로드
self.sam_model = SAM(self.sam_path) # None
self.predictor = None
self.indices = None
self.alpha_channel = None
self.size = 0
self.height = 0
self.width = 0

def remove_points_in_bboxes(self, point_list, label_list, bbox_list):
def is_point_in_bbox(p, b):
x, y = p
x_min, y_min, x_max, y_max = b
return x_min <= x <= x_max and y_min <= y <= y_max

filtered_p = []
filtered_l = []
for p, l in zip(point_list, label_list):
if not any(is_point_in_bbox(p, b) for b in bbox_list):
filtered_p.append(p)
filtered_l.append(l)

return filtered_p, filtered_l

def remove_alpha(self, image):
self.height, self.width = image.shape[:2]
self.size = self.height * self.width
if image.shape[2] == 4: # RGBA or RGB
self.alpha_channel = image[:, :, 3]
image_rgb = image[:, :, :3]
else:
image_rgb = image

return image_rgb

def load_sam_model(self, model_type="vit_h"): # sam 로드
self.sam_model = sam_model_registry[model_type](checkpoint=self.sam_path)
self.sam_model.to(self.device)
self.predictor = SamPredictor(self.sam_model)

def object_detection(self, img):
results = self.yolo_model.predict(source=img, imgsz=640, device=self.device,
iou=0.3, conf=0.3)
bbox = results[0].boxes.xyxy.tolist()
self.indices = [index for index, value in enumerate(results[0].boxes.cls) if value == 1.0]
logging.info(f'객체 탐지 - {self.indices} 박스에 동그라미 존재')
return bbox

def segment_from_points(self, image, user_points, user_labels, bbox, save_path=None):
'''input_point = np.array(user_points)
input_label = np.array(user_labels)

masks_points, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)'''
# filtered_points, filtered_labels = self.remove_points_in_bboxes(user_points, user_labels, bbox)
# logging.info(f'2차 마스킹 - 사용자 입력 필터링: from {len(user_points)}개 to {len(filtered_points)}개')

# results = self.sam_model.predict(source=image, points=filtered_points, labels=filtered_labels)
results = self.sam_model.predict(source=image, bboxes=[user_points])
mask_points = results[0].masks.data

masks_np = np.zeros(mask_points.shape[-2:], dtype=np.uint8)
for mask in mask_points:
mask_np = mask.cpu().numpy().astype(np.uint8) * 255
mask_np = mask_np.squeeze()
masks_np = cv2.bitwise_or(masks_np, mask_np)

# cv2.imwrite(save_path, mask_points_uint8)
logging.info(f'2차 마스킹 - 사용자 입력에 대해 {len(mask_points)}개 영역으로 세그먼트 완료.')
return masks_np

def segment_from_boxes(self, image, bbox, save_path=None):
'''
input_boxes = torch.tensor(bbox, device=self.predictor.device)
transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = self.predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
'''
results = self.sam_model.predict(source=image, bboxes=bbox)
mask_boxes = results[0].masks.data

masks_np = np.zeros(mask_boxes.shape[-2:], dtype=np.uint8)
for i, mask in enumerate(mask_boxes):
mask_np = mask.cpu().numpy().astype(np.uint8) * 255 # True는 255, False는 0으로 변환
mask_np = mask_np.squeeze()
if i in self.indices:
# kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13)) # Adjust size as needed
kernel = np.ones((11, 11), np.uint8)
mask_np = cv2.morphologyEx(mask_np, cv2.MORPH_GRADIENT, kernel)
masks_np = cv2.bitwise_or(masks_np, mask_np)
# cv2.imwrite(f'mask_box{i}.jpg', masks_np)

# cv2.imwrite(save_path, masks_np)
logging.info(f'1차 마스킹 - 바운딩 박스 {len(mask_boxes)}개 세그먼트 완료.')
return masks_np

def inpainting(self, image, mask_total):
# inpainted_image = cv2.inpaint(image.copy(), mask_total, 10, cv2.INPAINT_TELEA)
print(mask_total.shape) # (1893, 1577) with 0 or 255
# print(image.shape) # (1893, 1577, 3)

inpainted_image = image.copy()
inpainted_image[mask_total == 255] = [255, 255, 255]
final_image = cv2.convertScaleAbs(inpainted_image, alpha=1.5, beta=10)
# cv2.imwrite('test_images/inpainted_init.png', inpainted_image)
# cv2.imwrite('test_images/inpainted_final.png', final_image)
logging.info('인페인팅 및 후보정 완료.')

return final_image

def combine_alpha(self, image_rgb):
if self.alpha_channel is not None: # RGBA
image_rgba = cv2.merge([image_rgb, self.alpha_channel])
return image_rgba
else:
return image_rgb

def process(self, img_bytes, user_points, user_labels, extension='jpg'):
""" local test 용도 Vs. server test 용도 구분 """
### local 용도
# img_path = img_bytes
# image = cv2.imread(img_path, cv2.IMREAD_COLOR)

### server 용도
buffer = np.frombuffer(img_bytes, dtype=np.uint8)
image = cv2.imdecode(buffer, cv2.IMREAD_UNCHANGED)

### ready
#self.load_sam_model()
#self.predictor.set_image(image)
masks_total = np.zeros(image.shape[:2], dtype=np.uint8)

### 1차: Object Detection & Segment by Box
bbox = self.object_detection(image)
if len(bbox) > 0:
masks_by_box = self.segment_from_boxes(image, bbox, save_path=None) # 'test_images/seg_box.png'
masks_total = cv2.bitwise_or(masks_total, masks_by_box)
#logging.info( f"1차 마스킹 후 shape 점검: YOLOv11 감지된 영역 shape: {masks_by_box.shape}, 이미지 영역 shape: {image.shape}") # (1893, 1577, 3) (1893, 1577)
else:
masks_by_box = None
### 2차: points arguments by User & Segment by Points
if len(user_points) > 0:
mask_by_point = self.segment_from_points(image, user_points, user_labels, bbox, save_path=None) # save_path='test_images/seg_points.png'
masks_total = cv2.bitwise_or(masks_total, mask_by_point)
else:
mask_by_point = None
# cv2.imwrite('test_images/mask_total.png', masks_total)
if isinstance(masks_total, np.ndarray):
image_output = self.inpainting(image, masks_total)
logging.info('최종 마스킹 이미지 생성 완료')
else:
image_output = image
logging.info('최종 마스킹이 없거나, numpy 형식의 배열이 아님.')

# image_input = self.combine_alpha(image_rgb)
# image_output = self.combine_alpha(image_output)
_, input_bytes = cv2.imencode("." + extension, image)
_, mask_bytes = cv2.imencode("." + extension, masks_total.astype(np.uint8))
_, result_bytes = cv2.imencode("." + extension, image_output)

if mask_by_point is not None and masks_by_box is not None:
_, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8))
_, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), io.BytesIO(mask_by_point_bytes)
elif mask_by_point is not None:
_, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, io.BytesIO(mask_by_point_bytes)
elif masks_by_box is not None:
_, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), None
else:
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, None


85 changes: 81 additions & 4 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from starlette import status
from starlette.responses import StreamingResponse, JSONResponse
from ColorRemover import ColorRemover
from AIProcessor import AIProcessor
import ImageFunctions as ImageManager
import logging
import os
Expand Down Expand Up @@ -39,8 +40,10 @@ def create_file_path(obj_path, extension):
file_id = uuid4() # 각 클라이언트마다 고유한 파일 ID 생성
dir_path = obj_path.rsplit('/', 1)[0]
paths = {"input_path": f"{dir_path}/{file_id}.input.{extension}",
"output_path": f"{dir_path}/{file_id}.output.{extension}",
"mask_path": f"{dir_path}/{file_id}.mask.{extension}",
"output_path": f"{dir_path}/{file_id}.output.{extension}",
"one": f"{dir_path}/{file_id}.mask_b.{extension}",
"two": f"{dir_path}/{file_id}.mask_p.{extension}",
"extension": extension}
return paths

Expand All @@ -65,10 +68,84 @@ def upload_image_to_s3(file_bytes, file_path):
s3_client.upload_fileobj(file_bytes, BUCKET_NAME, file_path)


def download_model_from_s3(yolo_path: str = 'models/yolo11_best.pt', sam_path: str = "models/mobile_sam.pt"): # models/sam_vit_h_4b8939.pth
dest_dir = f'../' # 모델을 저장할 컨테이너 내 경로
try:
yolo_full_path = dest_dir+yolo_path
sam_full_path = dest_dir+sam_path
if not os.path.exists(yolo_full_path):
s3_client.download_file(BUCKET_NAME, yolo_path, yolo_full_path)
logger.info(f'YOLOv11 & SAM models downloaded successfully to {dest_dir}')
else:
logger.info(f'YOLOv11 already exists at {yolo_full_path}')
if not os.path.exists(sam_full_path):
s3_client.download_file(BUCKET_NAME, sam_path, sam_full_path)
logger.info(f'SAM models downloaded successfully to {dest_dir}')
else:
logger.info(f'SAM models already exists at {dest_dir}')

logger.info(f"Files in 'models' Dir: {os.listdir(dest_dir)}")
except Exception as e:
print(f'Failed to download model: {e}')


@app.get("/", status_code=status.HTTP_200_OK)
def greeting():
return JSONResponse(content={"message": "Hello! Let's start image processing"})
return JSONResponse(content={"message": "Hello! Welcome to OnO's FastAPI Server!"})


@app.get("/load-models", status_code=status.HTTP_200_OK)
async def get_models():
try:
download_model_from_s3()
except Exception as e:
logger.error("Error with Download & Saving AIs: %s", e)
raise HTTPException(status_code=500, detail="Error with Download & Saving AIs")

get_models()

@app.post("/process-shape")
async def processShape(request: Request):
""" AI handwriting detection & Telea Algorithm-based inpainting """
data = await request.json()
full_url = data['fullUrl']
point_list = data.get('points')
label_list = data.get('labels') # value or None
logger.info(f"사용자 입력 포인트: {point_list}")
logger.info(f"사용자 입력 라벨: {label_list}")

try:
s3_key = parse_s3_url(full_url)
paths = create_file_path(s3_key, s3_key.split(".")[-1])
img_bytes = download_image_from_s3(s3_key) # download from S3
corrected_img_bytes = ImageManager.correct_rotation(img_bytes, paths['extension'])
logger.info(f"시용자 입력 이미지({s3_key}) 다운로드 및 전처리 완료")

# aiProcessor = AIProcessor(yolo_path='/Users/semin/models/yolo11_best.pt', sam_path='/Users/semin/models/mobile_sam.pt') # local
aiProcessor = AIProcessor(yolo_path="../models/yolo11_best.pt", sam_path="../models/mobile_sam.pt") # server
img_input_bytes, img_mask_bytes, img_output_bytes, one, two = aiProcessor.process(img_bytes=corrected_img_bytes,
user_points=point_list,
user_labels=label_list)
logger.info("AI 필기 제거 프로세스 완료")

upload_image_to_s3(img_input_bytes, paths["input_path"])
upload_image_to_s3(img_mask_bytes, paths["mask_path"])
upload_image_to_s3(img_output_bytes, paths["output_path"])
if one is not None:
upload_image_to_s3(one, paths["one"])
if two is not None:
upload_image_to_s3(two, paths["two"])

logger.info("AI 필기 제거 결과 이미지 업로드 완료")
return JSONResponse(content={"message": "File processed successfully", "path": paths})

except KeyError as e:
raise HTTPException(status_code=400, detail=f"Missing key: {e.args[0]}")
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logger.error("Error during processing: %s", e)
raise HTTPException(status_code=500, detail="Error processing the image.")

@app.post("/process-color")
async def processColor(request: Request):
Expand Down Expand Up @@ -458,8 +535,8 @@ async def retrieve(problem_text: str):
async def augment(curriculum_context, query):
prompt = ("너는 고등학생의 오답 문제를 통해 약점을 보완해주는 공책이야. \
교육과정을 참고해서 오답 문제 핵심 의도를 바탕으로 문제에서 헷갈릴만한 요소, \
학생이 놓친 것 같은 중요한 개념을 찾아 그 개념에 대해 4줄 이내로 설명해주고, \
그 개념을 적용해서 풀이를 요점에 따라서 짧게(4줄 내외) 작성해줘. \
이 문제를 틀렸다면 놓칠 것 같은 같은 중요한 개념을 연관지어 그 개념에 대해 4줄 이내로 설명해주고, \
그 개념을 적용해서 풀이를 핵심적 원리에 집중하여 짧게(4줄 내외) 작성해줘. \
만약 오답 문제와 교과과정이 관련이 없다고 판단되면, 교육과정은 참고하지 않으면 돼. \n\n\n")
passage = f"오답 문제 : {query} \n\n\n"
context = f"교과과정 : {curriculum_context} \n\n\n"
Expand Down
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,9 @@ botocore==1.34.141
requests==2.32.3
charset-normalizer==3.3.2
openai==1.46.0
pymilvus==2.4.6
pymilvus==2.4.6
torch==2.5.1
torchvision==0.20.1
matplotlib==3.9.1.post1
ultralytics==8.3.31
git+https://github.com/facebookresearch/segment-anything.git@dca509fe793f601edb92606367a655c15ac00fdf
Loading