diff --git a/app/main.py b/app/main.py index eaf2cf9..ac0ba4f 100644 --- a/app/main.py +++ b/app/main.py @@ -12,6 +12,8 @@ import os from openai import OpenAI from pymilvus import connections, utility, FieldSchema, CollectionSchema, DataType, Collection +import time +from datetime import datetime # 로깅 추가 logging.basicConfig(level=logging.INFO) @@ -20,10 +22,10 @@ app = FastAPI() # 클라이언트 생성 -s3_client = boto3.client( "s3", - region_name="ap-northeast-2") +s3_client = boto3.client("s3", + region_name="ap-northeast-2",) ssm_client = boto3.client('ssm', - region_name='ap-northeast-2') + region_name="ap-northeast-2",) # s3 버킷 연결 try: @@ -150,7 +152,7 @@ async def upload_directly(upload_file: UploadFile = File(...)): @app.get("/analysis/whole") -async def analysis(problem_url = None): +async def analysis(problem_url=None): """ Curriculum-based Chat Completion API with CLOVA OCR & ChatGPT """ await connect_milvus() # milvus 서버 연결 @@ -159,11 +161,15 @@ async def analysis(problem_url = None): else: problem_text = await ocr(problem_url) - retrieving_result = await retrieve(problem_text) + retrieving_result, subjects, units, concepts = await retrieve(problem_text) question = await augment(retrieving_result, problem_text) answer = await generate(question) - return JSONResponse(content={"message": "Problem Analysis Finished Successfully", "answer": answer}) + return JSONResponse(content={"message": "Problem Analysis Finished Successfully", + "subject": list(set(subjects)), + "unit": list(set(units)), + "key_concept": list(set(concepts)), + "answer": answer}) @app.get("/analysis/ocr") @@ -174,12 +180,13 @@ async def ocr(problem_url: str): import time import json - logger.info("Analyzing problem from this image URL: %s", problem_url) try: + dt3 = datetime.fromtimestamp(time.time()) s3_key = parse_s3_url(problem_url) img_bytes = download_image_from_s3(s3_key) # download from S3 extension = s3_key.split(".")[-1] - logger.info("Completed Download & Sending Requests... '%s'", s3_key) + dt4 = datetime.fromtimestamp(time.time()) + logger.info(f"{dt3}~{dt4}: 이미지 다운로드 완료") clova_api_url = ssm_client.get_parameter( Name='/ono/fastapi/CLOVA_API_URL', @@ -211,18 +218,18 @@ async def ocr(problem_url: str): files = [ ('file', image_file) ] - logger.info("Processing OCR & Receiving Responses...") + dt3 = datetime.fromtimestamp(time.time()) ocr_response = requests.request("POST", clova_api_url, headers=headers, data=payload, files=files).text ocr_response_json = json.loads(ocr_response) - logger.info("***** Finished OCR Successfully *****") + dt4 = datetime.fromtimestamp(time.time()) + logger.info(f"{dt3}~{dt4}: 이미지 OCR 완료") infer_texts = [] for image in ocr_response_json["images"]: for field in image["fields"]: infer_texts.append(field["inferText"]) result = ' '.join(infer_texts) - return result except Exception as pe: @@ -256,14 +263,14 @@ async def upload_curriculum_txt(upload_file: UploadFile = File(...)): # Mivlus DB 연결 SERVER = os.getenv('SERVER') -logger.info(f"* log >> 환경변수 SERVER: {SERVER}로 받아왔습니다.") +logger.info(f"* log >> 환경변수를 SERVER({SERVER})로 받아왔습니다.") MILVUS_HOST = ssm_client.get_parameter( Name=f'/ono/{SERVER}/fastapi/MILVUS_HOST_NAME', WithDecryption=False )['Parameter']['Value'] MILVUS_PORT = 19530 -COLLECTION_NAME = 'Math2015Curriculum' +COLLECTION_NAME = 'Curriculum2015' DIMENSION = 1536 INDEX_TYPE = "IVF_FLAT" @@ -272,13 +279,16 @@ async def upload_curriculum_txt(upload_file: UploadFile = File(...)): async def connect_milvus(): try: # Milvus 서버 연결 - connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT) - logger.info(f"* log >> Milvus Server is connected to {MILVUS_HOST}:{MILVUS_PORT}") + dt1 = str(datetime.fromtimestamp(time.time())) + connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT) # server 용 + # connections.connect(host="127.0.0.1", port=19530, db="default") # localhost 용 + dt2 = str(datetime.fromtimestamp(time.time())) + logger.info(f"{dt1} ~ {dt2}: Milvus 서버 {MILVUS_HOST}:{MILVUS_PORT}에 연결 완료") # 컬렉션의 스키마 출력 if utility.has_collection(COLLECTION_NAME): collection = Collection(COLLECTION_NAME) - logger.info("* 존재하는 Collection Schema:") + logger.info(f"* 존재하는 Collection {COLLECTION_NAME} Schema:") for field in collection.schema.fields: logger.info(f" - Field Name: {field.name}, Data Type #: {field.dtype}") @@ -297,18 +307,16 @@ async def create_milvus(): FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='content', dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name='content_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION), + FieldSchema(name='subject_name', dtype=DataType.VARCHAR, max_length=100), # Meta Data1 + FieldSchema(name='unit_name', dtype=DataType.VARCHAR, max_length=100), # Meta Data2 + FieldSchema(name='main_concept', dtype=DataType.VARCHAR, max_length=100), # Meta Data3 ] - schema = CollectionSchema(fields=fields, description='Math2015Curriculum embedding collection') + schema = CollectionSchema(fields=fields, description='2015 Korean High School Curriculum Collection') collection = Collection(name=COLLECTION_NAME, schema=schema) - logger.info(f"* log >> Collection [{COLLECTION_NAME}] is created.") + logger.info(f"* log >> New Collection [{COLLECTION_NAME}] is created.") # 인덱스 생성 - # 스칼라 인덱스 - collection.create_index( - field_name="id" - ) - # 벡터 인덱스 - index_params = { + index_params = { # 벡터 인덱스 'index_type': INDEX_TYPE, 'metric_type': 'COSINE', 'params': { @@ -342,15 +350,15 @@ def get_embedding(client, text_list): @app.get("/milvus/insert") -async def insert_curriculum_embeddings(): +async def insert_curriculum_embeddings(subject: str): """ s3에서 교과과정을 읽고 임베딩하여 Milvus에 삽입 """ # Milvus 연결 await connect_milvus() collection = Collection(COLLECTION_NAME) # S3 내 커리큘럼 데이터 로드 - texts = [] - prefix = 'curriculum/science2015/' # 경로 + texts, subject_names, unit_names, main_concepts = [], [], [], [] + prefix = f'curriculum/{subject}2015/' # 경로 try: # 버킷에서 파일 목록 가져오기 s3_curriculum_response = s3_client.list_objects_v2(Bucket=BUCKET_NAME, Prefix=prefix) @@ -360,20 +368,34 @@ async def insert_curriculum_embeddings(): # S3 객체 가져오기 obj = s3_client.get_object(Bucket=BUCKET_NAME, Key=s3_key) # 텍스트 읽기 - text = obj['Body'].read().decode('utf-8') + data = obj['Body'].read().decode('utf-8') + lines = data.splitlines() + # 메타 데이터 추출 + meatdata_lines = [line.strip('#').strip() for line in lines[:3]] + subject_name = meatdata_lines[0] + subject_names.append(subject_name) + unit_name = meatdata_lines[1] + unit_names.append(unit_name) + main_concept = meatdata_lines[2] + main_concepts.append(main_concept) + # 교과과정 내용 추출 + text = '\n'.join(lines[3:]).strip() texts.append(text) logger.info(f"* log >> read {len(texts)} texts from S3") except Exception as e: logger.error(f"Error reading curriculum from S3: {e}") - # 데이터 임베딩 + # 교과과정 내용 임베딩 content_embeddings = get_embedding(openai_client, texts) logger.info(f"* log >> embedding 완료. dimension: {DIMENSION}") # 데이터 삽입 data = [ texts, # content 필드 - content_embeddings # content_embedding 필드 + content_embeddings, # content_embedding 필드 + subject_names, # subject_name 필드 + unit_names, # unit_name 필드 + main_concepts # main_concept 필드 ] collection.insert(data) @@ -390,10 +412,12 @@ async def retrieve(problem_text: str): # 검색 테스트 query = problem_text + dt5 = str(datetime.fromtimestamp(time.time())) query_embeddings = [get_embedding(openai_client, [query])] if not query_embeddings or query_embeddings[0] is None: raise ValueError("Embedding generation failed") - logger.info(f"* log >> Query embedding 완료") + dt6 = str(datetime.fromtimestamp(time.time())) + logger.info(f"{dt5} ~ {dt6}: 쿼리 임베딩 완료") search_params = { 'metric_type': 'COSINE', @@ -401,33 +425,42 @@ async def retrieve(problem_text: str): 'probe': 20 }, } + dt5 = str(datetime.fromtimestamp(time.time())) results = collection.search( data=query_embeddings[0], anns_field='content_embedding', param=search_params, limit=3, expr=None, - output_fields=['content'] + output_fields=['content', 'subject_name', 'unit_name', 'main_concept'] ) + dt6 = str(datetime.fromtimestamp(time.time())) context = ' '.join([result.entity.get('content') for result in results[0]]) - logger.info(f"* log >> context found") - - # 결과 확인 - '''logger.info(f"* log >> 쿼리 결과") + subjects_list = [result.entity.get('subject_name') for result in results[0]] + unit_list = [result.entity.get('unit_name') for result in results[0]] + main_concept_list = [result.entity.get('main_concept') for result in results[0]] + logger.info(f"{dt5} ~ {dt6}: 검색 완료") + logs = "" for result in results[0]: - logger.info("\n-------------------------------------------------------------------") - logger.info(f"Score : {result.distance}, \nText : \n{result.entity.get('content')}")''' - - return context + logs += ("\n"+f"Score : {result.distance}, \ + \nInfo: {result.entity.get('subject_name')}\ + > {result.entity.get('unit_name')}\ + > {result.entity.get('main_concept')}, \ + \nText : {result.entity.get('content')}"+"\n\n") + logger.info(f"* log >> 검색 결과: {logs}") + + return context, subjects_list, unit_list, main_concept_list except Exception as e: logger.error(f"Error in search: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/analysis/augmentation") async def augment(curriculum_context, query): - prompt = "교과과정에 기반하여 이 문제에 필요한 개념을 말해줘. 응답은 자연어처럼 제공해줘. \n" - context = curriculum_context - passage = query + prompt = "너는 공책이야. 내가 준 교육과정 중에 아래 문제와 관련된 교육과정을 골라서 고등학생 한 명에게\ + 이 문제를 왜 틀리거나 헷갈릴 수 있으며, 어떤 개념과 사고과정 등이 필요해 보이는지를 \ + 이 문제의 의도와 핵심을 짚어가되 너무 길지 않게 정리해서 고등학생에게 보여줘.\n" + context = f"교과과정은 이렇고 {curriculum_context}" + passage = f"문제는 이러해 {query}." augmented_query = prompt + context + passage return augmented_query @@ -435,6 +468,7 @@ async def augment(curriculum_context, query): async def generate(question): def get_chatgpt_response(client, question, model="gpt-4o-mini"): try: + dt7 = str(datetime.fromtimestamp(time.time())) gpt_response = client.chat.completions.create( model=model, messages=[ @@ -442,14 +476,16 @@ def get_chatgpt_response(client, question, model="gpt-4o-mini"): "content": question } ], - temperature=0.5 + temperature=0.7 ) + dt8 = str(datetime.fromtimestamp(time.time())) + logger.info(f"{dt7} ~ {dt8}: LLM 응답 완료") return gpt_response.choices[0].message.content except Exception as e: logger.info(f"Error during GPT querying: {e}") return None chatgpt_response = get_chatgpt_response(openai_client, question) - logger.info(f"* log >> ChatGPT Response: {chatgpt_response}") + logger.info(f"* log >> 응답 결과 \n {chatgpt_response}") return chatgpt_response