From b52eefb0e0f57a44733fc23c23649c6579d5d5d8 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 28 Feb 2025 16:11:20 -0800 Subject: [PATCH] fix: RAG with documents Summary: This was broken by https://github.com/meta-llama/llama-stack/pull/1015/files#r1975394190 Test Plan: added e2e test --- .../agents/meta_reference/agent_instance.py | 5 ++ tests/client-sdk/agents/test_agents.py | 61 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 4d0d8ed458..be7a6bf3c2 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -533,6 +533,11 @@ async def _run( if documents: await self.handle_documents(session_id, documents, input_messages, tool_defs) + session_info = await self.storage.get_session_info(session_id) + # if the session has a memory bank id, let the memory tool use it + if session_info and session_info.vector_db_id: + toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id) + output_attachments = [] n_iter = 0 diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 8f68699b28..f2c6a328ff 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -501,6 +501,67 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name): assert expected_kw in response.output_message.content.lower() +def test_rag_agent_with_attachments(llama_stack_client, agent_config): + urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] + documents = [ + Document( + document_id=f"num-{i}", + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + metadata={}, + ) + for i, url in enumerate(urls) + ] + agent_config = { + **agent_config, + "toolgroups": [ + dict( + name="builtin::rag/knowledge_search", + args={ + "vector_db_ids": [], + }, + ) + ], + } + rag_agent = Agent(llama_stack_client, agent_config) + session_id = rag_agent.create_session(f"test-session-{uuid4()}") + user_prompts = [ + ( + "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", + "grouped", + ), + ] + user_prompts = [ + ( + "I am attaching some documentation for Torchtune. Help me answer questions I will ask next.", + documents, + ), + ( + "Tell me how to use LoRA", + None, + ), + ] + + for prompt in user_prompts: + response = rag_agent.create_turn( + messages=[ + { + "role": "user", + "content": prompt[0], + } + ], + documents=prompt[1], + session_id=session_id, + stream=False, + ) + + # rag is called + tool_execution_step = [step for step in response.steps if step.step_type == "tool_execution"] + assert len(tool_execution_step) >= 1 + assert tool_execution_step[0].tool_calls[0].tool_name == "knowledge_search" + assert "lora_rank" in response.output_message.content.lower() + + def test_rag_and_code_agent(llama_stack_client, agent_config): documents = [] documents.append(