diff --git a/frontend/streamlit_app.py b/frontend/streamlit_app.py index 20daff0..10067b8 100644 --- a/frontend/streamlit_app.py +++ b/frontend/streamlit_app.py @@ -8,10 +8,8 @@ from streamlit.runtime.uploaded_file_manager import UploadedFile -# from fastapi import UploadFile import streamlit as st from streamlit_option_menu import option_menu -from streamlit_extras.add_vertical_space import add_vertical_space from streamlit_extras.stylable_container import stylable_container import requests @@ -100,7 +98,7 @@ def initialize_session(refresh_session=False): if "chat_mode" not in st.session_state: st.session_state["chat_mode"] = "" if "selected_page" not in st.session_state: - st.session_state["selected_page"] = 0 + st.session_state["selected_page"] = "questionai" if "redirect_page" not in st.session_state: st.session_state["redirect_page"] = None if "file_uploader_key" not in st.session_state: @@ -110,7 +108,6 @@ def initialize_session(refresh_session=False): def clear_history(): - # st.session_state.messages.clear() initialize_session(refresh_session=True) response = requests.get(os.path.join(API_URL, "clear_history")) st.session_state["redirect_page"] = 0 @@ -128,6 +125,7 @@ def clear_storage(): st.session_state["file_uploader_key"] += 1 st.session_state["url_uploader_key"] += 1 st.session_state["redirect_page"] = 0 + st.session_state["selected_page"] = "questionai" # st.session_state["url_input"]="" clear_history() response = requests.get(os.path.join(API_URL, "clear_storage")) @@ -199,7 +197,6 @@ def post_data_to_backend( response_data = response.json() logging.info(f"upload response data: {response_data}") # st.session_state.counter += 1 - # print(response_data["summary"], st.session_state.counter) post_ai_message_to_chat( response_data.get("summary", "Unknown response"), response_data.get("text_category"), @@ -208,10 +205,7 @@ def post_data_to_backend( response_data.get("used_tokens", 0) ) else: - st.sidebar.error( - f"Error: {response.status_code} - {response}" - # .json().get("msg")}' - ) + st.sidebar.error(f"Error: {response.status_code} - {response}") except FileNotFoundError: st.sidebar.error( "No context is given. Please provide a url or upload a file" @@ -234,19 +228,21 @@ def url_callback(): def display_sidemenu(): st.sidebar.title("Menu") - st.sidebar.markdown( - """ - Please uploade your file or enter a url. Supported file types: - - - txt - as upload - - pdf - as upload - - website - as url - - sqlite - as upload - """ - ) + sidebar_container = st.sidebar.container() + success_message = st.sidebar.empty() + + with sidebar_container: + st.markdown( + """ + Please uploade your file or enter a url. Supported types: + + - txt - as upload + - pdf - as upload + - website - as url + - sqlite - as upload + """ + ) - with st.sidebar.container(): - success_message = st.empty() if st.file_uploader( "dragndrop", type=["txt", "pdf", "sqlite"], @@ -286,129 +282,141 @@ def display_sidemenu(): # st.rerun() if st.button("Clear knowledge base: texts/ urls", use_container_width=True): success_message.success(clear_storage()) - st.experimental_rerun() + st.rerun() @register_page(PAGE_REGISTRY_DICT) def questionai(): with st.container(): + messages = stylable_container( + key="message_container", + css_styles=[ + """ + { + padding: 0.5em; + min-height: 65vh; + overflow-y: scroll; + } + """, + """ + .stMarkdown { + padding-right: 1.5em; + } + """, + ], + ) + for message in st.session_state.messages: - # image = "xxx.png" if message["role"] == "user" else "xxx.png" - with st.chat_message(message["role"]): - st.markdown(message["content"]) + messages.chat_message(message["role"]).write(message["content"]) if prompt := st.chat_input("-> Your Question ..."): st.session_state.messages.append({"role": "user", "content": prompt}) - with st.chat_message("user"): - st.markdown(prompt) + messages.chat_message("user").write(prompt) - with st.spinner("Waiting for response"): + with messages.chat_message("assistant"), st.spinner( + "Waiting for Response ..." + ): payload = { "prompt": prompt, "temperature": st.session_state.temperature, } response = requests.post(os.path.join(API_URL, "qa_text"), json=payload) - with st.chat_message("assistant"): - message_placeholder = st.empty() - ai_answer = "" - if response.status_code == 200: - response_data = response.json() - ai_answer = response_data.get( - "ai_answer", "Unknown response type" - ) - st.session_state.total_tokens.append( - response_data.get("used_tokens", 0) - ) - message_placeholder.markdown(ai_answer) - st.session_state.messages.append( - {"role": "assistant", "content": ai_answer} - ) - else: - st.error(f"Error: {response.status_code} - {response.text}") - message_placeholder.markdown(ai_answer) - add_vertical_space(7) - elif len(st.session_state.messages) == 0: - # cfd = pathlib.Path(__file__).parent - image = Image.open(cfd / "static" / "main_picture3.png") - _, center, _ = st.columns((1, 5, 1)) - center.image( - image, - caption=None, - ) - + ai_answer = "" -@register_page(PAGE_REGISTRY_DICT) -def quizme(): - with st.container(): - if st.session_state["chat_mode"] == "database": - st.markdown( - """Sorry, you are in database mode, no quiz available. - Please upload a text or give a url to a webpage to generate a quiz.""" - ) - else: - st.markdown("### A Quiz for You") - st.session_state.score = 0 - message_placeholder = st.empty() - if st.button("Generate a Quiz"): - response = make_get_request("quiz") if response.status_code == 200: - for question in response.json().get("questions"): - answer_options = [ - question["correct_answer"], - question["wrong_answer_1"], - question["wrong_answer_2"], - ] - random.shuffle(answer_options) - st.session_state["question_data"].append( - { - "question_txt": question["question"], - "correct_answer": question["correct_answer"], - "answer_options": answer_options, - } - ) + response_data = response.json() + ai_answer = response_data.get("ai_answer", "Unknown response type") + st.session_state.total_tokens.append( + response_data.get("used_tokens", 0) + ) + st.write(ai_answer) + st.session_state.messages.append( + {"role": "assistant", "content": ai_answer} + ) else: - message_placeholder.error(response.json().get("detail")) - - for question in st.session_state["question_data"]: - st.markdown(f"##### Question: {question['question_txt']}") - user_answer = st.radio( - "Select an answer:", - ["Please Select an answer:", *question["answer_options"]], - label_visibility="collapsed", - ) - - if user_answer == question["correct_answer"]: - st.session_state.score += 1 + st.error(f"Error: {response.status_code} - {response.text}") - if st.session_state["score"] > 0: - message_placeholder.success( - f"You answered {st.session_state.score} questions correct!" - ) - if not st.session_state["question_data"]: - # cfd = pathlib.Path(__file__).parent - image = Image.open(cfd / "static" / "Hippo.png") - _, center, _ = st.columns((2, 4, 2)) + elif len(st.session_state.messages) == 0: + with messages: + _, center, _ = st.columns((1, 5, 1)) + image = Image.open(cfd / "static" / "main_picture3.png") center.image( image, caption=None, + use_column_width=True, ) +@register_page(PAGE_REGISTRY_DICT) +def quizme(): + if st.session_state["chat_mode"] == "database": + st.markdown( + """Sorry, you are in database mode, no quiz available. + Please upload a text or give a url to a webpage to generate a quiz.""" + ) + else: + st.markdown("### A Quiz for You") + st.session_state.score = 0 + message_placeholder = st.empty() + if st.button("Generate a Quiz"): + response = make_get_request("quiz") + if response.status_code == 200: + for question in response.json().get("questions"): + answer_options = [ + question["correct_answer"], + question["wrong_answer_1"], + question["wrong_answer_2"], + ] + random.shuffle(answer_options) + st.session_state["question_data"].append( + { + "question_txt": question["question"], + "correct_answer": question["correct_answer"], + "answer_options": answer_options, + } + ) + else: + message_placeholder.error(response.json().get("detail")) + + for question in st.session_state["question_data"]: + st.markdown(f"##### Question: {question['question_txt']}") + user_answer = st.radio( + "Select an answer:", + ["Please Select an answer:", *question["answer_options"]], + label_visibility="collapsed", + ) + + if user_answer == question["correct_answer"]: + st.session_state.score += 1 + + if st.session_state["score"] > 0: + message_placeholder.success( + f"You answered {st.session_state.score} questions correct!" + ) + if not st.session_state["question_data"]: + image = Image.open(cfd / "static" / "Hippo.png") + _, center, _ = st.columns([1, 3, 1]) + center.image( + image, + caption=None, + use_column_width=True, + ) + + @register_page(PAGE_REGISTRY_DICT) def statistics(): import pandas as pd - with st.container(): - st.markdown("### Used API Tokens per Request of your Current Session") - st.markdown("(Usage not persisted)") - _, center, _ = st.columns((1, 5, 1)) - chart_data = pd.DataFrame({"Used API Tokens": st.session_state.total_tokens}) - center.bar_chart( - data=chart_data, - color="#D3DCE5", - y="Used API Tokens", - use_container_width=False, - ) + st.markdown("### Used API Tokens per Request of your Current Session") + st.markdown("(Usage not persisted)") + _, center, _ = st.columns((1, 5, 1)) + chart_data = pd.DataFrame({"Used API Tokens": st.session_state.total_tokens}) + center.bar_chart( + data=chart_data, + color="#D3DCE5", + y="Used API Tokens", + use_container_width=False, + ) def post_ai_message_to_chat(message, document_category):