Skip to content

Commit

Permalink
Merge pull request #53 from mt7180/52-fix-streamlit-frontend-with-new…
Browse files Browse the repository at this point in the history
…-streamlit-version

fix options menu and chat_input iposition
  • Loading branch information
mt7180 authored Mar 28, 2024
2 parents 7ff410d + 37b8858 commit 294296f
Showing 1 changed file with 123 additions and 115 deletions.
238 changes: 123 additions & 115 deletions frontend/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@

from streamlit.runtime.uploaded_file_manager import UploadedFile

# from fastapi import UploadFile
import streamlit as st
from streamlit_option_menu import option_menu
from streamlit_extras.add_vertical_space import add_vertical_space
from streamlit_extras.stylable_container import stylable_container
import requests

Expand Down Expand Up @@ -100,7 +98,7 @@ def initialize_session(refresh_session=False):
if "chat_mode" not in st.session_state:
st.session_state["chat_mode"] = ""
if "selected_page" not in st.session_state:
st.session_state["selected_page"] = 0
st.session_state["selected_page"] = "questionai"
if "redirect_page" not in st.session_state:
st.session_state["redirect_page"] = None
if "file_uploader_key" not in st.session_state:
Expand All @@ -110,7 +108,6 @@ def initialize_session(refresh_session=False):


def clear_history():
# st.session_state.messages.clear()
initialize_session(refresh_session=True)
response = requests.get(os.path.join(API_URL, "clear_history"))
st.session_state["redirect_page"] = 0
Expand All @@ -128,6 +125,7 @@ def clear_storage():
st.session_state["file_uploader_key"] += 1
st.session_state["url_uploader_key"] += 1
st.session_state["redirect_page"] = 0
st.session_state["selected_page"] = "questionai"
# st.session_state["url_input"]=""
clear_history()
response = requests.get(os.path.join(API_URL, "clear_storage"))
Expand Down Expand Up @@ -199,7 +197,6 @@ def post_data_to_backend(
response_data = response.json()
logging.info(f"upload response data: {response_data}")
# st.session_state.counter += 1
# print(response_data["summary"], st.session_state.counter)
post_ai_message_to_chat(
response_data.get("summary", "Unknown response"),
response_data.get("text_category"),
Expand All @@ -208,10 +205,7 @@ def post_data_to_backend(
response_data.get("used_tokens", 0)
)
else:
st.sidebar.error(
f"Error: {response.status_code} - {response}"
# .json().get("msg")}'
)
st.sidebar.error(f"Error: {response.status_code} - {response}")
except FileNotFoundError:
st.sidebar.error(
"No context is given. Please provide a url or upload a file"
Expand All @@ -234,19 +228,21 @@ def url_callback():

def display_sidemenu():
st.sidebar.title("Menu")
st.sidebar.markdown(
"""
Please uploade your file or enter a url. Supported file types:
- txt - as upload
- pdf - as upload
- website - as url
- sqlite - as upload
"""
)
sidebar_container = st.sidebar.container()
success_message = st.sidebar.empty()

with sidebar_container:
st.markdown(
"""
Please uploade your file or enter a url. Supported types:
- txt - as upload
- pdf - as upload
- website - as url
- sqlite - as upload
"""
)

with st.sidebar.container():
success_message = st.empty()
if st.file_uploader(
"dragndrop",
type=["txt", "pdf", "sqlite"],
Expand Down Expand Up @@ -286,129 +282,141 @@ def display_sidemenu():
# st.rerun()
if st.button("Clear knowledge base: texts/ urls", use_container_width=True):
success_message.success(clear_storage())
st.experimental_rerun()
st.rerun()


@register_page(PAGE_REGISTRY_DICT)
def questionai():
with st.container():
messages = stylable_container(
key="message_container",
css_styles=[
"""
{
padding: 0.5em;
min-height: 65vh;
overflow-y: scroll;
}
""",
"""
.stMarkdown {
padding-right: 1.5em;
}
""",
],
)

for message in st.session_state.messages:
# image = "xxx.png" if message["role"] == "user" else "xxx.png"
with st.chat_message(message["role"]):
st.markdown(message["content"])
messages.chat_message(message["role"]).write(message["content"])

if prompt := st.chat_input("-> Your Question ..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
messages.chat_message("user").write(prompt)

with st.spinner("Waiting for response"):
with messages.chat_message("assistant"), st.spinner(
"Waiting for Response ..."
):
payload = {
"prompt": prompt,
"temperature": st.session_state.temperature,
}
response = requests.post(os.path.join(API_URL, "qa_text"), json=payload)
with st.chat_message("assistant"):
message_placeholder = st.empty()
ai_answer = ""
if response.status_code == 200:
response_data = response.json()
ai_answer = response_data.get(
"ai_answer", "Unknown response type"
)
st.session_state.total_tokens.append(
response_data.get("used_tokens", 0)
)
message_placeholder.markdown(ai_answer)
st.session_state.messages.append(
{"role": "assistant", "content": ai_answer}
)
else:
st.error(f"Error: {response.status_code} - {response.text}")
message_placeholder.markdown(ai_answer)
add_vertical_space(7)
elif len(st.session_state.messages) == 0:
# cfd = pathlib.Path(__file__).parent
image = Image.open(cfd / "static" / "main_picture3.png")
_, center, _ = st.columns((1, 5, 1))
center.image(
image,
caption=None,
)

ai_answer = ""

@register_page(PAGE_REGISTRY_DICT)
def quizme():
with st.container():
if st.session_state["chat_mode"] == "database":
st.markdown(
"""Sorry, you are in database mode, no quiz available.
Please upload a text or give a url to a webpage to generate a quiz."""
)
else:
st.markdown("### A Quiz for You")
st.session_state.score = 0
message_placeholder = st.empty()
if st.button("Generate a Quiz"):
response = make_get_request("quiz")
if response.status_code == 200:
for question in response.json().get("questions"):
answer_options = [
question["correct_answer"],
question["wrong_answer_1"],
question["wrong_answer_2"],
]
random.shuffle(answer_options)
st.session_state["question_data"].append(
{
"question_txt": question["question"],
"correct_answer": question["correct_answer"],
"answer_options": answer_options,
}
)
response_data = response.json()
ai_answer = response_data.get("ai_answer", "Unknown response type")
st.session_state.total_tokens.append(
response_data.get("used_tokens", 0)
)
st.write(ai_answer)
st.session_state.messages.append(
{"role": "assistant", "content": ai_answer}
)
else:
message_placeholder.error(response.json().get("detail"))

for question in st.session_state["question_data"]:
st.markdown(f"##### Question: {question['question_txt']}")
user_answer = st.radio(
"Select an answer:",
["Please Select an answer:", *question["answer_options"]],
label_visibility="collapsed",
)

if user_answer == question["correct_answer"]:
st.session_state.score += 1
st.error(f"Error: {response.status_code} - {response.text}")

if st.session_state["score"] > 0:
message_placeholder.success(
f"You answered {st.session_state.score} questions correct!"
)
if not st.session_state["question_data"]:
# cfd = pathlib.Path(__file__).parent
image = Image.open(cfd / "static" / "Hippo.png")
_, center, _ = st.columns((2, 4, 2))
elif len(st.session_state.messages) == 0:
with messages:
_, center, _ = st.columns((1, 5, 1))
image = Image.open(cfd / "static" / "main_picture3.png")
center.image(
image,
caption=None,
use_column_width=True,
)


@register_page(PAGE_REGISTRY_DICT)
def quizme():
if st.session_state["chat_mode"] == "database":
st.markdown(
"""Sorry, you are in database mode, no quiz available.
Please upload a text or give a url to a webpage to generate a quiz."""
)
else:
st.markdown("### A Quiz for You")
st.session_state.score = 0
message_placeholder = st.empty()
if st.button("Generate a Quiz"):
response = make_get_request("quiz")
if response.status_code == 200:
for question in response.json().get("questions"):
answer_options = [
question["correct_answer"],
question["wrong_answer_1"],
question["wrong_answer_2"],
]
random.shuffle(answer_options)
st.session_state["question_data"].append(
{
"question_txt": question["question"],
"correct_answer": question["correct_answer"],
"answer_options": answer_options,
}
)
else:
message_placeholder.error(response.json().get("detail"))

for question in st.session_state["question_data"]:
st.markdown(f"##### Question: {question['question_txt']}")
user_answer = st.radio(
"Select an answer:",
["Please Select an answer:", *question["answer_options"]],
label_visibility="collapsed",
)

if user_answer == question["correct_answer"]:
st.session_state.score += 1

if st.session_state["score"] > 0:
message_placeholder.success(
f"You answered {st.session_state.score} questions correct!"
)
if not st.session_state["question_data"]:
image = Image.open(cfd / "static" / "Hippo.png")
_, center, _ = st.columns([1, 3, 1])
center.image(
image,
caption=None,
use_column_width=True,
)


@register_page(PAGE_REGISTRY_DICT)
def statistics():
import pandas as pd

with st.container():
st.markdown("### Used API Tokens per Request of your Current Session")
st.markdown("(Usage not persisted)")
_, center, _ = st.columns((1, 5, 1))
chart_data = pd.DataFrame({"Used API Tokens": st.session_state.total_tokens})
center.bar_chart(
data=chart_data,
color="#D3DCE5",
y="Used API Tokens",
use_container_width=False,
)
st.markdown("### Used API Tokens per Request of your Current Session")
st.markdown("(Usage not persisted)")
_, center, _ = st.columns((1, 5, 1))
chart_data = pd.DataFrame({"Used API Tokens": st.session_state.total_tokens})
center.bar_chart(
data=chart_data,
color="#D3DCE5",
y="Used API Tokens",
use_container_width=False,
)


def post_ai_message_to_chat(message, document_category):
Expand Down

0 comments on commit 294296f

Please sign in to comment.