Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: implement retry for plan LLM cal #365

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
930 changes: 917 additions & 13 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ python-decouple = "^3.8"
redis = "^5.0.8"
requests = "^2.32.3"
scrubadub = {extras = ["all"], version = "^2.0.1"}
tenacity = "^9.0.0"
tiktoken = "^0.7.0"

[tool.poetry.group.test.dependencies]
deepeval = "^2.1.2"
deepeval = "^2.3.2"
fakeredis = "^2.23.3"
prettytable = "^3.10.2"
pytest = "^8.2.2"
Expand Down Expand Up @@ -68,6 +69,9 @@ pythonpath = [
testpaths = [
"tests",
]
env_files = [
".env.test"
]

[tool.poe.tasks]
lint = "ruff check ."
Expand Down
18 changes: 15 additions & 3 deletions src/agents/common/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from agents.common.constants import (
AGENT_MESSAGES,
AGENT_MESSAGES_SUMMARY,
ERROR,
IS_LAST_STEP,
MESSAGES,
MY_TASK,
Expand All @@ -20,12 +21,16 @@
from agents.common.state import BaseAgentState, SubTaskStatus
from agents.common.utils import filter_messages
from agents.summarization.summarization import Summarization
from utils.chain import ainvoke_chain
from utils.logging import get_logger
from utils.models.factory import IModel, ModelType
from utils.settings import (
SUMMARIZATION_TOKEN_LOWER_LIMIT,
SUMMARIZATION_TOKEN_UPPER_LIMIT,
)

logger = get_logger(__name__)


def subtask_selector_edge(state: BaseAgentState) -> Literal["agent", "finalizer"]:
"""Function that determines whether to finalize or call agent."""
Expand Down Expand Up @@ -126,7 +131,11 @@ async def _invoke_chain(self, state: BaseAgentState, config: RunnableConfig) ->
if len(state.agent_messages) == 0:
inputs[AGENT_MESSAGES] = filter_messages(state.messages)

response = await self.chain.ainvoke(inputs, config)
response = await ainvoke_chain(
self.chain,
inputs,
config=config,
)
return response

async def _model_node(
Expand All @@ -135,13 +144,16 @@ async def _model_node(
try:
response = await self._invoke_chain(state, config)
except Exception as e:
error_message = f"An error occurred while processing the request: {e}"
logger.error(error_message)
return {
AGENT_MESSAGES: [
AIMessage(
content=f"Sorry, I encountered an error while processing the request. Error: {e}",
content=f"Sorry, {error_message}",
name=self.name,
)
]
],
ERROR: error_message,
}

# if the recursive limit is reached and the response is a tool call, return a message.
Expand Down
6 changes: 6 additions & 0 deletions src/agents/common/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class SubtasksMissingError(Exception):
"""Exception raised when no subtasks are created for the given query."""

def __init__(self, query: str):
self.query = query
super().__init__(f"Subtasks are missing for the given query: {query}")
4 changes: 3 additions & 1 deletion src/agents/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from agents.summarization.summarization import Summarization
from agents.supervisor.agent import SUPERVISOR, SupervisorAgent
from services.k8s import IK8sClient
from utils.chain import ainvoke_chain
from utils.logging import get_logger
from utils.models.factory import IModel, ModelType
from utils.settings import (
Expand Down Expand Up @@ -147,7 +148,8 @@ def _create_common_chain(model: IModel) -> RunnableSequence:

async def _invoke_common_node(self, state: CompanionState, subtask: str) -> str:
"""Invoke the common node."""
response = await self._common_chain.ainvoke(
response = await ainvoke_chain(
self._common_chain,
{
"messages": state.get_messages_including_summary(),
"query": subtask,
Expand Down
4 changes: 3 additions & 1 deletion src/agents/summarization/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from agents.common.utils import compute_messages_token_count, compute_string_token_count
from agents.summarization.prompts import MESSAGES_SUMMARIZATION_PROMPT
from utils.chain import ainvoke_chain
from utils.models.factory import IModel, ModelType


Expand Down Expand Up @@ -88,7 +89,8 @@ async def get_summary(
if len(messages) == 0:
return ""

res = await self._chain.ainvoke(
res = await ainvoke_chain(
self._chain,
{"messages": messages},
config=config,
)
Expand Down
74 changes: 35 additions & 39 deletions src/agents/supervisor/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@

from agents.common.constants import (
COMMON,
ERROR,
FINALIZER,
K8S_AGENT,
KYMA_AGENT,
MESSAGES,
NEXT,
PLANNER,
)
from agents.common.exceptions import SubtasksMissingError
from agents.common.response_converter import IResponseConverter, ResponseConverter
from agents.common.state import Plan
from agents.common.utils import create_node_output, filter_messages
Expand All @@ -29,6 +31,7 @@
PLANNER_SYSTEM_PROMPT,
)
from agents.supervisor.state import SupervisorState
from utils.chain import ainvoke_chain
from utils.filter_messages import (
filter_messages_via_checks,
is_ai_message,
Expand Down Expand Up @@ -146,7 +149,7 @@ def _create_planner_chain(self, model: IModel) -> RunnableSequence:
return self.planner_prompt | model.llm.with_structured_output(Plan) # type: ignore

async def _invoke_planner(self, state: SupervisorState) -> Plan:
"""Invoke the planner."""
"""Invoke the planner with retry logic using tenacity."""

filtered_messages = filter_messages_via_checks(
state.messages,
Expand All @@ -159,8 +162,9 @@ async def _invoke_planner(self, state: SupervisorState) -> Plan:
)
reduces_messages = filter_messages(filtered_messages)

plan: Plan = await self._planner_chain.ainvoke(
input={
plan: Plan = await ainvoke_chain(
self._planner_chain,
{
"messages": reduces_messages,
},
)
Expand All @@ -187,23 +191,16 @@ async def _plan(self, state: SupervisorState) -> dict[str, Any]:

# if the Planner did not respond directly but also failed to create any subtasks, raise an exception
if not plan.subtasks:
raise Exception(
f"No subtasks are created for the given query: {state.messages[-1].content}"
)
raise SubtasksMissingError(str(state.messages[-1].content))
# return the plan with the subtasks to be dispatched by the Router
return create_node_output(
next=ROUTER,
subtasks=plan.subtasks,
)
except Exception as e:
logger.error(f"Error in planning: {e}")
except Exception:
logger.exception("Error in planning")
return {
MESSAGES: [
AIMessage(
content=f"Sorry, I encountered an error while processing the request. Error: {e}",
name=PLANNER,
)
]
ERROR: "Unexpected error while processing the request. Please try again later.",
}

def _final_response_chain(self, state: SupervisorState) -> RunnableSequence:
Expand All @@ -226,40 +223,39 @@ async def _generate_final_response(self, state: SupervisorState) -> dict[str, An

final_response_chain = self._final_response_chain(state)

try:
final_response = await final_response_chain.ainvoke(
{"messages": state.messages},
)
final_response = await ainvoke_chain(
final_response_chain,
{"messages": state.messages},
)

return {
MESSAGES: [
AIMessage(
content=final_response.content,
name=FINALIZER,
)
],
NEXT: END,
}

async def _get_converted_final_response(
self, state: SupervisorState
) -> dict[str, Any]:
"""Convert the generated final response."""
try:
final_response = await self._generate_final_response(state)
return self.response_converter.convert_final_response(final_response)
except Exception:
logger.exception("Error in generating final response")
return {
MESSAGES: [
AIMessage(
content=final_response.content,
name=FINALIZER,
)
],
NEXT: END,
}
except Exception as e:
logger.error(f"Error in generating final response: {e}")
return {
MESSAGES: [
AIMessage(
content=f"Sorry, I encountered an error while processing the request. Error: {e}",
content="Sorry, I encountered an error while processing the request. Try again later.",
name=FINALIZER,
)
]
}

async def _get_converted_final_response(
self, state: SupervisorState
) -> dict[str, Any]:
"""Convert the generated final response."""

final_response = await self._generate_final_response(state)

return self.response_converter.convert_final_response(final_response)

def _build_graph(self) -> CompiledGraph:
# Define a new graph.
workflow = StateGraph(SupervisorState)
Expand Down
6 changes: 4 additions & 2 deletions src/rag/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from langchain_core.prompts import PromptTemplate

from rag.prompts import GENERATOR_PROMPT
from utils.chain import ainvoke_chain
from utils.logging import get_logger
from utils.models.factory import IModel

Expand Down Expand Up @@ -33,8 +34,9 @@ async def agenerate(self, relevant_docs: list[Document], query: str) -> str:
# Convert Document objects to a list of their page_content
docs_content = "\n\n".join(doc.page_content for doc in relevant_docs)
try:
response = await self.rag_chain.ainvoke(
{"context": docs_content, "query": query}
response = await ainvoke_chain(
self.rag_chain,
{"context": docs_content, "query": query},
)
except Exception as e:
logger.exception(f"Error generating response for query: {query}")
Expand Down
6 changes: 5 additions & 1 deletion src/rag/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
QUERY_GENERATOR_PROMPT_TEMPLATE,
)
from utils import logging
from utils.chain import ainvoke_chain
from utils.models.factory import IModel

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -57,7 +58,10 @@ def _create_chain(self) -> Any:
async def agenerate_queries(self, query: str) -> Queries:
"""Generate multiple queries based on the input query."""
try:
queries = await self._chain.ainvoke({"query": query})
queries = await ainvoke_chain(
self._chain,
{"query": query},
)
return cast(Queries, queries)
except Exception:
logger.exception(
Expand Down
6 changes: 4 additions & 2 deletions src/rag/reranker/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rag.reranker.prompt import RERANKER_PROMPT_TEMPLATE
from rag.reranker.rrf import get_relevant_documents
from rag.reranker.utils import document_to_str
from utils.chain import ainvoke_chain
from utils.logging import get_logger
from utils.models.factory import IModel

Expand Down Expand Up @@ -92,12 +93,13 @@ async def _chain_ainvoke(
"""

# reranking using the LLM model
response: RerankedDocs = await self.chain.ainvoke(
response: RerankedDocs = await ainvoke_chain(
self.chain,
{
"documents": format_documents(docs),
"queries": format_queries(queries),
"limit": limit,
}
},
)
# return reranked documents capped at the output limit
reranked_docs = [
Expand Down
1 change: 0 additions & 1 deletion src/services/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,4 @@ async def handle_request(
async for chunk in self._companion_graph.astream(
conversation_id, message, k8s_client
):
logger.debug(f"Sending chunk: {chunk}")
yield chunk.encode()
68 changes: 68 additions & 0 deletions src/utils/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
from typing import Any

from langchain.schema.runnable import RunnableConfig, RunnableSequence
from tenacity import (
RetryCallState,
retry,
stop_after_attempt,
wait_incrementing,
)

logger = logging.getLogger(__name__)


def after_log(retry_state: RetryCallState) -> None:
"""Log retry attempts with appropriate log levels.

Args:
retry_state (RetryCallState): Current state of the retry operation
"""
loglevel = logging.INFO if retry_state.attempt_number < 1 else logging.WARNING
logger.log(
loglevel,
"Retrying %s: attempt %s",
f"{retry_state.fn.__module__}.{retry_state.fn.__name__}",
retry_state.attempt_number,
)


@retry(
stop=stop_after_attempt(3),
wait=wait_incrementing(start=2, increment=3),
after=after_log,
reraise=True,
)
async def ainvoke_chain(
chain: RunnableSequence,
inputs: dict[str, Any] | Any,
*,
config: RunnableConfig | None = None,
) -> Any:
"""Invokes a LangChain chain asynchronously.
Retries the LLM calls if they fail with the provided wait strategy.
Tries 3 times, waits 2 seconds between attempts, i.e. 2, 5.
Logs warnings and raises an error.

Args:
chain (Chain): The LangChain chain to invoke
inputs (Union[Dict[str, Any], Any]): Input parameters for the chain. Can be either a dictionary
of inputs or a single value that will be wrapped in a dict with key "input"
config (Optional[Dict[str, Any]], optional): Additional configuration for chain execution.
Defaults to None.

Returns:
Any: The chain execution results
"""
# Convert single value input to dict if needed
chain_inputs = inputs if isinstance(inputs, dict) else {"input": inputs}

logger.debug(f"Invoking chain with inputs: {chain_inputs}")

result = await chain.ainvoke(
input=chain_inputs,
config=config,
)

logger.debug(f"Chain execution completed. Result: {result}")
return result
Loading
Loading