Skip to content

Commit

Permalink
[Feat] AI기반 필기인식 기능
Browse files Browse the repository at this point in the history
[Feat] AI기반 필기인식 기능
  • Loading branch information
semnisem authored Nov 20, 2024
2 parents d318ba8 + 91d37ce commit 6ed2e98
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 8 deletions.
6 changes: 3 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ FROM python:3.12-slim
LABEL org.opencontainers.image.source="https://github.com/AI-SIP/MVP_CV"

# 필요한 시스템 패키지 먼저 설치
RUN apt-get update && apt-get install -y --no-install-recommends \
RUN apt-get update && apt-get install -y git \
libgl1-mesa-glx \
libglib2.0-0 \
gcc \
python3-dev && \
rm -rf /var/lib/apt/lists/*
rm -rf /var/lib/apt/lists/* \

# 작업 디렉토리 설정
WORKDIR /test
Expand All @@ -26,4 +26,4 @@ COPY ./app /test/app/
WORKDIR /test/app

# 컨테이너 실행 시 실행할 명령어
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
CMD ["bash", "-c", "mkdir -p /test/models && uvicorn main:app --host 0.0.0.0 --port 8000"]
197 changes: 197 additions & 0 deletions app/AIProcessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import torch
import torchvision
import numpy as np
import cv2
from ultralytics import YOLO
from ultralytics import SAM
from segment_anything import sam_model_registry, SamPredictor
import io
import logging


class AIProcessor:
def __init__(self, yolo_path: str, sam_path: str, device: torch.device = torch.device("cpu")):
self.yolo_path = yolo_path
self.sam_path = sam_path
self.device = device
self.yolo_model = YOLO(self.yolo_path) # yolo 로드
self.sam_model = SAM(self.sam_path) # None
self.predictor = None
self.indices = None
self.alpha_channel = None
self.size = 0
self.height = 0
self.width = 0

def remove_points_in_bboxes(self, point_list, label_list, bbox_list):
def is_point_in_bbox(p, b):
x, y = p
x_min, y_min, x_max, y_max = b
return x_min <= x <= x_max and y_min <= y <= y_max

filtered_p = []
filtered_l = []
for p, l in zip(point_list, label_list):
if not any(is_point_in_bbox(p, b) for b in bbox_list):
filtered_p.append(p)
filtered_l.append(l)

return filtered_p, filtered_l

def remove_alpha(self, image):
self.height, self.width = image.shape[:2]
self.size = self.height * self.width
if image.shape[2] == 4: # RGBA or RGB
self.alpha_channel = image[:, :, 3]
image_rgb = image[:, :, :3]
else:
image_rgb = image

return image_rgb

def load_sam_model(self, model_type="vit_h"): # sam 로드
self.sam_model = sam_model_registry[model_type](checkpoint=self.sam_path)
self.sam_model.to(self.device)
self.predictor = SamPredictor(self.sam_model)

def object_detection(self, img):
results = self.yolo_model.predict(source=img, imgsz=640, device=self.device,
iou=0.3, conf=0.3)
bbox = results[0].boxes.xyxy.tolist()
self.indices = [index for index, value in enumerate(results[0].boxes.cls) if value == 1.0]
logging.info(f'객체 탐지 - {self.indices} 박스에 동그라미 존재')
return bbox

def segment_from_points(self, image, user_points, user_labels, bbox, save_path=None):
'''input_point = np.array(user_points)
input_label = np.array(user_labels)
masks_points, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)'''
# filtered_points, filtered_labels = self.remove_points_in_bboxes(user_points, user_labels, bbox)
# logging.info(f'2차 마스킹 - 사용자 입력 필터링: from {len(user_points)}개 to {len(filtered_points)}개')

# results = self.sam_model.predict(source=image, points=filtered_points, labels=filtered_labels)
results = self.sam_model.predict(source=image, bboxes=[user_points])
mask_points = results[0].masks.data

masks_np = np.zeros(mask_points.shape[-2:], dtype=np.uint8)
for mask in mask_points:
mask_np = mask.cpu().numpy().astype(np.uint8) * 255
mask_np = mask_np.squeeze()
masks_np = cv2.bitwise_or(masks_np, mask_np)

# cv2.imwrite(save_path, mask_points_uint8)
logging.info(f'2차 마스킹 - 사용자 입력에 대해 {len(mask_points)}개 영역으로 세그먼트 완료.')
return masks_np

def segment_from_boxes(self, image, bbox, save_path=None):
'''
input_boxes = torch.tensor(bbox, device=self.predictor.device)
transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = self.predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
'''
results = self.sam_model.predict(source=image, bboxes=bbox)
mask_boxes = results[0].masks.data

masks_np = np.zeros(mask_boxes.shape[-2:], dtype=np.uint8)
for i, mask in enumerate(mask_boxes):
mask_np = mask.cpu().numpy().astype(np.uint8) * 255 # True는 255, False는 0으로 변환
mask_np = mask_np.squeeze()
if i in self.indices:
# kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13)) # Adjust size as needed
kernel = np.ones((11, 11), np.uint8)
mask_np = cv2.morphologyEx(mask_np, cv2.MORPH_GRADIENT, kernel)
masks_np = cv2.bitwise_or(masks_np, mask_np)
# cv2.imwrite(f'mask_box{i}.jpg', masks_np)

# cv2.imwrite(save_path, masks_np)
logging.info(f'1차 마스킹 - 바운딩 박스 {len(mask_boxes)}개 세그먼트 완료.')
return masks_np

def inpainting(self, image, mask_total):
# inpainted_image = cv2.inpaint(image.copy(), mask_total, 10, cv2.INPAINT_TELEA)
print(mask_total.shape) # (1893, 1577) with 0 or 255
# print(image.shape) # (1893, 1577, 3)

inpainted_image = image.copy()
inpainted_image[mask_total == 255] = [255, 255, 255]
final_image = cv2.convertScaleAbs(inpainted_image, alpha=1.5, beta=10)
# cv2.imwrite('test_images/inpainted_init.png', inpainted_image)
# cv2.imwrite('test_images/inpainted_final.png', final_image)
logging.info('인페인팅 및 후보정 완료.')

return final_image

def combine_alpha(self, image_rgb):
if self.alpha_channel is not None: # RGBA
image_rgba = cv2.merge([image_rgb, self.alpha_channel])
return image_rgba
else:
return image_rgb

def process(self, img_bytes, user_points, user_labels, extension='jpg'):
""" local test 용도 Vs. server test 용도 구분 """
### local 용도
# img_path = img_bytes
# image = cv2.imread(img_path, cv2.IMREAD_COLOR)

### server 용도
buffer = np.frombuffer(img_bytes, dtype=np.uint8)
image = cv2.imdecode(buffer, cv2.IMREAD_UNCHANGED)

### ready
#self.load_sam_model()
#self.predictor.set_image(image)
masks_total = np.zeros(image.shape[:2], dtype=np.uint8)

### 1차: Object Detection & Segment by Box
bbox = self.object_detection(image)
if len(bbox) > 0:
masks_by_box = self.segment_from_boxes(image, bbox, save_path=None) # 'test_images/seg_box.png'
masks_total = cv2.bitwise_or(masks_total, masks_by_box)
#logging.info( f"1차 마스킹 후 shape 점검: YOLOv11 감지된 영역 shape: {masks_by_box.shape}, 이미지 영역 shape: {image.shape}") # (1893, 1577, 3) (1893, 1577)
else:
masks_by_box = None
### 2차: points arguments by User & Segment by Points
if len(user_points) > 0:
mask_by_point = self.segment_from_points(image, user_points, user_labels, bbox, save_path=None) # save_path='test_images/seg_points.png'
masks_total = cv2.bitwise_or(masks_total, mask_by_point)
else:
mask_by_point = None
# cv2.imwrite('test_images/mask_total.png', masks_total)
if isinstance(masks_total, np.ndarray):
image_output = self.inpainting(image, masks_total)
logging.info('최종 마스킹 이미지 생성 완료')
else:
image_output = image
logging.info('최종 마스킹이 없거나, numpy 형식의 배열이 아님.')

# image_input = self.combine_alpha(image_rgb)
# image_output = self.combine_alpha(image_output)
_, input_bytes = cv2.imencode("." + extension, image)
_, mask_bytes = cv2.imencode("." + extension, masks_total.astype(np.uint8))
_, result_bytes = cv2.imencode("." + extension, image_output)

if mask_by_point is not None and masks_by_box is not None:
_, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8))
_, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), io.BytesIO(mask_by_point_bytes)
elif mask_by_point is not None:
_, mask_by_point_bytes = cv2.imencode("." + extension, mask_by_point.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, io.BytesIO(mask_by_point_bytes)
elif masks_by_box is not None:
_, mask_by_box_bytes = cv2.imencode("." + extension, masks_by_box.astype(np.uint8))
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), io.BytesIO(mask_by_box_bytes), None
else:
return io.BytesIO(input_bytes), io.BytesIO(mask_bytes), io.BytesIO(result_bytes), None, None


85 changes: 81 additions & 4 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from starlette import status
from starlette.responses import StreamingResponse, JSONResponse
from ColorRemover import ColorRemover
from AIProcessor import AIProcessor
import ImageFunctions as ImageManager
import logging
import os
Expand Down Expand Up @@ -39,8 +40,10 @@ def create_file_path(obj_path, extension):
file_id = uuid4() # 각 클라이언트마다 고유한 파일 ID 생성
dir_path = obj_path.rsplit('/', 1)[0]
paths = {"input_path": f"{dir_path}/{file_id}.input.{extension}",
"output_path": f"{dir_path}/{file_id}.output.{extension}",
"mask_path": f"{dir_path}/{file_id}.mask.{extension}",
"output_path": f"{dir_path}/{file_id}.output.{extension}",
"one": f"{dir_path}/{file_id}.mask_b.{extension}",
"two": f"{dir_path}/{file_id}.mask_p.{extension}",
"extension": extension}
return paths

Expand All @@ -65,10 +68,84 @@ def upload_image_to_s3(file_bytes, file_path):
s3_client.upload_fileobj(file_bytes, BUCKET_NAME, file_path)


def download_model_from_s3(yolo_path: str = 'models/yolo11_best.pt', sam_path: str = "models/mobile_sam.pt"): # models/sam_vit_h_4b8939.pth
dest_dir = f'../' # 모델을 저장할 컨테이너 내 경로
try:
yolo_full_path = dest_dir+yolo_path
sam_full_path = dest_dir+sam_path
if not os.path.exists(yolo_full_path):
s3_client.download_file(BUCKET_NAME, yolo_path, yolo_full_path)
logger.info(f'YOLOv11 & SAM models downloaded successfully to {dest_dir}')
else:
logger.info(f'YOLOv11 already exists at {yolo_full_path}')
if not os.path.exists(sam_full_path):
s3_client.download_file(BUCKET_NAME, sam_path, sam_full_path)
logger.info(f'SAM models downloaded successfully to {dest_dir}')
else:
logger.info(f'SAM models already exists at {dest_dir}')

logger.info(f"Files in 'models' Dir: {os.listdir(dest_dir)}")
except Exception as e:
print(f'Failed to download model: {e}')


@app.get("/", status_code=status.HTTP_200_OK)
def greeting():
return JSONResponse(content={"message": "Hello! Let's start image processing"})
return JSONResponse(content={"message": "Hello! Welcome to OnO's FastAPI Server!"})


@app.get("/load-models", status_code=status.HTTP_200_OK)
async def get_models():
try:
download_model_from_s3()
except Exception as e:
logger.error("Error with Download & Saving AIs: %s", e)
raise HTTPException(status_code=500, detail="Error with Download & Saving AIs")

get_models()

@app.post("/process-shape")
async def processShape(request: Request):
""" AI handwriting detection & Telea Algorithm-based inpainting """
data = await request.json()
full_url = data['fullUrl']
point_list = data.get('points')
label_list = data.get('labels') # value or None
logger.info(f"사용자 입력 포인트: {point_list}")
logger.info(f"사용자 입력 라벨: {label_list}")

try:
s3_key = parse_s3_url(full_url)
paths = create_file_path(s3_key, s3_key.split(".")[-1])
img_bytes = download_image_from_s3(s3_key) # download from S3
corrected_img_bytes = ImageManager.correct_rotation(img_bytes, paths['extension'])
logger.info(f"시용자 입력 이미지({s3_key}) 다운로드 및 전처리 완료")

# aiProcessor = AIProcessor(yolo_path='/Users/semin/models/yolo11_best.pt', sam_path='/Users/semin/models/mobile_sam.pt') # local
aiProcessor = AIProcessor(yolo_path="../models/yolo11_best.pt", sam_path="../models/mobile_sam.pt") # server
img_input_bytes, img_mask_bytes, img_output_bytes, one, two = aiProcessor.process(img_bytes=corrected_img_bytes,
user_points=point_list,
user_labels=label_list)
logger.info("AI 필기 제거 프로세스 완료")

upload_image_to_s3(img_input_bytes, paths["input_path"])
upload_image_to_s3(img_mask_bytes, paths["mask_path"])
upload_image_to_s3(img_output_bytes, paths["output_path"])
if one is not None:
upload_image_to_s3(one, paths["one"])
if two is not None:
upload_image_to_s3(two, paths["two"])

logger.info("AI 필기 제거 결과 이미지 업로드 완료")
return JSONResponse(content={"message": "File processed successfully", "path": paths})

except KeyError as e:
raise HTTPException(status_code=400, detail=f"Missing key: {e.args[0]}")
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logger.error("Error during processing: %s", e)
raise HTTPException(status_code=500, detail="Error processing the image.")

@app.post("/process-color")
async def processColor(request: Request):
Expand Down Expand Up @@ -458,8 +535,8 @@ async def retrieve(problem_text: str):
async def augment(curriculum_context, query):
prompt = ("너는 고등학생의 오답 문제를 통해 약점을 보완해주는 공책이야. \
교육과정을 참고해서 오답 문제 핵심 의도를 바탕으로 문제에서 헷갈릴만한 요소, \
학생이 놓친 것 같은 중요한 개념을 찾아 그 개념에 대해 4줄 이내로 설명해주고, \
그 개념을 적용해서 풀이를 요점에 따라서 짧게(4줄 내외) 작성해줘. \
이 문제를 틀렸다면 놓칠 것 같은 같은 중요한 개념을 연관지어 그 개념에 대해 4줄 이내로 설명해주고, \
그 개념을 적용해서 풀이를 핵심적 원리에 집중하여 짧게(4줄 내외) 작성해줘. \
만약 오답 문제와 교과과정이 관련이 없다고 판단되면, 교육과정은 참고하지 않으면 돼. \n\n\n")
passage = f"오답 문제 : {query} \n\n\n"
context = f"교과과정 : {curriculum_context} \n\n\n"
Expand Down
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,9 @@ botocore==1.34.141
requests==2.32.3
charset-normalizer==3.3.2
openai==1.46.0
pymilvus==2.4.6
pymilvus==2.4.6
torch==2.5.1
torchvision==0.20.1
matplotlib==3.9.1.post1
ultralytics==8.3.31
git+https://github.com/facebookresearch/segment-anything.git@dca509fe793f601edb92606367a655c15ac00fdf

0 comments on commit 6ed2e98

Please sign in to comment.