Skip to content

Commit

Permalink
#1 - refactor rag_chain invoke
Browse files Browse the repository at this point in the history
  • Loading branch information
obriensystems committed Sep 2, 2024
1 parent 035a47e commit 86f63fc
Showing 1 changed file with 29 additions and 9 deletions.
38 changes: 29 additions & 9 deletions src/rag/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
logging.basicConfig(level=logging.INFO)

def set_environment_variables():
os.environ["OPENAI_API_KEY"] = ""
os.environ["LANGCHAIN_API_KEY"] = ""
logging.basicConfig("getting OPENAI_API_KEY and LANGCHAIN_API_KEY env variables")
# os.environ["OPENAI_API_KEY"] = ""
# os.environ["LANGCHAIN_API_KEY"] = ""


def initialize_llm():
return ChatOpenAI(model="gpt-4o-mini")
Expand Down Expand Up @@ -51,6 +53,27 @@ def retrieve_documents(vectorstore, query, k=6):
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)


## refactor out rag_chain setup to allow for a non-local retriever from the vectorstore
def setup_rag_chain(retriever, prompt, llm):
"""
Sets up the RAG (Retrieval-Augmented Generation) chain.
Parameters:
- retriever: The retriever object for document retrieval.
- prompt: The prompt object for generating responses.
- llm: The language model object.
Returns:
- The configured RAG chain.
"""
return (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)

def main():
set_environment_variables()
llm = initialize_llm()
Expand All @@ -66,6 +89,7 @@ def main():
logging.info(f"Metadata size: {len(splits[10].metadata)}")

vectorstore = create_vectorstore(splits)
# retrieved_docs are not used
retrieved_docs = retrieve_documents(vectorstore, "What are the approaches to Task Decomposition?")
logging.info(f"Vectorstore retrieved: {len(retrieved_docs)}")

Expand All @@ -75,13 +99,9 @@ def main():
).to_messages()
logging.info(f"Example messages: {example_messages[0].content}")

# retriever needs to be moved up from function scope
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# unroll retriever
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 6})
rag_chain = setup_rag_chain(retriever, prompt, llm)

# Uncomment to use the RAG chain
# for chunk in rag_chain.stream("What is Task Decomposition?"):
Expand Down

0 comments on commit 86f63fc

Please sign in to comment.