From d3d779afadc4ba6b033d8bbe8ccd1fcd6ac63b00 Mon Sep 17 00:00:00 2001 From: Mansur Uralov Date: Wed, 5 Feb 2025 13:17:18 +0100 Subject: [PATCH] Refactor chain invocation and response handling for mypy - Update ainvoke_chain to use RunnableSequence and RunnableConfig - Modify response handling in supervisor and reranker to use model instantiation - Adjust type hints and return types for chain invocation utility --- src/agents/supervisor/agent.py | 3 ++- src/rag/reranker/reranker.py | 3 ++- src/utils/chain.py | 10 +++++----- tests/unit/utils/test_chain.py | 3 ++- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/agents/supervisor/agent.py b/src/agents/supervisor/agent.py index 640528ad..fd145c4b 100644 --- a/src/agents/supervisor/agent.py +++ b/src/agents/supervisor/agent.py @@ -162,12 +162,13 @@ async def _invoke_planner(self, state: SupervisorState) -> Plan: ) reduces_messages = filter_messages(filtered_messages) - plan: Plan = await ainvoke_chain( + response = await ainvoke_chain( self._planner_chain, { "messages": reduces_messages, }, ) + plan = Plan(**response) return plan async def _plan(self, state: SupervisorState) -> dict[str, Any]: diff --git a/src/rag/reranker/reranker.py b/src/rag/reranker/reranker.py index 710bc0cf..82861ee1 100644 --- a/src/rag/reranker/reranker.py +++ b/src/rag/reranker/reranker.py @@ -93,7 +93,7 @@ async def _chain_ainvoke( """ # reranking using the LLM model - response: RerankedDocs = await ainvoke_chain( + response_dict = await ainvoke_chain( self.chain, { "documents": format_documents(docs), @@ -101,6 +101,7 @@ async def _chain_ainvoke( "limit": limit, }, ) + response = RerankedDocs(**response_dict) # return reranked documents capped at the output limit reranked_docs = [ Document(page_content=doc.page_content) diff --git a/src/utils/chain.py b/src/utils/chain.py index 88685740..5a9f3214 100644 --- a/src/utils/chain.py +++ b/src/utils/chain.py @@ -1,7 +1,7 @@ import logging from typing import Any -from langchain.chains.base import Chain +from langchain.schema.runnable import RunnableConfig, RunnableSequence from tenacity import ( RetryCallState, retry, @@ -34,11 +34,11 @@ def after_log(retry_state: RetryCallState) -> None: reraise=True, ) async def ainvoke_chain( - chain: Chain, + chain: RunnableSequence, inputs: dict[str, Any] | Any, *, - config: dict[str, Any] | None = None, -) -> dict[str, Any]: + config: RunnableConfig | None = None, +) -> Any: """Invokes a LangChain chain asynchronously. Retries the LLM calls if they fail with the provided wait strategy. Tries 3 times, waits 2 seconds between attempts, i.e. 2, 5. @@ -52,7 +52,7 @@ async def ainvoke_chain( Defaults to None. Returns: - Dict[str, Any]: The chain execution results + Any: The chain execution results """ # Convert single value input to dict if needed chain_inputs = inputs if isinstance(inputs, dict) else {"input": inputs} diff --git a/tests/unit/utils/test_chain.py b/tests/unit/utils/test_chain.py index 4f906a6f..6d2a667d 100644 --- a/tests/unit/utils/test_chain.py +++ b/tests/unit/utils/test_chain.py @@ -3,6 +3,7 @@ import pytest from langchain.chains.base import Chain +from langchain.schema.runnable import RunnableConfig from utils.chain import ainvoke_chain @@ -80,7 +81,7 @@ def mock_chain(): async def test_ainvoke_chain( mock_chain, input_data: dict[str, Any] | str, - config: dict[str, Any] | None, + config: RunnableConfig | None, mock_response: Any, expected_chain_input: dict[str, Any], expected_output: dict[str, Any] | None,