Skip to content

Commit

Permalink
fix: RAG with documents
Browse files Browse the repository at this point in the history
Summary:
This was broken by https://github.com/meta-llama/llama-stack/pull/1015/files#r1975394190

Test Plan:

added e2e test
  • Loading branch information
ehhuang committed Mar 1, 2025
1 parent 6fa257b commit b52eefb
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,11 @@ async def _run(
if documents:
await self.handle_documents(session_id, documents, input_messages, tool_defs)

session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info and session_info.vector_db_id:
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)

output_attachments = []

n_iter = 0
Expand Down
61 changes: 61 additions & 0 deletions tests/client-sdk/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,67 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
assert expected_kw in response.output_message.content.lower()


def test_rag_agent_with_attachments(llama_stack_client, agent_config):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
agent_config = {
**agent_config,
"toolgroups": [
dict(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": [],
},
)
],
}
rag_agent = Agent(llama_stack_client, agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped",
),
]
user_prompts = [
(
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
documents,
),
(
"Tell me how to use LoRA",
None,
),
]

for prompt in user_prompts:
response = rag_agent.create_turn(
messages=[
{
"role": "user",
"content": prompt[0],
}
],
documents=prompt[1],
session_id=session_id,
stream=False,
)

# rag is called
tool_execution_step = [step for step in response.steps if step.step_type == "tool_execution"]
assert len(tool_execution_step) >= 1
assert tool_execution_step[0].tool_calls[0].tool_name == "knowledge_search"
assert "lora_rank" in response.output_message.content.lower()


def test_rag_and_code_agent(llama_stack_client, agent_config):
documents = []
documents.append(
Expand Down

0 comments on commit b52eefb

Please sign in to comment.