Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add feedback reaction button and add chat history v2 #112

Merged
merged 3 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions frontend/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
streamlit==1.37.0
streamlit==1.40.2
requests==2.32.3
requests-oauthlib==2.0.0
Pillow==10.3.0
Pillow==11.0.0
pytz==2024.1
google-auth==2.30.0
google-auth-httplib2==0.2.0
Expand All @@ -13,4 +13,5 @@ flask==3.0.3
types-pytz==2024.1.0.20240417
types-requests==2.32.0.20240622
pre-commit==3.7.1
ruff==0.5.1
ruff==0.5.1
mypy==1.10.1
4 changes: 2 additions & 2 deletions frontend/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
streamlit==1.37.0
streamlit==1.40.2
requests==2.32.3
requests-oauthlib==2.0.0
Pillow==10.3.0
Pillow==11.0.0
pytz==2024.1
google-auth==2.30.0
google-auth-httplib2==0.2.0
Expand Down
209 changes: 141 additions & 68 deletions frontend/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import os
import ast
from PIL import Image
from utils.feedback import show_feedback_form
from utils.feedback import (
show_feedback_form,
submit_feedback_to_google_sheet,
get_git_commit_hash,
)
from dotenv import load_dotenv
from typing import Callable, Any

Expand All @@ -22,6 +26,24 @@ def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, float]:
return wrapper


def translate_chat_history_to_api(chat_history, max_pairs=4):
api_format = []
relevant_history = [
msg for msg in chat_history[1:] if msg["role"] in ["user", "ai"]
]

i = len(relevant_history) - 1
while i > 0 and len(api_format) < max_pairs:
ai_msg = relevant_history[i]
user_msg = relevant_history[i - 1]
if ai_msg["role"] == "ai" and user_msg["role"] == "user":
api_format.insert(0, {"User": user_msg["content"], "AI": ai_msg["content"]})
i -= 2
else:
i -= 1
return api_format


@measure_response_time
def response_generator(user_input: str) -> tuple[str, str] | tuple[None, None]:
"""
Expand All @@ -34,74 +56,47 @@ def response_generator(user_input: str) -> tuple[str, str] | tuple[None, None]:
- tuple: Contains the AI response and sources.
"""
url = f"{st.session_state.base_url}{st.session_state.selected_endpoint}"

headers = {"accept": "application/json", "Content-Type": "application/json"}

payload = {"query": user_input, "list_sources": True, "list_context": True}

chat_history = translate_chat_history_to_api(st.session_state.chat_history)
payload = {
"query": user_input,
"list_sources": True,
"list_context": True,
"chat_history": chat_history,
}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()

try:
data = response.json()
if not isinstance(data, dict):
st.error("Invalid response format")
return None, None
except ValueError:
st.error("Failed to decode JSON response")
data = response.json()
if not isinstance(data, dict):
st.error("Invalid response format")
return None, None

sources = data.get("sources", "")
st.session_state.metadata[user_input] = {
"sources": sources,
"context": data.get("context", ""),
}

return data.get("response", ""), sources

except requests.exceptions.RequestException as e:
st.error(f"Request failed: {e}")
return None, None


def fetch_endpoints() -> tuple[str, list[str]]:
base_url = os.getenv("CHAT_ENDPOINT", "http://localhost:8000")
url = f"{base_url}/chains/listAll"
try:
response = requests.get(url)
response.raise_for_status()
endpoints = response.json()
return base_url, endpoints
except requests.exceptions.RequestException as e:
st.error(f"Failed to fetch endpoints: {e}")
return base_url, []


def main() -> None:
load_dotenv()

img = Image.open("assets/or_logo.png")
st.set_page_config(page_title="OR Assistant", page_icon=img)

deployment_time = datetime.datetime.now(pytz.timezone("UTC"))
st.info(f"Deployment time: {deployment_time.strftime('%m/%d/%Y %H:%M')} UTC")
st.info(f'Deployment time: {deployment_time.strftime("%m/%d/%Y %H:%M")} UTC')

st.title("OR Assistant")

base_url, endpoints = fetch_endpoints()

selected_endpoint = st.selectbox(
"Select preferred endpoint",
options=endpoints,
index=0,
format_func=lambda x: x.split("/")[-1].capitalize(),
)
base_url = os.getenv("CHAT_ENDPOINT", "http://localhost:8000")
selected_endpoint = "/graphs/agent-retriever"

if "selected_endpoint" not in st.session_state:
st.session_state.selected_endpoint = selected_endpoint
else:
st.session_state.selected_endpoint = selected_endpoint

if "base_url" not in st.session_state:
st.session_state.base_url = base_url
Expand All @@ -115,6 +110,8 @@ def main() -> None:
st.session_state.chat_history = []
if "metadata" not in st.session_state:
st.session_state.metadata = {}
if "sources" not in st.session_state:
st.session_state.sources = {}

if not st.session_state.chat_history:
st.session_state.chat_history.append(
Expand All @@ -124,10 +121,42 @@ def main() -> None:
}
)

for message in st.session_state.chat_history:
for idx, message in enumerate(st.session_state.chat_history):
with st.chat_message(message["role"]):
st.markdown(message["content"])

if message["role"] == "ai" and idx > 0:
user_message = st.session_state.chat_history[idx - 1]
if user_message["role"] == "user":
user_input = user_message["content"]
sources = st.session_state.sources.get(user_input)
with st.expander("Sources:"):
try:
if sources:
if isinstance(sources, str):
cleaned_sources = sources.replace("{", "[").replace(
"}", "]"
)
parsed_sources = ast.literal_eval(cleaned_sources)
else:
parsed_sources = sources
if (
isinstance(parsed_sources, (list, set))
and parsed_sources
):
sources_list = "\n".join(
f"- [{link}]({link})"
for link in parsed_sources
if link.strip()
)
st.markdown(sources_list)
else:
st.markdown("No Sources Attached.")
else:
st.markdown("No Sources Attached.")
except (ValueError, SyntaxError) as e:
st.markdown(f"Failed to parse sources: {e}")

user_input = st.chat_input("Enter your queries ...")

if user_input:
Expand All @@ -146,62 +175,69 @@ def main() -> None:
):
response, sources = response_tuple
if response is not None:
response_buffer = ""
response_buffer = response

with st.chat_message("ai"):
message_placeholder = st.empty()

response_buffer = ""
for chunk in response.split(" "):
response_buffer += chunk + " "
if chunk.endswith("\n"):
response_buffer += " "
message_placeholder.markdown(response_buffer)
time.sleep(0.05)

message_placeholder.markdown(response_buffer)

# Display response time
response_time_text = (
f"Response Time: {response_time / 1000:.2f} seconds"
)
response_time_colored = f":{'green' if response_time < 5000 else 'orange' if response_time < 10000 else 'red'}[{response_time_text}]"
st.markdown(response_time_colored)
if response_time < 5000:
color = "green"
elif response_time < 10000:
color = "orange"
else:
color = "red"
st.markdown(f":{color}[{response_time_text}]")

st.session_state.chat_history.append(
{
"content": response_buffer,
"role": "ai",
}
)

if sources:
with st.expander("Sources:"):
try:
st.session_state.sources[user_input] = sources

with st.expander("Sources:"):
try:
if sources:
if isinstance(sources, str):
cleaned_sources = sources.replace("{", "[").replace(
"}", "]"
)
parsed_sources = ast.literal_eval(cleaned_sources)
else:
parsed_sources = sources
if isinstance(parsed_sources, (list, set)):
if (
isinstance(parsed_sources, (list, set))
and parsed_sources
):
sources_list = "\n".join(
f"- [{link}]({link})"
for link in parsed_sources
if link.strip()
)
st.markdown(sources_list)
else:
st.markdown("No valid sources found.")
except (ValueError, SyntaxError) as e:
st.markdown(f"Failed to parse sources: {e}")
else:
st.error("Invalid response from the API")

st.markdown("No Sources Attached.")
else:
st.markdown("No Sources Attached.")
except (ValueError, SyntaxError) as e:
st.markdown(f"Failed to parse sources: {e}")
else:
st.error("Invalid response from the API")

# Reaction buttons and feedback form
question_dict = {
interaction["content"]: i
for i, interaction in enumerate(st.session_state.chat_history)
if interaction["role"] == "user"
}

if question_dict and os.getenv("FEEDBACK_SHEET_ID"):
if "feedback_button" not in st.session_state:
st.session_state.feedback_button = False
Expand All @@ -212,10 +248,47 @@ def update_state() -> None:
"""
st.session_state.feedback_button = True

if (
st.button("Feedback", on_click=update_state)
or st.session_state.feedback_button
):
# Display reaction buttons
col1, col2, col3 = st.columns([1, 1, 2])
with col1:
thumbs_up = st.button("👍", key="thumbs_up")
with col2:
thumbs_down = st.button("👎", key="thumbs_down")
with col3:
feedback_clicked = st.button("Feedback", on_click=update_state)

# Handle thumbs up and thumbs down reactions
if thumbs_up or thumbs_down:
try:
selected_question = st.session_state.chat_history[-2][
"content"
] # Last user question
gen_ans = st.session_state.chat_history[-1][
"content"
] # Last AI response
sources = st.session_state.metadata.get(selected_question, {}).get(
"sources", ["N/A"]
)
context = st.session_state.metadata.get(selected_question, {}).get(
"context", ["N/A"]
)
reaction = "upvote" if thumbs_up else "downvote"

submit_feedback_to_google_sheet(
question=selected_question,
answer=gen_ans,
sources=sources if isinstance(sources, list) else [sources],
context=context if isinstance(context, list) else [context],
issue="", # Leave issue blank
version=os.getenv("RAG_VERSION", get_git_commit_hash()),
reaction=reaction, # Pass the reaction
)
st.success("Thank you for your feedback!")
except Exception as e:
st.error(f"Failed to submit feedback: {e}")

# Feedback form logic
if feedback_clicked or st.session_state.feedback_button:
try:
show_feedback_form(
question_dict,
Expand Down
Loading