Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: use structured output in Kyma RAG reranker #308

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/rag/reranker/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@
2. Compare the queries to each document by considering keywords or semantic meaning.
3. Determine a relevance score for each document with respect to the queries (higher score means more relevant).
4. Avoid documents that are irrelevant to the queries with very low relevance scores.
5. Return documents in the order of relevance in a well structured JSON format.
</your-tasks>

<additional-rules>
1. Do not make up or invent any new documents. Only use the documents from the provided ranking.
2. Use the documents and your expertise to decide what to keep, how to rank, and which documents to remove.
3. Order the documents based on their relevance score to the queries (top document is the most relevant).
4. Restrict the output to the top {limit} most relevant documents to the queries.
5. Return the documents in the order of relevance in the following JSON format:
{format_instructions}
6. Ensure that the response is a well-structured and valid JSON.
</additional-rules>
"""
9 changes: 3 additions & 6 deletions src/rag/reranker/reranker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Protocol

from langchain_core.documents import Document
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel

Expand Down Expand Up @@ -45,11 +44,9 @@ class LLMReranker(IReranker):

def __init__(self, model: IModel):
"""Initialize the reranker."""
reranked_docs_parser = PydanticOutputParser(pydantic_object=RerankedDocs)
prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE).partial(
format_instructions=reranked_docs_parser.get_format_instructions()
)
self.chain = prompt | model.llm | reranked_docs_parser
prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE)
self.chain = prompt | model.llm.with_structured_output(RerankedDocs)
logger.info("Reranker initialized")

async def arerank(
self,
Expand Down
8 changes: 2 additions & 6 deletions tests/unit/rag/reranker/test_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
from langchain_core.documents import Document
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate

from rag.reranker.prompt import RERANKER_PROMPT_TEMPLATE
Expand Down Expand Up @@ -44,11 +43,8 @@ def test_init(self):
# When
reranker = LLMReranker(model=mock_model)

reranked_docs_parser = PydanticOutputParser(pydantic_object=RerankedDocs)
prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE).partial(
format_instructions=reranked_docs_parser.get_format_instructions()
)
expected_chain = prompt | mock_model.llm | reranked_docs_parser
prompt = PromptTemplate.from_template(RERANKER_PROMPT_TEMPLATE)
expected_chain = prompt | mock_model.llm.with_structured_output(RerankedDocs)

# Then
assert reranker is not None
Expand Down
Loading