Skip to content

Commit

Permalink
Adds Chat History from langchain-redis
Browse files Browse the repository at this point in the history
  • Loading branch information
bsbodden committed Sep 29, 2024
1 parent e77be32 commit 01b70c9
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 149 deletions.
233 changes: 143 additions & 90 deletions demos/chat_with_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@

import gradio as gr
from dotenv import load_dotenv
from langchain.chains import RetrievalQA, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.prompts import PromptTemplate
from gradio_modal import Modal
from langchain.chains import RetrievalQA
from langchain.memory import ConversationBufferMemory
from langchain_community.callbacks import get_openai_callback
from langchain_core.runnables import RunnableSequence
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableSequence
from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings
from langchain_redis import RedisVectorStore
from langchain_redis import RedisChatMessageHistory, RedisVectorStore
from ragas.integrations.langchain import EvaluatorChain
from ragas.metrics import answer_relevancy, faithfulness
from redisvl.extensions.llmcache import SemanticCache
from redisvl.utils.rerank import CohereReranker, HFCrossEncoderReranker

from shared_components.cached_llm import CachedLLM
from shared_components.llm_utils import openai_models
from shared_components.pdf_utils import (process_file, render_file,
render_first_page)
from shared_components.theme_management import load_theme
from shared_components.llm_utils import openai_models

load_dotenv()

Expand Down Expand Up @@ -103,6 +105,9 @@ def __init__(self) -> None:
# LLM settings
self.llm_temperature = 0.7

# Initialize LLM
self.update_llm()

# Initialize RAGAS evaluator chains
self.faithfulness_chain = EvaluatorChain(metric=faithfulness)
self.answer_rel_chain = EvaluatorChain(metric=answer_relevancy)
Expand All @@ -113,6 +118,12 @@ def __init__(self) -> None:
redis_url=self.redis_url,
distance_threshold=self.distance_threshold,
)

# Chat History
self.use_chat_history = False
self.chat_history = None
print(f"DEBUG: Initial chat history state - use_chat_history: {self.use_chat_history}, chat_history: {self.chat_history}")

self.ensure_index_created()

def ensure_index_created(self):
Expand All @@ -128,9 +139,7 @@ def __call__(self, file: str, chunk_size: int, chunking_technique: str) -> Any:
return self.chain

def build_chain(self, file: str):
print(
f"DEBUG: Starting build_chain for file: {file.name} with chunk size: {self.chunk_size} and {self.chunking_technique}"
)
print(f"DEBUG: Starting build_chain for file: {file.name} with chunk size: {self.chunk_size} and {self.chunking_technique}")
documents, file_name = process_file(
file, self.chunk_size, self.chunking_technique
)
Expand All @@ -139,52 +148,92 @@ def build_chain(self, file: str):
).rstrip("_")

print(f"DEBUG: Creating vector store with index name: {index_name}")
# Load embeddings model
embeddings = OpenAIEmbeddings(api_key=self.openai_api_key)
vector_store = RedisVectorStore.from_documents(
self.vector_store = RedisVectorStore.from_documents(
documents,
embeddings,
redis_url=self.redis_url,
index_name=index_name,
)

# Create the retriever with the initial top_k value
self.vector_store = vector_store.as_retriever(search_kwargs={"k": self.top_k})
retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})

# Configure the LLM
self.llm = ChatOpenAI(
model=self.selected_model,
temperature=self.llm_temperature,
api_key=self.openai_api_key,
)
self.update_llm()

# Create a prompt template
prompt = PromptTemplate.from_template(
"""You are a helpful assistant. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
# Create a formatting function for documents
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

Context: {context}
# Create a custom prompt template
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."),
("human", "Context: {context}\n\nQuestion: {question}"),
("human", "Helpful Answer:")
])

Question: {input}
Helpful Answer:"""
# Create the RAG chain
rag_chain = (
{
"context": retriever | format_docs,
"question": RunnablePassthrough()
}
| prompt
| self.cached_llm
| StrOutputParser()
)

# Create the retrieval chain
self.qa_chain = RetrievalQA.from_chain_type(
self.cached_llm, # Use cached_llm instead of llm
retriever=self.vector_store,
return_source_documents=True,
)
print("DEBUG: RAG chain created successfully")
return rag_chain

return self.qa_chain
def update_chat_history(self, use_chat_history: bool):
print(f"DEBUG: Updating chat history. use_chat_history: {use_chat_history}")
self.use_chat_history = use_chat_history
if self.use_chat_history:
if self.chat_history is None:
self.chat_history = RedisChatMessageHistory(session_id="chat_with_pdf", redis_url=self.redis_url)
print("DEBUG: Created new RedisChatMessageHistory")
else:
print("DEBUG: Using existing RedisChatMessageHistory")
print(f"DEBUG: Current chat history length: {len(self.chat_history.messages)}")
else:
if self.chat_history:
print(f"DEBUG: Clearing chat history. Current length: {len(self.chat_history.messages)}")
self.chat_history.clear()
print("DEBUG: Cleared existing chat history")
self.chat_history = None
print(f"DEBUG: Chat history setting updated to {self.use_chat_history}")
return self.use_chat_history

def get_chat_history(self):
if self.chat_history and self.use_chat_history:
messages = self.chat_history.messages
print(f"DEBUG: Retrieved {len(messages)} messages from chat history")
formatted_history = []
for msg in messages:
if msg.type == 'human':
formatted_history.append(f"👤 **Human**: {msg.content}\n")
elif msg.type == 'ai':
formatted_history.append(f"🤖 **AI**: {msg.content}\n")
return "\n".join(formatted_history)
return "No chat history available."

def update_llm(self):
print("DEBUG: Updating LLM")
if self.llm is None:
print("DEBUG: self.llm is None, initializing new LLM")
self.llm = ChatOpenAI(
model=self.selected_model,
temperature=self.llm_temperature,
api_key=self.openai_api_key,
)

if self.use_semantic_cache:
print("DEBUG: Using semantic cache")
self.cached_llm = CachedLLM(self.llm, self.llmcache)
else:
print("DEBUG: Not using semantic cache")
self.cached_llm = self.llm

print(f"DEBUG: Updated LLM type: {type(self.cached_llm)}")

def update_chain(self):
if self.vector_store:
self.qa_chain = RetrievalQA.from_chain_type(
Expand Down Expand Up @@ -300,7 +349,7 @@ def evaluate_response(self, query, result):
eval_input = {
"question": query,
"answer": result["result"],
"contexts": [doc.page_content for doc in result["source_documents"]],
"contexts": [],
}

try:
Expand Down Expand Up @@ -335,6 +384,7 @@ def get_response(
top_k,
llm_model,
llm_temperature,
use_chat_history,
):
if not file:
raise gr.Error(message="Upload a PDF")
Expand All @@ -349,37 +399,16 @@ def get_response(
if app.llm_temperature != llm_temperature:
app.update_temperature(llm_temperature)

# Check if the semantic cache setting has changed
if app.use_semantic_cache != use_semantic_cache:
app.update_semantic_cache(use_semantic_cache)
app.chain = app(file, app.chunk_size, app.chunking_technique) # Rebuild the chain
app.use_semantic_cache = use_semantic_cache

app.use_reranker = use_reranker
app.reranker_type = reranker_type

chain = app.chain
start_time = time.time()

print(f"DEBUG: Invoking chain with query: {query}")
print(f"DEBUG: use_chat_history: {app.use_chat_history}")

with get_openai_callback() as cb:
result = chain.invoke({"query": query})
result = chain.invoke(query)
end_time = time.time()

# Apply re-ranking if enabled
rerank_info = None
if app.use_reranker:
print(f"DEBUG: Reranking with {reranker_type}")
reranked_docs, rerank_info, original_results = app.rerank_results(
query, result["source_documents"]
)
if reranked_docs:
result["source_documents"] = reranked_docs
else:
print("DEBUG: Re-ranking produced no results")
else:
print("DEBUG: Re-ranking skipped")

is_cache_hit = app.get_last_cache_status()

if not is_cache_hit:
Expand All @@ -397,47 +426,33 @@ def get_response(

elapsed_time = end_time - start_time

answer = result["result"]
app.chat_history += [(query, answer)]

# Prepare reranking feedback
rerank_feedback = ""
if rerank_info:
original_order = rerank_info["original_order"]
reranked_order = rerank_info["reranked_order"]
reranked_scores = rerank_info["reranked_scores"]

# Check if the order changed
order_changed = original_order != reranked_order
answer = result # The result is now directly the answer string

if order_changed:
rerank_feedback = (
f"ReRanking changed document order. Top score: {reranked_scores[0]:.4f}"
)
else:
rerank_feedback = "ReRanking did not change document order."
if app.use_chat_history and app.chat_history is not None:
app.chat_history.add_user_message(query)
app.chat_history.add_ai_message(answer)
print(f"DEBUG: Added to chat history. Current length: {len(app.chat_history.messages)}")
print(f"DEBUG: Last message in history: {app.chat_history.messages[-1].content[:50]}...")
else:
print("DEBUG: Chat history not updated (disabled or None)")

# Prepare the initial output without RAGAS evaluation
# Prepare the output
if is_cache_hit:
initial_output = f"⏱️ | Cache: {elapsed_time:.2f} SEC | COST $0.00 \n\n{rerank_feedback}\n\nRAGAS Evaluation: In progress..."
output = f"⏱️ | Cache: {elapsed_time:.2f} SEC | COST $0.00"
else:
tokens_per_sec = num_tokens / elapsed_time if elapsed_time > 0 else 0
initial_output = f"⏱️ | LLM: {elapsed_time:.2f} SEC | {tokens_per_sec:.2f} TOKENS/SEC | {num_tokens} TOKENS | COST ${total_cost:.4f}\n\n{rerank_feedback}\n\nRAGAS Evaluation: In progress..."
output = f"⏱️ | LLM: {elapsed_time:.2f} SEC | {tokens_per_sec:.2f} TOKENS/SEC | {num_tokens} TOKENS | COST ${total_cost:.4f}"

# Yield the response and initial output
# Yield the response and output
for char in answer:
history[-1][-1] += char
yield history, "", initial_output
yield history, "", output

# Perform RAGAS evaluation after yielding the response
feedback = perform_ragas_evaluation(query, result)
feedback = perform_ragas_evaluation(query, {"result": answer})

# Prepare the final output with RAGAS evaluation
if is_cache_hit:
final_output = f"⏱️ | Cache: {elapsed_time:.2f} SEC | COST $0.00 \n\n{rerank_feedback}\n\n{feedback}"
else:
tokens_per_sec = num_tokens / elapsed_time if elapsed_time > 0 else 0
final_output = f"⏱️ | LLM: {elapsed_time:.2f} SEC | {tokens_per_sec:.2f} TOKENS/SEC | {num_tokens} TOKENS | COST ${total_cost:.4f}\n\n{rerank_feedback}\n\n{feedback}"
final_output = f"{output}\n\n{feedback}"

# Yield one last time to update with RAGAS evaluation results
yield history, "", final_output
Expand All @@ -464,6 +479,24 @@ def render_first(file, chunk_size, chunking_technique):

return image, []

# Connect the show_history_btn to the display_chat_history function and show the modal
def show_history():
print(f"DEBUG: show_history called. use_chat_history: {app.use_chat_history}, chat_history: {app.chat_history}")
if app.use_chat_history and app.chat_history is not None:
messages = app.chat_history.messages
print(f"DEBUG: Retrieved {len(messages)} messages from chat history")
formatted_history = []
for msg in messages:
if msg.type == 'human':
formatted_history.append(f"👤 **Human**: {msg.content}\n")
elif msg.type == 'ai':
formatted_history.append(f"🤖 **AI**: {msg.content}\n")
history = "\n".join(formatted_history)
else:
history = "No chat history available."

print(f"DEBUG: Formatted chat history: {history[:100]}...")
return history, gr.update(visible=True)

def reset_app():
app.chat_history = []
Expand Down Expand Up @@ -526,7 +559,7 @@ def reset_app():
maximum=1.0,
value=app.distance_threshold,
step=0.01,
label="Semantic Cache Distance Threshold",
label="Distance Threshold",
)

with gr.Row():
Expand All @@ -538,6 +571,10 @@ def reset_app():
interactive=True,
)

with gr.Row():
use_chat_history = gr.Checkbox(label="Use Chat History", value=app.use_chat_history)
show_history_btn = gr.Button("Show Chat History")

# Right Half
with gr.Column(scale=6):
show_img = gr.Image(label="Uploaded PDF")
Expand Down Expand Up @@ -565,6 +602,11 @@ def reset_app():
)
reset_btn = gr.Button("Reset", elem_id="reset-btn")

# Add Modal for chat history here, outside of any column or row
with Modal(visible=False) as history_modal:
# gr.Markdown("Chat History")
history_display = gr.Markdown("No chat history available.")

btn.upload(
fn=render_first,
inputs=[btn, chunk_size, chunking_technique],
Expand Down Expand Up @@ -600,3 +642,14 @@ def reset_app():
inputs=None,
outputs=[chatbot, show_img, txt, feedback_markdown],
)

use_chat_history.change(
fn=app.update_chat_history,
inputs=[use_chat_history],
outputs=[]
)

show_history_btn.click(
fn=show_history,
outputs=[history_display, history_modal]
)
Loading

0 comments on commit 01b70c9

Please sign in to comment.