Skip to content

Commit

Permalink
[Feat & Fix & Comment] 오답분석 기능 답변 고도화 및 연관 개념 답변, 서버 로그 추가
Browse files Browse the repository at this point in the history
[Feat & Fix & Comment] 오답분석 기능 답변 고도화 및 연관 개념 답변, 서버 로그 추가
  • Loading branch information
semnisem authored Oct 8, 2024
2 parents 49a0a52 + 98ed29d commit a30e538
Showing 1 changed file with 81 additions and 45 deletions.
126 changes: 81 additions & 45 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import os
from openai import OpenAI
from pymilvus import connections, utility, FieldSchema, CollectionSchema, DataType, Collection
import time
from datetime import datetime

# 로깅 추가
logging.basicConfig(level=logging.INFO)
Expand All @@ -20,10 +22,10 @@
app = FastAPI()

# 클라이언트 생성
s3_client = boto3.client( "s3",
region_name="ap-northeast-2")
s3_client = boto3.client("s3",
region_name="ap-northeast-2",)
ssm_client = boto3.client('ssm',
region_name='ap-northeast-2')
region_name="ap-northeast-2",)

# s3 버킷 연결
try:
Expand Down Expand Up @@ -150,7 +152,7 @@ async def upload_directly(upload_file: UploadFile = File(...)):


@app.get("/analysis/whole")
async def analysis(problem_url = None):
async def analysis(problem_url=None):
""" Curriculum-based Chat Completion API with CLOVA OCR & ChatGPT """
await connect_milvus() # milvus 서버 연결

Expand All @@ -159,11 +161,15 @@ async def analysis(problem_url = None):
else:
problem_text = await ocr(problem_url)

retrieving_result = await retrieve(problem_text)
retrieving_result, subjects, units, concepts = await retrieve(problem_text)
question = await augment(retrieving_result, problem_text)
answer = await generate(question)

return JSONResponse(content={"message": "Problem Analysis Finished Successfully", "answer": answer})
return JSONResponse(content={"message": "Problem Analysis Finished Successfully",
"subject": list(set(subjects)),
"unit": list(set(units)),
"key_concept": list(set(concepts)),
"answer": answer})


@app.get("/analysis/ocr")
Expand All @@ -174,12 +180,13 @@ async def ocr(problem_url: str):
import time
import json

logger.info("Analyzing problem from this image URL: %s", problem_url)
try:
dt3 = datetime.fromtimestamp(time.time())
s3_key = parse_s3_url(problem_url)
img_bytes = download_image_from_s3(s3_key) # download from S3
extension = s3_key.split(".")[-1]
logger.info("Completed Download & Sending Requests... '%s'", s3_key)
dt4 = datetime.fromtimestamp(time.time())
logger.info(f"{dt3}~{dt4}: 이미지 다운로드 완료")

clova_api_url = ssm_client.get_parameter(
Name='/ono/fastapi/CLOVA_API_URL',
Expand Down Expand Up @@ -211,18 +218,18 @@ async def ocr(problem_url: str):
files = [
('file', image_file)
]
logger.info("Processing OCR & Receiving Responses...")

dt3 = datetime.fromtimestamp(time.time())
ocr_response = requests.request("POST", clova_api_url, headers=headers, data=payload, files=files).text
ocr_response_json = json.loads(ocr_response)
logger.info("***** Finished OCR Successfully *****")
dt4 = datetime.fromtimestamp(time.time())
logger.info(f"{dt3}~{dt4}: 이미지 OCR 완료")

infer_texts = []
for image in ocr_response_json["images"]:
for field in image["fields"]:
infer_texts.append(field["inferText"])
result = ' '.join(infer_texts)

return result

except Exception as pe:
Expand Down Expand Up @@ -256,14 +263,14 @@ async def upload_curriculum_txt(upload_file: UploadFile = File(...)):

# Mivlus DB 연결
SERVER = os.getenv('SERVER')
logger.info(f"* log >> 환경변수 SERVER: {SERVER}로 받아왔습니다.")
logger.info(f"* log >> 환경변수를 SERVER({SERVER})로 받아왔습니다.")

MILVUS_HOST = ssm_client.get_parameter(
Name=f'/ono/{SERVER}/fastapi/MILVUS_HOST_NAME',
WithDecryption=False
)['Parameter']['Value']
MILVUS_PORT = 19530
COLLECTION_NAME = 'Math2015Curriculum'
COLLECTION_NAME = 'Curriculum2015'
DIMENSION = 1536
INDEX_TYPE = "IVF_FLAT"

Expand All @@ -272,13 +279,16 @@ async def upload_curriculum_txt(upload_file: UploadFile = File(...)):
async def connect_milvus():
try:
# Milvus 서버 연결
connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT)
logger.info(f"* log >> Milvus Server is connected to {MILVUS_HOST}:{MILVUS_PORT}")
dt1 = str(datetime.fromtimestamp(time.time()))
connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT) # server 용
# connections.connect(host="127.0.0.1", port=19530, db="default") # localhost 용
dt2 = str(datetime.fromtimestamp(time.time()))
logger.info(f"{dt1} ~ {dt2}: Milvus 서버 {MILVUS_HOST}:{MILVUS_PORT}에 연결 완료")

# 컬렉션의 스키마 출력
if utility.has_collection(COLLECTION_NAME):
collection = Collection(COLLECTION_NAME)
logger.info("* 존재하는 Collection Schema:")
logger.info(f"* 존재하는 Collection {COLLECTION_NAME} Schema:")
for field in collection.schema.fields:
logger.info(f" - Field Name: {field.name}, Data Type #: {field.dtype}")

Expand All @@ -297,18 +307,16 @@ async def create_milvus():
FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name='content', dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name='content_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION),
FieldSchema(name='subject_name', dtype=DataType.VARCHAR, max_length=100), # Meta Data1
FieldSchema(name='unit_name', dtype=DataType.VARCHAR, max_length=100), # Meta Data2
FieldSchema(name='main_concept', dtype=DataType.VARCHAR, max_length=100), # Meta Data3
]
schema = CollectionSchema(fields=fields, description='Math2015Curriculum embedding collection')
schema = CollectionSchema(fields=fields, description='2015 Korean High School Curriculum Collection')
collection = Collection(name=COLLECTION_NAME, schema=schema)
logger.info(f"* log >> Collection [{COLLECTION_NAME}] is created.")
logger.info(f"* log >> New Collection [{COLLECTION_NAME}] is created.")

# 인덱스 생성
# 스칼라 인덱스
collection.create_index(
field_name="id"
)
# 벡터 인덱스
index_params = {
index_params = { # 벡터 인덱스
'index_type': INDEX_TYPE,
'metric_type': 'COSINE',
'params': {
Expand Down Expand Up @@ -342,15 +350,15 @@ def get_embedding(client, text_list):


@app.get("/milvus/insert")
async def insert_curriculum_embeddings():
async def insert_curriculum_embeddings(subject: str):
""" s3에서 교과과정을 읽고 임베딩하여 Milvus에 삽입 """
# Milvus 연결
await connect_milvus()
collection = Collection(COLLECTION_NAME)

# S3 내 커리큘럼 데이터 로드
texts = []
prefix = 'curriculum/science2015/' # 경로
texts, subject_names, unit_names, main_concepts = [], [], [], []
prefix = f'curriculum/{subject}2015/' # 경로
try:
# 버킷에서 파일 목록 가져오기
s3_curriculum_response = s3_client.list_objects_v2(Bucket=BUCKET_NAME, Prefix=prefix)
Expand All @@ -360,20 +368,34 @@ async def insert_curriculum_embeddings():
# S3 객체 가져오기
obj = s3_client.get_object(Bucket=BUCKET_NAME, Key=s3_key)
# 텍스트 읽기
text = obj['Body'].read().decode('utf-8')
data = obj['Body'].read().decode('utf-8')
lines = data.splitlines()
# 메타 데이터 추출
meatdata_lines = [line.strip('#').strip() for line in lines[:3]]
subject_name = meatdata_lines[0]
subject_names.append(subject_name)
unit_name = meatdata_lines[1]
unit_names.append(unit_name)
main_concept = meatdata_lines[2]
main_concepts.append(main_concept)
# 교과과정 내용 추출
text = '\n'.join(lines[3:]).strip()
texts.append(text)
logger.info(f"* log >> read {len(texts)} texts from S3")
except Exception as e:
logger.error(f"Error reading curriculum from S3: {e}")

# 데이터 임베딩
# 교과과정 내용 임베딩
content_embeddings = get_embedding(openai_client, texts)
logger.info(f"* log >> embedding 완료. dimension: {DIMENSION}")

# 데이터 삽입
data = [
texts, # content 필드
content_embeddings # content_embedding 필드
content_embeddings, # content_embedding 필드
subject_names, # subject_name 필드
unit_names, # unit_name 필드
main_concepts # main_concept 필드
]
collection.insert(data)

Expand All @@ -390,66 +412,80 @@ async def retrieve(problem_text: str):

# 검색 테스트
query = problem_text
dt5 = str(datetime.fromtimestamp(time.time()))
query_embeddings = [get_embedding(openai_client, [query])]
if not query_embeddings or query_embeddings[0] is None:
raise ValueError("Embedding generation failed")
logger.info(f"* log >> Query embedding 완료")
dt6 = str(datetime.fromtimestamp(time.time()))
logger.info(f"{dt5} ~ {dt6}: 쿼리 임베딩 완료")

search_params = {
'metric_type': 'COSINE',
'params': {
'probe': 20
},
}
dt5 = str(datetime.fromtimestamp(time.time()))
results = collection.search(
data=query_embeddings[0],
anns_field='content_embedding',
param=search_params,
limit=3,
expr=None,
output_fields=['content']
output_fields=['content', 'subject_name', 'unit_name', 'main_concept']
)
dt6 = str(datetime.fromtimestamp(time.time()))
context = ' '.join([result.entity.get('content') for result in results[0]])
logger.info(f"* log >> context found")

# 결과 확인
'''logger.info(f"* log >> 쿼리 결과")
subjects_list = [result.entity.get('subject_name') for result in results[0]]
unit_list = [result.entity.get('unit_name') for result in results[0]]
main_concept_list = [result.entity.get('main_concept') for result in results[0]]
logger.info(f"{dt5} ~ {dt6}: 검색 완료")
logs = ""
for result in results[0]:
logger.info("\n-------------------------------------------------------------------")
logger.info(f"Score : {result.distance}, \nText : \n{result.entity.get('content')}")'''

return context
logs += ("\n"+f"Score : {result.distance}, \
\nInfo: {result.entity.get('subject_name')}\
> {result.entity.get('unit_name')}\
> {result.entity.get('main_concept')}, \
\nText : {result.entity.get('content')}"+"\n\n")
logger.info(f"* log >> 검색 결과: {logs}")

return context, subjects_list, unit_list, main_concept_list
except Exception as e:
logger.error(f"Error in search: {e}")
raise HTTPException(status_code=500, detail=str(e))

@app.get("/analysis/augmentation")
async def augment(curriculum_context, query):
prompt = "교과과정에 기반하여 이 문제에 필요한 개념을 말해줘. 응답은 자연어처럼 제공해줘. \n"
context = curriculum_context
passage = query
prompt = "너는 공책이야. 내가 준 교육과정 중에 아래 문제와 관련된 교육과정을 골라서 고등학생 한 명에게\
이 문제를 왜 틀리거나 헷갈릴 수 있으며, 어떤 개념과 사고과정 등이 필요해 보이는지를 \
이 문제의 의도와 핵심을 짚어가되 너무 길지 않게 정리해서 고등학생에게 보여줘.\n"
context = f"교과과정은 이렇고 {curriculum_context}"
passage = f"문제는 이러해 {query}."
augmented_query = prompt + context + passage
return augmented_query

@app.get("/analysis/generation")
async def generate(question):
def get_chatgpt_response(client, question, model="gpt-4o-mini"):
try:
dt7 = str(datetime.fromtimestamp(time.time()))
gpt_response = client.chat.completions.create(
model=model,
messages=[
{"role": "user",
"content": question
}
],
temperature=0.5
temperature=0.7
)
dt8 = str(datetime.fromtimestamp(time.time()))
logger.info(f"{dt7} ~ {dt8}: LLM 응답 완료")
return gpt_response.choices[0].message.content
except Exception as e:
logger.info(f"Error during GPT querying: {e}")
return None

chatgpt_response = get_chatgpt_response(openai_client, question)
logger.info(f"* log >> ChatGPT Response: {chatgpt_response}")
logger.info(f"* log >> 응답 결과 \n {chatgpt_response}")
return chatgpt_response

0 comments on commit a30e538

Please sign in to comment.