Skip to content

Commit

Permalink
Refactor chain invocation and response handling for mypy
Browse files Browse the repository at this point in the history
- Update ainvoke_chain to use RunnableSequence and RunnableConfig
- Modify response handling in supervisor and reranker to use model instantiation
- Adjust type hints and return types for chain invocation utility
  • Loading branch information
muralov committed Feb 5, 2025
1 parent a2a2ef9 commit e835699
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/agents/supervisor/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,13 @@ async def _invoke_planner(self, state: SupervisorState) -> Plan:
)
reduces_messages = filter_messages(filtered_messages)

plan: Plan = await ainvoke_chain(
response = await ainvoke_chain(
self._planner_chain,
{
"messages": reduces_messages,
},
)
plan = Plan(**response)
return plan

async def _plan(self, state: SupervisorState) -> dict[str, Any]:
Expand Down
3 changes: 2 additions & 1 deletion src/rag/reranker/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ async def _chain_ainvoke(
"""

# reranking using the LLM model
response: RerankedDocs = await ainvoke_chain(
response_dict = await ainvoke_chain(
self.chain,
{
"documents": format_documents(docs),
"queries": format_queries(queries),
"limit": limit,
},
)
response = RerankedDocs(**response_dict)
# return reranked documents capped at the output limit
reranked_docs = [
Document(page_content=doc.page_content)
Expand Down
10 changes: 5 additions & 5 deletions src/utils/chain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import Any

from langchain.chains.base import Chain
from langchain.schema.runnable import RunnableConfig, RunnableSequence
from tenacity import (
RetryCallState,
retry,
Expand Down Expand Up @@ -34,11 +34,11 @@ def after_log(retry_state: RetryCallState) -> None:
reraise=True,
)
async def ainvoke_chain(
chain: Chain,
chain: RunnableSequence,
inputs: dict[str, Any] | Any,
*,
config: dict[str, Any] | None = None,
) -> dict[str, Any]:
config: RunnableConfig | None = None,
) -> Any:
"""Invokes a LangChain chain asynchronously.
Retries the LLM calls if they fail with the provided wait strategy.
Tries 3 times, waits 2 seconds between attempts, i.e. 2, 5.
Expand All @@ -52,7 +52,7 @@ async def ainvoke_chain(
Defaults to None.
Returns:
Dict[str, Any]: The chain execution results
Any: The chain execution results
"""
# Convert single value input to dict if needed
chain_inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/utils/test_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from langchain.chains.base import Chain
from langchain.schema.runnable import RunnableConfig

from utils.chain import ainvoke_chain

Expand Down Expand Up @@ -80,7 +81,7 @@ def mock_chain():
async def test_ainvoke_chain(
mock_chain,
input_data: dict[str, Any] | str,
config: dict[str, Any] | None,
config: RunnableConfig | None,
mock_response: Any,
expected_chain_input: dict[str, Any],
expected_output: dict[str, Any] | None,
Expand Down

0 comments on commit e835699

Please sign in to comment.