Skip to content

Commit

Permalink
changed to summarization node
Browse files Browse the repository at this point in the history
  • Loading branch information
mfaizanse committed Jan 10, 2025
1 parent fbbfc00 commit 72f0625
Show file tree
Hide file tree
Showing 13 changed files with 213 additions and 268 deletions.
67 changes: 26 additions & 41 deletions src/agents/common/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,33 @@
from langchain_core.embeddings import Embeddings
from langchain_core.messages import (
AIMessage,
BaseMessage,
RemoveMessage,
ToolMessage,
)
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.config import RunnableConfig
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode

from agents.common.constants import IS_LAST_STEP, MESSAGES, MY_TASK, OWNER
from agents.common.constants import (
AGENT_MESSAGES,
IS_LAST_STEP,
MESSAGES,
MY_TASK,
)
from agents.common.state import BaseAgentState, SubTaskStatus
from agents.common.utils import filter_messages
from utils.models.factory import IModel


def subtask_selector_edge(state: BaseAgentState) -> Literal["agent", "__end__"]:
"""Function that determines whether to end or call agent."""
"""Function that determines whether to finalize or call agent."""
if state.is_last_step and state.my_task is None:
return "__end__"
return "finalizer"
return "agent"


def agent_edge(state: BaseAgentState) -> Literal["tools", "finalizer"]:
"""Function that determines whether to call tools or finalizer."""
last_message = state.messages[-1]
last_message = state.agent_messages[-1]
if isinstance(last_message, AIMessage) and not last_message.tool_calls:
return "finalizer"
return "tools"
Expand Down Expand Up @@ -72,28 +75,11 @@ def agent_node(self) -> Any:
"""Get agent node function."""
return self.graph

def is_internal_message(self, message: BaseMessage) -> bool:
"""Check if the message is an internal message."""
# if the message is a tool call and the owner is the agent, return True.
if (
message.additional_kwargs is not None
and OWNER in message.additional_kwargs
and message.additional_kwargs[OWNER] == self.name
and message.tool_calls # type: ignore
):
return True

# if the message is a tool message and the tool is in the agent's tools, return True.
tool_names = [tool.name for tool in self.tools]
if isinstance(message, ToolMessage) and message.name in tool_names:
return True
return False

def _create_chain(self, system_prompt: str) -> Any:
agent_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name=MESSAGES),
MessagesPlaceholder(variable_name=AGENT_MESSAGES),
("human", "{query}"),
]
)
Expand All @@ -115,7 +101,7 @@ def _subtask_selector_node(self, state: BaseAgentState) -> dict[str, Any]:

return {
IS_LAST_STEP: True,
MESSAGES: [
AGENT_MESSAGES: [
AIMessage(
content="All my subtasks are already completed.",
name=self.name,
Expand All @@ -125,9 +111,12 @@ def _subtask_selector_node(self, state: BaseAgentState) -> dict[str, Any]:

async def _invoke_chain(self, state: BaseAgentState, config: RunnableConfig) -> Any:
inputs = {
MESSAGES: state.messages,
AGENT_MESSAGES: filter_messages(state.agent_messages),
"query": state.my_task.description,
}
if len(state.agent_messages) == 0:
inputs[AGENT_MESSAGES] = filter_messages(state.messages)

response = await self.chain.ainvoke(inputs, config)
return response

Expand All @@ -138,7 +127,7 @@ async def _model_node(
response = await self._invoke_chain(state, config)
except Exception as e:
return {
MESSAGES: [
AGENT_MESSAGES: [
AIMessage(
content=f"Sorry, I encountered an error while processing the request. Error: {e}",
name=self.name,
Expand All @@ -154,7 +143,7 @@ async def _model_node(
and response.tool_calls
):
return {
MESSAGES: [
AGENT_MESSAGES: [
AIMessage(
content="Sorry, I need more steps to process the request.",
name=self.name,
Expand All @@ -163,20 +152,14 @@ async def _model_node(
}

response.additional_kwargs["owner"] = self.name
return {MESSAGES: [response]}
return {AGENT_MESSAGES: [response]}

def _finalizer_node(self, state: BaseAgentState, config: RunnableConfig) -> Any:
"""Finalizer node will mark the task as completed and clean-up extra messages."""
state.my_task.complete()
"""Finalizer node will mark the task as completed."""
if state.my_task is not None:
state.my_task.complete()
# clean all agent messages to avoid populating the checkpoint with unnecessary messages.
return {
MESSAGES: [
RemoveMessage(id=m.id) # type: ignore
for m in state.messages
if self.is_internal_message(m)
],
MY_TASK: None,
}
return {MESSAGES: [state.agent_messages[-1]]}

def _build_graph(self, state_class: type) -> Any:
# Define a new graph
Expand All @@ -185,7 +168,9 @@ def _build_graph(self, state_class: type) -> Any:
# Define nodes with async awareness
workflow.add_node("subtask_selector", self._subtask_selector_node)
workflow.add_node("agent", self._model_node)
workflow.add_node("tools", ToolNode(self.tools))
workflow.add_node(
"tools", ToolNode(tools=self.tools, messages_key=AGENT_MESSAGES)
)
workflow.add_node("finalizer", self._finalizer_node)

# Set the entrypoint: ENTRY --> subtask_selector
Expand Down
6 changes: 6 additions & 0 deletions src/agents/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

COMMON = "Common"

SUMMARIZATION = "Summarization"

FINALIZER = "Finalizer"

EXIT = "Exit"
Expand All @@ -12,6 +14,10 @@

MESSAGES = "messages"

MESSAGES_SUMMARY = "messages_summary"

AGENT_MESSAGES = "agent_messages"

ERROR = "error"

NEXT = "next"
Expand Down
16 changes: 12 additions & 4 deletions src/agents/common/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from enum import Enum
from typing import Annotated, Literal

from langchain_core.messages import BaseMessage
from langchain_core.messages import BaseMessage, SystemMessage
from langgraph.graph import add_messages
from langgraph.managed import IsLastStep
from pydantic import BaseModel, Field

from agents.common.constants import COMMON, K8S_AGENT, K8S_CLIENT, KYMA_AGENT
from agents.reducer.reducers import new_default_summarization_reducer
from services.k8s import IK8sClient


Expand Down Expand Up @@ -92,12 +92,19 @@ class CompanionState(BaseModel):
default=None,
)

messages: Annotated[Sequence[BaseMessage], new_default_summarization_reducer()]
messages: Annotated[Sequence[BaseMessage], add_messages]
messages_summary: str = ""
next: str | None = None
subtasks: list[SubTask] | None = []
error: str | None = None
k8s_client: IK8sClient | None = None

def get_messages_including_summary(self) -> list[BaseMessage]:
"""Get messages including the summary message."""
if self.messages_summary:
return [SystemMessage(content=self.messages_summary)] + self.messages
return self.messages

def all_tasks_completed(self) -> bool:
"""Check if all the sub-tasks are completed."""
return all(task.status == SubTaskStatus.COMPLETED for task in self.subtasks)
Expand All @@ -110,11 +117,12 @@ class Config:
class BaseAgentState(BaseModel):
"""Base state for KymaAgent and KubernetesAgent agents (subgraphs)."""

messages: Annotated[Sequence[BaseMessage], new_default_summarization_reducer()]
messages: Annotated[Sequence[BaseMessage], add_messages]
subtasks: list[SubTask] | None = []
k8s_client: IK8sClient

# Subgraph private fields
agent_messages: Annotated[Sequence[BaseMessage], add_messages]
my_task: SubTask | None = None
is_last_step: IsLastStep

Expand Down
16 changes: 16 additions & 0 deletions src/agents/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence
from typing import Any, Literal

import tiktoken
from gen_ai_hub.proxy.langchain import ChatOpenAI
from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage
from langchain_core.prompts import MessagesPlaceholder
from langgraph.graph.message import Messages

from agents.common.constants import (
CONTINUE,
Expand All @@ -18,6 +20,7 @@
)
from agents.common.state import CompanionState, SubTask, SubTaskStatus
from utils.logging import get_logger
from utils.models.factory import ModelType

logger = get_logger(__name__)

Expand Down Expand Up @@ -123,3 +126,16 @@ def create_node_output(
SUBTASKS: subtasks,
ERROR: error,
}


def compute_string_token_count(text: str, model_type: ModelType) -> int:
"""Returns the token count of the string."""
return len(tiktoken.encoding_for_model(model_type).encode(text=text))


def compute_messages_token_count(msgs: Messages, model_type: ModelType) -> int:
"""Returns the token count of the messages."""
tokens_per_msg = (
compute_string_token_count(str(msg.content), model_type) for msg in msgs
)
return sum(tokens_per_msg)
68 changes: 63 additions & 5 deletions src/agents/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
AIMessage,
BaseMessage,
HumanMessage,
RemoveMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableSequence
from langchain_core.runnables import RunnableConfig, RunnableSequence
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import END
from langgraph.graph import StateGraph
Expand All @@ -26,17 +27,24 @@
from agents.common.constants import (
COMMON,
MESSAGES,
MESSAGES_SUMMARY,
SUMMARIZATION,
)
from agents.common.data import Message
from agents.common.state import CompanionState, Plan, SubTask, UserInput
from agents.k8s.agent import K8S_AGENT, KubernetesAgent
from agents.kyma.agent import KYMA_AGENT, KymaAgent
from agents.prompts import COMMON_QUESTION_PROMPT
from agents.summarization.summarization import Summarization
from agents.supervisor.agent import SUPERVISOR, SupervisorAgent
from services.k8s import IK8sClient
from utils.langfuse import handler
from utils.logging import get_logger
from utils.models.factory import IModel, ModelType
from utils.settings import (
SUMMARIZATION_TOKEN_LOWER_LIMIT,
SUMMARIZATION_TOKEN_UPPER_LIMIT,
)

logger = get_logger(__name__)

Expand All @@ -49,7 +57,13 @@ class CustomJSONEncoder(json.JSONEncoder):

def default(self, obj): # noqa D102
if isinstance(
obj, AIMessage | HumanMessage | SystemMessage | ToolMessage | SubTask
obj,
RemoveMessage
| AIMessage
| HumanMessage
| SystemMessage
| ToolMessage
| SubTask,
):
return obj.__dict__
elif isinstance(obj, IK8sClient):
Expand Down Expand Up @@ -102,6 +116,13 @@ def __init__(
members=[KYMA_AGENT, K8S_AGENT, COMMON],
)

self.summarization = Summarization(
model=gpt_4o_mini,
tokenizer_model_type=ModelType.GPT4O,
token_lower_limit=SUMMARIZATION_TOKEN_LOWER_LIMIT,
token_upper_limit=SUMMARIZATION_TOKEN_UPPER_LIMIT,
)

self.members = [self.kyma_agent.name, self.k8s_agent.name, COMMON]
self._common_chain = self._create_common_chain(cast(IModel, gpt_4o_mini))
self.graph = self._build_graph()
Expand All @@ -123,7 +144,7 @@ async def _invoke_common_node(self, state: CompanionState, subtask: str) -> str:
"""Invoke the common node."""
response = await self._common_chain.ainvoke(
{
"messages": state.messages,
"messages": state.get_messages_including_summary(),
"query": subtask,
},
)
Expand Down Expand Up @@ -166,6 +187,41 @@ async def _common_node(self, state: CompanionState) -> dict[str, Any]:
]
}

async def _summarization_node(
self, state: CompanionState, config: RunnableConfig
) -> dict[str, Any]:
"""Summarization node to summarize the conversation."""
all_messages = [SystemMessage(content=state.messages_summary)] + state.messages

token_count = self.summarization.get_messages_token_count(all_messages)
if token_count <= self.summarization.get_token_upper_limit():
return {
MESSAGES: [],
}

# filter out messages that can be kept within the token limit.
latest_messages = self.summarization.filter_messages_by_token_limit(
all_messages
)

if len(latest_messages) == len(all_messages):
return {
MESSAGES: [],
}

# summarize the remaining old messages
msgs_for_summarization = all_messages[: -len(latest_messages)]
summary = self.summarization.get_summary(msgs_for_summarization, config)

# remove excluded messages from state.
msgs_to_remove = state.messages[: -len(latest_messages)]
delete_messages = [RemoveMessage(id=m.id) for m in msgs_to_remove]

return {
MESSAGES_SUMMARY: summary,
MESSAGES: delete_messages,
}

def _build_graph(self) -> CompiledGraph:
"""Create the companion parent graph."""

Expand All @@ -177,14 +233,16 @@ def _build_graph(self) -> CompiledGraph:
workflow.add_node(KYMA_AGENT, self.kyma_agent.agent_node())
workflow.add_node(K8S_AGENT, self.k8s_agent.agent_node())
workflow.add_node(COMMON, self._common_node)
workflow.add_node(SUMMARIZATION, self._summarization_node)

# Set the entrypoint: ENTRY --> supervisor
workflow.set_entry_point(SUPERVISOR)

# Define the edges: (KymaAgent | KubernetesAgent | Common) --> supervisor
# The agents ALWAYS "report back" to the supervisor.
# The agents ALWAYS "report back" to the supervisor through summarization node.
for member in self.members:
workflow.add_edge(member, SUPERVISOR)
workflow.add_edge(member, SUMMARIZATION)
workflow.add_edge(SUMMARIZATION, SUPERVISOR)

# The supervisor dynamically populates the "next" field in the graph.
conditional_map: dict[Hashable, str] = {k: k for k in self.members + [END]}
Expand Down
Loading

0 comments on commit 72f0625

Please sign in to comment.