From 28a82cd8915d952eb28cadc3ac316bef377a8078 Mon Sep 17 00:00:00 2001 From: marcobebway Date: Wed, 11 Dec 2024 11:49:32 +0100 Subject: [PATCH] chore: use structured output in Kyma RAG reranker --- src/rag/reranker/prompt.py | 4 ---- src/rag/reranker/reranker.py | 9 +++------ tests/unit/rag/reranker/test_reranker.py | 8 ++------ 3 files changed, 5 insertions(+), 16 deletions(-) diff --git a/src/rag/reranker/prompt.py b/src/rag/reranker/prompt.py index b8aeaac8..70a13ac5 100644 --- a/src/rag/reranker/prompt.py +++ b/src/rag/reranker/prompt.py @@ -16,7 +16,6 @@ 2. Compare the queries to each document by considering keywords or semantic meaning. 3. Determine a relevance score for each document with respect to the queries (higher score means more relevant). 4. Avoid documents that are irrelevant to the queries with very low relevance scores. -5. Return documents in the order of relevance in a well structured JSON format. @@ -24,8 +23,5 @@ 2. Use the documents and your expertise to decide what to keep, how to rank, and which documents to remove. 3. Order the documents based on their relevance score to the queries (top document is the most relevant). 4. Restrict the output to the top {limit} most relevant documents to the queries. -5. Return the documents in the order of relevance in the following JSON format: -{format_instructions} -6. Ensure that the response is a well-structured and valid JSON. """ diff --git a/src/rag/reranker/reranker.py b/src/rag/reranker/reranker.py index 255b7457..62226068 100644 --- a/src/rag/reranker/reranker.py +++ b/src/rag/reranker/reranker.py @@ -1,7 +1,6 @@ from typing import Protocol from langchain_core.documents import Document -from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import PromptTemplate from pydantic import BaseModel @@ -45,11 +44,9 @@ class LLMReranker(IReranker): def __init__(self, model: IModel): """Initialize the reranker.""" - reranked_docs_parser = PydanticOutputParser(pydantic_object=RerankedDocs) - prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE).partial( - format_instructions=reranked_docs_parser.get_format_instructions() - ) - self.chain = prompt | model.llm | reranked_docs_parser + prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE) + self.chain = prompt | model.llm.with_structured_output(RerankedDocs) + logger.info("Reranker initialized") async def arerank( self, diff --git a/tests/unit/rag/reranker/test_reranker.py b/tests/unit/rag/reranker/test_reranker.py index 1240d721..b060a0e7 100644 --- a/tests/unit/rag/reranker/test_reranker.py +++ b/tests/unit/rag/reranker/test_reranker.py @@ -2,7 +2,6 @@ import pytest from langchain_core.documents import Document -from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import PromptTemplate from rag.reranker.prompt import RERANKER_PROMPT_TEMPLATE @@ -44,11 +43,8 @@ def test_init(self): # When reranker = LLMReranker(model=mock_model) - reranked_docs_parser = PydanticOutputParser(pydantic_object=RerankedDocs) - prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE).partial( - format_instructions=reranked_docs_parser.get_format_instructions() - ) - expected_chain = prompt | mock_model.llm | reranked_docs_parser + prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE) + expected_chain = prompt | mock_model.llm.with_structured_output(RerankedDocs) # Then assert reranker is not None