Skip to content

Commit

Permalink
fix options menu, put chat_messages and input into container to preve…
Browse files Browse the repository at this point in the history
…nt from interference with options menu and set container height according to view port height
  • Loading branch information
mt7180 committed Mar 28, 2024
1 parent 7ff410d commit 37b8858
Showing 1 changed file with 123 additions and 115 deletions.
238 changes: 123 additions & 115 deletions frontend/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@

from streamlit.runtime.uploaded_file_manager import UploadedFile

# from fastapi import UploadFile
import streamlit as st
from streamlit_option_menu import option_menu
from streamlit_extras.add_vertical_space import add_vertical_space
from streamlit_extras.stylable_container import stylable_container
import requests

Expand Down Expand Up @@ -100,7 +98,7 @@ def initialize_session(refresh_session=False):
if "chat_mode" not in st.session_state:
st.session_state["chat_mode"] = ""
if "selected_page" not in st.session_state:
st.session_state["selected_page"] = 0
st.session_state["selected_page"] = "questionai"
if "redirect_page" not in st.session_state:
st.session_state["redirect_page"] = None
if "file_uploader_key" not in st.session_state:
Expand All @@ -110,7 +108,6 @@ def initialize_session(refresh_session=False):


def clear_history():
# st.session_state.messages.clear()
initialize_session(refresh_session=True)
response = requests.get(os.path.join(API_URL, "clear_history"))
st.session_state["redirect_page"] = 0
Expand All @@ -128,6 +125,7 @@ def clear_storage():
st.session_state["file_uploader_key"] += 1
st.session_state["url_uploader_key"] += 1
st.session_state["redirect_page"] = 0
st.session_state["selected_page"] = "questionai"
# st.session_state["url_input"]=""
clear_history()
response = requests.get(os.path.join(API_URL, "clear_storage"))
Expand Down Expand Up @@ -199,7 +197,6 @@ def post_data_to_backend(
response_data = response.json()
logging.info(f"upload response data: {response_data}")
# st.session_state.counter += 1
# print(response_data["summary"], st.session_state.counter)
post_ai_message_to_chat(
response_data.get("summary", "Unknown response"),
response_data.get("text_category"),
Expand All @@ -208,10 +205,7 @@ def post_data_to_backend(
response_data.get("used_tokens", 0)
)
else:
st.sidebar.error(
f"Error: {response.status_code} - {response}"
# .json().get("msg")}'
)
st.sidebar.error(f"Error: {response.status_code} - {response}")
except FileNotFoundError:
st.sidebar.error(
"No context is given. Please provide a url or upload a file"
Expand All @@ -234,19 +228,21 @@ def url_callback():

def display_sidemenu():
st.sidebar.title("Menu")
st.sidebar.markdown(
"""
Please uploade your file or enter a url. Supported file types:
- txt - as upload
- pdf - as upload
- website - as url
- sqlite - as upload
"""
)
sidebar_container = st.sidebar.container()
success_message = st.sidebar.empty()

with sidebar_container:
st.markdown(
"""
Please uploade your file or enter a url. Supported types:
- txt - as upload
- pdf - as upload
- website - as url
- sqlite - as upload
"""
)

with st.sidebar.container():
success_message = st.empty()
if st.file_uploader(
"dragndrop",
type=["txt", "pdf", "sqlite"],
Expand Down Expand Up @@ -286,129 +282,141 @@ def display_sidemenu():
# st.rerun()
if st.button("Clear knowledge base: texts/ urls", use_container_width=True):
success_message.success(clear_storage())
st.experimental_rerun()
st.rerun()


@register_page(PAGE_REGISTRY_DICT)
def questionai():
with st.container():
messages = stylable_container(
key="message_container",
css_styles=[
"""
{
padding: 0.5em;
min-height: 65vh;
overflow-y: scroll;
}
""",
"""
.stMarkdown {
padding-right: 1.5em;
}
""",
],
)

for message in st.session_state.messages:
# image = "xxx.png" if message["role"] == "user" else "xxx.png"
with st.chat_message(message["role"]):
st.markdown(message["content"])
messages.chat_message(message["role"]).write(message["content"])

if prompt := st.chat_input("-> Your Question ..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
messages.chat_message("user").write(prompt)

with st.spinner("Waiting for response"):
with messages.chat_message("assistant"), st.spinner(
"Waiting for Response ..."
):
payload = {
"prompt": prompt,
"temperature": st.session_state.temperature,
}
response = requests.post(os.path.join(API_URL, "qa_text"), json=payload)
with st.chat_message("assistant"):
message_placeholder = st.empty()
ai_answer = ""
if response.status_code == 200:
response_data = response.json()
ai_answer = response_data.get(
"ai_answer", "Unknown response type"
)
st.session_state.total_tokens.append(
response_data.get("used_tokens", 0)
)
message_placeholder.markdown(ai_answer)
st.session_state.messages.append(
{"role": "assistant", "content": ai_answer}
)
else:
st.error(f"Error: {response.status_code} - {response.text}")
message_placeholder.markdown(ai_answer)
add_vertical_space(7)
elif len(st.session_state.messages) == 0:
# cfd = pathlib.Path(__file__).parent
image = Image.open(cfd / "static" / "main_picture3.png")
_, center, _ = st.columns((1, 5, 1))
center.image(
image,
caption=None,
)

ai_answer = ""

@register_page(PAGE_REGISTRY_DICT)
def quizme():
with st.container():
if st.session_state["chat_mode"] == "database":
st.markdown(
"""Sorry, you are in database mode, no quiz available.
Please upload a text or give a url to a webpage to generate a quiz."""
)
else:
st.markdown("### A Quiz for You")
st.session_state.score = 0
message_placeholder = st.empty()
if st.button("Generate a Quiz"):
response = make_get_request("quiz")
if response.status_code == 200:
for question in response.json().get("questions"):
answer_options = [
question["correct_answer"],
question["wrong_answer_1"],
question["wrong_answer_2"],
]
random.shuffle(answer_options)
st.session_state["question_data"].append(
{
"question_txt": question["question"],
"correct_answer": question["correct_answer"],
"answer_options": answer_options,
}
)
response_data = response.json()
ai_answer = response_data.get("ai_answer", "Unknown response type")
st.session_state.total_tokens.append(
response_data.get("used_tokens", 0)
)
st.write(ai_answer)
st.session_state.messages.append(
{"role": "assistant", "content": ai_answer}
)
else:
message_placeholder.error(response.json().get("detail"))

for question in st.session_state["question_data"]:
st.markdown(f"##### Question: {question['question_txt']}")
user_answer = st.radio(
"Select an answer:",
["Please Select an answer:", *question["answer_options"]],
label_visibility="collapsed",
)

if user_answer == question["correct_answer"]:
st.session_state.score += 1
st.error(f"Error: {response.status_code} - {response.text}")

if st.session_state["score"] > 0:
message_placeholder.success(
f"You answered {st.session_state.score} questions correct!"
)
if not st.session_state["question_data"]:
# cfd = pathlib.Path(__file__).parent
image = Image.open(cfd / "static" / "Hippo.png")
_, center, _ = st.columns((2, 4, 2))
elif len(st.session_state.messages) == 0:
with messages:
_, center, _ = st.columns((1, 5, 1))
image = Image.open(cfd / "static" / "main_picture3.png")
center.image(
image,
caption=None,
use_column_width=True,
)


@register_page(PAGE_REGISTRY_DICT)
def quizme():
if st.session_state["chat_mode"] == "database":
st.markdown(
"""Sorry, you are in database mode, no quiz available.
Please upload a text or give a url to a webpage to generate a quiz."""
)
else:
st.markdown("### A Quiz for You")
st.session_state.score = 0
message_placeholder = st.empty()
if st.button("Generate a Quiz"):
response = make_get_request("quiz")
if response.status_code == 200:
for question in response.json().get("questions"):
answer_options = [
question["correct_answer"],
question["wrong_answer_1"],
question["wrong_answer_2"],
]
random.shuffle(answer_options)
st.session_state["question_data"].append(
{
"question_txt": question["question"],
"correct_answer": question["correct_answer"],
"answer_options": answer_options,
}
)
else:
message_placeholder.error(response.json().get("detail"))

for question in st.session_state["question_data"]:
st.markdown(f"##### Question: {question['question_txt']}")
user_answer = st.radio(
"Select an answer:",
["Please Select an answer:", *question["answer_options"]],
label_visibility="collapsed",
)

if user_answer == question["correct_answer"]:
st.session_state.score += 1

if st.session_state["score"] > 0:
message_placeholder.success(
f"You answered {st.session_state.score} questions correct!"
)
if not st.session_state["question_data"]:
image = Image.open(cfd / "static" / "Hippo.png")
_, center, _ = st.columns([1, 3, 1])
center.image(
image,
caption=None,
use_column_width=True,
)


@register_page(PAGE_REGISTRY_DICT)
def statistics():
import pandas as pd

with st.container():
st.markdown("### Used API Tokens per Request of your Current Session")
st.markdown("(Usage not persisted)")
_, center, _ = st.columns((1, 5, 1))
chart_data = pd.DataFrame({"Used API Tokens": st.session_state.total_tokens})
center.bar_chart(
data=chart_data,
color="#D3DCE5",
y="Used API Tokens",
use_container_width=False,
)
st.markdown("### Used API Tokens per Request of your Current Session")
st.markdown("(Usage not persisted)")
_, center, _ = st.columns((1, 5, 1))
chart_data = pd.DataFrame({"Used API Tokens": st.session_state.total_tokens})
center.bar_chart(
data=chart_data,
color="#D3DCE5",
y="Used API Tokens",
use_container_width=False,
)


def post_ai_message_to_chat(message, document_category):
Expand Down

0 comments on commit 37b8858

Please sign in to comment.