Skip to content

Commit

Permalink
chore: use structured output in Kyma RAG reranker
Browse files Browse the repository at this point in the history
  • Loading branch information
marcobebway committed Dec 11, 2024
1 parent ebe83ae commit 28a82cd
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 16 deletions.
4 changes: 0 additions & 4 deletions src/rag/reranker/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@
2. Compare the queries to each document by considering keywords or semantic meaning.
3. Determine a relevance score for each document with respect to the queries (higher score means more relevant).
4. Avoid documents that are irrelevant to the queries with very low relevance scores.
5. Return documents in the order of relevance in a well structured JSON format.
</your-tasks>
<additional-rules>
1. Do not make up or invent any new documents. Only use the documents from the provided ranking.
2. Use the documents and your expertise to decide what to keep, how to rank, and which documents to remove.
3. Order the documents based on their relevance score to the queries (top document is the most relevant).
4. Restrict the output to the top {limit} most relevant documents to the queries.
5. Return the documents in the order of relevance in the following JSON format:
{format_instructions}
6. Ensure that the response is a well-structured and valid JSON.
</additional-rules>
"""
9 changes: 3 additions & 6 deletions src/rag/reranker/reranker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Protocol

from langchain_core.documents import Document
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel

Expand Down Expand Up @@ -45,11 +44,9 @@ class LLMReranker(IReranker):

def __init__(self, model: IModel):
"""Initialize the reranker."""
reranked_docs_parser = PydanticOutputParser(pydantic_object=RerankedDocs)
prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE).partial(
format_instructions=reranked_docs_parser.get_format_instructions()
)
self.chain = prompt | model.llm | reranked_docs_parser
prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE)
self.chain = prompt | model.llm.with_structured_output(RerankedDocs)
logger.info("Reranker initialized")

async def arerank(
self,
Expand Down
8 changes: 2 additions & 6 deletions tests/unit/rag/reranker/test_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
from langchain_core.documents import Document
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate

from rag.reranker.prompt import RERANKER_PROMPT_TEMPLATE
Expand Down Expand Up @@ -44,11 +43,8 @@ def test_init(self):
# When
reranker = LLMReranker(model=mock_model)

reranked_docs_parser = PydanticOutputParser(pydantic_object=RerankedDocs)
prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE).partial(
format_instructions=reranked_docs_parser.get_format_instructions()
)
expected_chain = prompt | mock_model.llm | reranked_docs_parser
prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE)
expected_chain = prompt | mock_model.llm.with_structured_output(RerankedDocs)

# Then
assert reranker is not None
Expand Down

0 comments on commit 28a82cd

Please sign in to comment.