Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix options menu and chat_input position #53

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 123 additions & 115 deletions frontend/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@

from streamlit.runtime.uploaded_file_manager import UploadedFile

# from fastapi import UploadFile
import streamlit as st
from streamlit_option_menu import option_menu
from streamlit_extras.add_vertical_space import add_vertical_space
from streamlit_extras.stylable_container import stylable_container
import requests

Expand Down Expand Up @@ -100,7 +98,7 @@ def initialize_session(refresh_session=False):
if "chat_mode" not in st.session_state:
st.session_state["chat_mode"] = ""
if "selected_page" not in st.session_state:
st.session_state["selected_page"] = 0
st.session_state["selected_page"] = "questionai"
if "redirect_page" not in st.session_state:
st.session_state["redirect_page"] = None
if "file_uploader_key" not in st.session_state:
Expand All @@ -110,7 +108,6 @@ def initialize_session(refresh_session=False):


def clear_history():
# st.session_state.messages.clear()
initialize_session(refresh_session=True)
response = requests.get(os.path.join(API_URL, "clear_history"))
st.session_state["redirect_page"] = 0
Expand All @@ -128,6 +125,7 @@ def clear_storage():
st.session_state["file_uploader_key"] += 1
st.session_state["url_uploader_key"] += 1
st.session_state["redirect_page"] = 0
st.session_state["selected_page"] = "questionai"
# st.session_state["url_input"]=""
clear_history()
response = requests.get(os.path.join(API_URL, "clear_storage"))
Expand Down Expand Up @@ -199,7 +197,6 @@ def post_data_to_backend(
response_data = response.json()
logging.info(f"upload response data: {response_data}")
# st.session_state.counter += 1
# print(response_data["summary"], st.session_state.counter)
post_ai_message_to_chat(
response_data.get("summary", "Unknown response"),
response_data.get("text_category"),
Expand All @@ -208,10 +205,7 @@ def post_data_to_backend(
response_data.get("used_tokens", 0)
)
else:
st.sidebar.error(
f"Error: {response.status_code} - {response}"
# .json().get("msg")}'
)
st.sidebar.error(f"Error: {response.status_code} - {response}")
except FileNotFoundError:
st.sidebar.error(
"No context is given. Please provide a url or upload a file"
Expand All @@ -234,19 +228,21 @@ def url_callback():

def display_sidemenu():
st.sidebar.title("Menu")
st.sidebar.markdown(
"""
Please uploade your file or enter a url. Supported file types:

- txt - as upload
- pdf - as upload
- website - as url
- sqlite - as upload
"""
)
sidebar_container = st.sidebar.container()
success_message = st.sidebar.empty()

with sidebar_container:
st.markdown(
"""
Please uploade your file or enter a url. Supported types:

- txt - as upload
- pdf - as upload
- website - as url
- sqlite - as upload
"""
)

with st.sidebar.container():
success_message = st.empty()
if st.file_uploader(
"dragndrop",
type=["txt", "pdf", "sqlite"],
Expand Down Expand Up @@ -286,129 +282,141 @@ def display_sidemenu():
# st.rerun()
if st.button("Clear knowledge base: texts/ urls", use_container_width=True):
success_message.success(clear_storage())
st.experimental_rerun()
st.rerun()


@register_page(PAGE_REGISTRY_DICT)
def questionai():
with st.container():
messages = stylable_container(
key="message_container",
css_styles=[
"""
{
padding: 0.5em;
min-height: 65vh;
overflow-y: scroll;
}
""",
"""
.stMarkdown {
padding-right: 1.5em;
}
""",
],
)

for message in st.session_state.messages:
# image = "xxx.png" if message["role"] == "user" else "xxx.png"
with st.chat_message(message["role"]):
st.markdown(message["content"])
messages.chat_message(message["role"]).write(message["content"])

if prompt := st.chat_input("-> Your Question ..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
messages.chat_message("user").write(prompt)

with st.spinner("Waiting for response"):
with messages.chat_message("assistant"), st.spinner(
"Waiting for Response ..."
):
payload = {
"prompt": prompt,
"temperature": st.session_state.temperature,
}
response = requests.post(os.path.join(API_URL, "qa_text"), json=payload)
with st.chat_message("assistant"):
message_placeholder = st.empty()
ai_answer = ""
if response.status_code == 200:
response_data = response.json()
ai_answer = response_data.get(
"ai_answer", "Unknown response type"
)
st.session_state.total_tokens.append(
response_data.get("used_tokens", 0)
)
message_placeholder.markdown(ai_answer)
st.session_state.messages.append(
{"role": "assistant", "content": ai_answer}
)
else:
st.error(f"Error: {response.status_code} - {response.text}")
message_placeholder.markdown(ai_answer)
add_vertical_space(7)
elif len(st.session_state.messages) == 0:
# cfd = pathlib.Path(__file__).parent
image = Image.open(cfd / "static" / "main_picture3.png")
_, center, _ = st.columns((1, 5, 1))
center.image(
image,
caption=None,
)

ai_answer = ""

@register_page(PAGE_REGISTRY_DICT)
def quizme():
with st.container():
if st.session_state["chat_mode"] == "database":
st.markdown(
"""Sorry, you are in database mode, no quiz available.
Please upload a text or give a url to a webpage to generate a quiz."""
)
else:
st.markdown("### A Quiz for You")
st.session_state.score = 0
message_placeholder = st.empty()
if st.button("Generate a Quiz"):
response = make_get_request("quiz")
if response.status_code == 200:
for question in response.json().get("questions"):
answer_options = [
question["correct_answer"],
question["wrong_answer_1"],
question["wrong_answer_2"],
]
random.shuffle(answer_options)
st.session_state["question_data"].append(
{
"question_txt": question["question"],
"correct_answer": question["correct_answer"],
"answer_options": answer_options,
}
)
response_data = response.json()
ai_answer = response_data.get("ai_answer", "Unknown response type")
st.session_state.total_tokens.append(
response_data.get("used_tokens", 0)
)
st.write(ai_answer)
st.session_state.messages.append(
{"role": "assistant", "content": ai_answer}
)
else:
message_placeholder.error(response.json().get("detail"))

for question in st.session_state["question_data"]:
st.markdown(f"##### Question: {question['question_txt']}")
user_answer = st.radio(
"Select an answer:",
["Please Select an answer:", *question["answer_options"]],
label_visibility="collapsed",
)

if user_answer == question["correct_answer"]:
st.session_state.score += 1
st.error(f"Error: {response.status_code} - {response.text}")

if st.session_state["score"] > 0:
message_placeholder.success(
f"You answered {st.session_state.score} questions correct!"
)
if not st.session_state["question_data"]:
# cfd = pathlib.Path(__file__).parent
image = Image.open(cfd / "static" / "Hippo.png")
_, center, _ = st.columns((2, 4, 2))
elif len(st.session_state.messages) == 0:
with messages:
_, center, _ = st.columns((1, 5, 1))
image = Image.open(cfd / "static" / "main_picture3.png")
center.image(
image,
caption=None,
use_column_width=True,
)


@register_page(PAGE_REGISTRY_DICT)
def quizme():
if st.session_state["chat_mode"] == "database":
st.markdown(
"""Sorry, you are in database mode, no quiz available.
Please upload a text or give a url to a webpage to generate a quiz."""
)
else:
st.markdown("### A Quiz for You")
st.session_state.score = 0
message_placeholder = st.empty()
if st.button("Generate a Quiz"):
response = make_get_request("quiz")
if response.status_code == 200:
for question in response.json().get("questions"):
answer_options = [
question["correct_answer"],
question["wrong_answer_1"],
question["wrong_answer_2"],
]
random.shuffle(answer_options)
st.session_state["question_data"].append(
{
"question_txt": question["question"],
"correct_answer": question["correct_answer"],
"answer_options": answer_options,
}
)
else:
message_placeholder.error(response.json().get("detail"))

for question in st.session_state["question_data"]:
st.markdown(f"##### Question: {question['question_txt']}")
user_answer = st.radio(
"Select an answer:",
["Please Select an answer:", *question["answer_options"]],
label_visibility="collapsed",
)

if user_answer == question["correct_answer"]:
st.session_state.score += 1

if st.session_state["score"] > 0:
message_placeholder.success(
f"You answered {st.session_state.score} questions correct!"
)
if not st.session_state["question_data"]:
image = Image.open(cfd / "static" / "Hippo.png")
_, center, _ = st.columns([1, 3, 1])
center.image(
image,
caption=None,
use_column_width=True,
)


@register_page(PAGE_REGISTRY_DICT)
def statistics():
import pandas as pd

with st.container():
st.markdown("### Used API Tokens per Request of your Current Session")
st.markdown("(Usage not persisted)")
_, center, _ = st.columns((1, 5, 1))
chart_data = pd.DataFrame({"Used API Tokens": st.session_state.total_tokens})
center.bar_chart(
data=chart_data,
color="#D3DCE5",
y="Used API Tokens",
use_container_width=False,
)
st.markdown("### Used API Tokens per Request of your Current Session")
st.markdown("(Usage not persisted)")
_, center, _ = st.columns((1, 5, 1))
chart_data = pd.DataFrame({"Used API Tokens": st.session_state.total_tokens})
center.bar_chart(
data=chart_data,
color="#D3DCE5",
y="Used API Tokens",
use_container_width=False,
)


def post_ai_message_to_chat(message, document_category):
Expand Down
Loading