Skip to content

Commit

Permalink
Merge pull request #37 from crestalnetwork/feat/prompt-append
Browse files Browse the repository at this point in the history
Feat: prompt append
  • Loading branch information
taiyangc authored Jan 8, 2025
2 parents 3aefefd + 21311a3 commit 8ea3ccf
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 34 deletions.
104 changes: 73 additions & 31 deletions app/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
from langchain_core.messages import (
HumanMessage,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.graph.graph import CompiledGraph
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.exc import NoResultFound
from sqlmodel import Session
from sqlmodel import Session, select

from abstracts.engine import AgentMessageInput
from abstracts.graph import AgentState
from app.config.config import config
from app.core.graph import create_agent
from app.models.agent import Agent
Expand All @@ -44,6 +46,33 @@
agents: dict[str, CompiledGraph] = {}


def agent_prompt(agent: Agent) -> str:
prompt = ""
if config.system_prompt:
prompt += config.system_prompt + "\n\n"
if agent.name:
prompt += f"Your name is {agent.name}.\n\n"
if agent.prompt:
prompt += agent.prompt
elif agent.cdp_enabled:
prompt += (
"You are a helpful agent that can interact onchain using the Coinbase Developer Platform AgentKit. "
"You are empowered to interact onchain using your tools. If you ever need funds, you can request "
"them from the faucet if you are on network ID 'base-sepolia'. If not, you can provide your wallet "
"details and request funds from the user. Before executing your first action, get the wallet details "
"to see what network you're on. If there is a 5XX (internal) HTTP error code, ask the user to try "
"again later. If someone asks you to do something you can't do with your currently available tools, "
"you must say so, and encourage them to implement it themselves using the CDP SDK + Agentkit, "
"recommend they go to docs.cdp.coinbase.com for more information. Be concise and helpful with your "
"responses. Refrain from restating your tools' descriptions unless it is explicitly requested."
)
if agent.cdp_enabled:
prompt += """\n\nWallet addresses are public information. If someone asks for your default wallet,
current wallet, personal wallet, crypto wallet, or wallet public address, don't use any address in message history,
you must use the "get_wallet_details" tool to retrieve your wallet address every time.\n\n"""
return prompt


def initialize_agent(aid):
"""Initialize an AI agent with specified configuration and tools.
Expand All @@ -68,7 +97,7 @@ def initialize_agent(aid):
with Session(engine) as db:
# get the agent from the database
try:
agent: Agent = db.query(Agent).filter(Agent.id == aid).one()
agent: Agent = db.exec(select(Agent).filter(Agent.id == aid)).one()
except NoResultFound:
# Handle the case where the user is not found
raise HTTPException(status_code=404, detail="Agent not found")
Expand All @@ -83,27 +112,6 @@ def initialize_agent(aid):
# ==== Store buffered conversation history in memory.
memory = PostgresSaver(get_coon())

# ==== Set up prompt
prompt = ""
if config.system_prompt:
prompt += config.system_prompt + "\n\n"
if agent.name:
prompt += f"Your name is {agent.name}.\n\n"
if agent.prompt:
prompt += agent.prompt
elif agent.cdp_enabled:
prompt += (
"You are a helpful agent that can interact onchain using the Coinbase Developer Platform AgentKit. "
"You are empowered to interact onchain using your tools. If you ever need funds, you can request "
"them from the faucet if you are on network ID 'base-sepolia'. If not, you can provide your wallet "
"details and request funds from the user. Before executing your first action, get the wallet details "
"to see what network you're on. If there is a 5XX (internal) HTTP error code, ask the user to try "
"again later. If someone asks you to do something you can't do with your currently available tools, "
"you must say so, and encourage them to implement it themselves using the CDP SDK + Agentkit, "
"recommend they go to docs.cdp.coinbase.com for more information. Be concise and helpful with your "
"responses. Refrain from restating your tools' descriptions unless it is explicitly requested."
)

# ==== Load skills
tools: list[BaseTool] = []

Expand All @@ -126,10 +134,6 @@ def initialize_agent(aid):
# Initialize CDP Agentkit Toolkit and get tools.
cdp_toolkit = CdpToolkit.from_cdp_agentkit_wrapper(agentkit)
tools.extend(cdp_toolkit.get_tools())
# add prompt
prompt += """\n\nWallet addresses are public information. If someone asks for your default wallet,
current wallet, personal wallet, crypto wallet, or wallet public address, don't use any address in message history,
you must use the "get_wallet_details" tool to retrieve your wallet address every time.\n\n"""

# Twitter skills
if (
Expand Down Expand Up @@ -163,12 +167,25 @@ def initialize_agent(aid):
for tool in tools:
logger.info(f"[{aid}] loaded tool: {tool.name}")

# finally, setup the system prompt
prompt = agent_prompt(agent)
prompt_array = [
("system", prompt),
("placeholder", "{messages}"),
]
if agent.prompt_append:
prompt_array.append(("system", agent.prompt_append))
prompt_temp = ChatPromptTemplate.from_messages(prompt_array)

def formatted_prompt(state: AgentState):
return prompt_temp.invoke({"messages": state["messages"]})

# Create ReAct Agent using the LLM and CDP Agentkit tools.
agents[aid] = create_agent(
llm,
tools=tools,
checkpointer=memory,
state_modifier=prompt,
state_modifier=formatted_prompt,
debug=config.debug,
)

Expand All @@ -186,8 +203,9 @@ def execute_agent(
Args:
aid (str): Agent ID
prompt (str): Input prompt for the agent
message (AgentMessageInput): Input message for the agent
thread_id (str): Thread ID for the agent execution
debug (bool): Enable debug mode
Returns:
list[str]: Formatted response lines including timing information
Expand All @@ -201,14 +219,14 @@ def execute_agent(
]
"""
stream_config = {"configurable": {"thread_id": thread_id}}
resp_debug = []
resp_debug = [f"Thread ID: {thread_id}\n\n-------------------\n"]
resp = []
start = time.perf_counter()
last = start

# user input
resp_debug.append(
f"[ Input: ]\n\n {message.text}\n\n{'\n'.join(message.images)}\n\n-------------------\n"
f"[ Input: ]\n\n {message.text}\n{'\n'.join(message.images)}\n-------------------\n"
)

# cold start
Expand All @@ -231,6 +249,29 @@ def execute_agent(
for image_url in message.images
]
)
# debug prompt
if debug:
# get the agent from the database
engine = get_engine()
with Session(engine) as db:
try:
agent: Agent = db.exec(select(Agent).filter(Agent.id == aid)).one()
except NoResultFound:
# Handle the case where the user is not found
raise HTTPException(status_code=404, detail="Agent not found")
except SQLAlchemyError as e:
# Handle other SQLAlchemy-related errors
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
resp_debug_append = "\n===================\n\n[ system ]\n"
resp_debug_append += agent_prompt(agent)
snap = executor.get_state(stream_config)
if snap.values and "messages" in snap.values:
for msg in snap.values["messages"]:
resp_debug_append += f"[ {msg.type} ]\n{msg.content}\n\n"
if agent.prompt_append:
resp_debug_append += "[ system ]\n"
resp_debug_append += agent.prompt_append
# run
for chunk in executor.stream(
{"messages": [HumanMessage(content=content)]}, stream_config
Expand Down Expand Up @@ -258,6 +299,7 @@ def execute_agent(
total_time = time.perf_counter() - start
resp_debug.append(f"Total time cost: {total_time:.3f} seconds")
if debug:
resp_debug.append(resp_debug_append)
return resp_debug
else:
return resp
10 changes: 7 additions & 3 deletions app/entrypoints/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def chat(
aid: str = Path(..., description="instance id"),
q: str = Query(None, description="Query string"),
debug: bool = Query(None, description="Enable debug mode"),
thread: str = Query(None, description="Thread ID for conversation tracking"),
db: Session = Depends(get_db),
):
"""Chat with an AI agent.
Expand All @@ -39,6 +40,7 @@ def chat(
aid: Agent ID
q: User's input query
debug: Enable debug mode
thread: Thread ID for conversation tracking
db: Database session
Returns:
Expand All @@ -63,11 +65,13 @@ def chat(
),
)

# get thread_id from request ip
thread_id = f"{aid}-{request.client.host}"
# get thread_id from query or request ip
thread_id = (
f"{aid}-{thread}" if thread is not None else f"{aid}-{request.client.host}"
)
logger.debug(f"thread id: {thread_id}")

debug = debug if debug is not None else config.debug
debug = debug if debug is not None else config.debug_resp

# Execute agent and get response
resp = execute_agent(aid, AgentMessageInput(text=q), thread_id, debug=debug)
Expand Down
1 change: 1 addition & 0 deletions app/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Agent(SQLModel, table=True):
name: Optional[str] = Field(default=None)
model: str = Field(default="gpt-4o-mini")
prompt: Optional[str]
prompt_append: Optional[str]
# autonomous mode
autonomous_enabled: bool = Field(default=False)
autonomous_minutes: Optional[int]
Expand Down

0 comments on commit 8ea3ccf

Please sign in to comment.