Skip to content

Commit

Permalink
converse
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashad001 committed Jul 28, 2024
1 parent a54be20 commit 2e3d505
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 22 deletions.
Binary file not shown.
Binary file not shown.
Binary file modified backend/chroma_db/chroma.sqlite3
Binary file not shown.
95 changes: 73 additions & 22 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
from langchain_groq import ChatGroq
from langchain import hub
from langchain_chroma import Chroma
from langchain.prompts import ChatPromptTemplate
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.chains import create_history_aware_retriever, create_retrieval_chain


# Load environment variables
load_dotenv()
Expand All @@ -21,8 +26,8 @@
MODEL_NAME = "BAAI/bge-small-en"
MODEL_KWARGS = {"device": "cpu"}
ENCODE_KWARGS = {"normalize_embeddings": True}
CHUNK_SIZE = 500
CHUNK_OVERLAP = 50
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 250
DATA_DIR = "./data/files"
CHROMA_DB_DIR = "./chroma_db"
METADATA_FILE = os.path.join(CHROMA_DB_DIR, "metadata.txt")
Expand Down Expand Up @@ -83,33 +88,79 @@ def get_vectorstore():
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

def create_rag_chain(retriever):
memory = ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True)

prompt_template = """You are a helpful AI assistant. Use the following pieces of context to answer the human's question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
Chat History:
{chat_history}
"""
chat_prompt_template = ChatPromptTemplate.from_template(
prompt_template, input_variables=["context", "question", "chat_history"]
def create_rag_chain(retriever):
contextualize_q_system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)

return (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| chat_prompt_template
| llm
| StrOutputParser()

system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise."
"\n\n"
"{context}"
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)


store = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]


conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)

return conversational_rag_chain

def setup_vectorstore():
return get_vectorstore().as_retriever()

def query(retriever, question):
rag_chain = create_rag_chain(retriever)
response = rag_chain.invoke(question)
return response
response = rag_chain.invoke(
{"input": question},
config= {
"configurable": {
"session_id": "scorp123"
},
}
)
return response['answer']

def main():
retriever = setup_vectorstore()
Expand Down

0 comments on commit 2e3d505

Please sign in to comment.