Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: vastly improve chat UI responsiveness by reordering Gradio events #360

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 57 additions & 167 deletions libs/ktem/ktem/pages/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import asyncio
import csv
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Optional

import gradio as gr
from filelock import FileLock
from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
Expand Down Expand Up @@ -38,6 +34,7 @@
for (var i = 0; i < links.length; i++) {
links[i].onclick = openModal;
}
return [links.length]
}
"""

Expand All @@ -48,19 +45,18 @@ def __init__(self, app):
self._indices_input = []

self.on_building_ui()

self._preview_links = gr.State(value=None)
self._reasoning_type = gr.State(value=None)
self._llm_type = gr.State(value=None)
self._conversation_renamed = gr.State(value=False)
self.info_panel_expanded = gr.State(value=True)
self._info_panel_expanded = gr.State(value=True)

def on_building_ui(self):
with gr.Row():
self.state_chat = gr.State(STATE)
self.state_retrieval_history = gr.State([])
self.state_chat_history = gr.State([])
self.state_plot_history = gr.State([])
self.state_settings = gr.State({})
self.state_info_panel = gr.State("")
self.state_plot_panel = gr.State(None)

with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
Expand Down Expand Up @@ -203,37 +199,11 @@ def on_register_events(self):
],
concurrency_limit=20,
show_progress="minimal",
).success(
fn=self.backup_original_info,
inputs=[
self.chat_panel.chatbot,
self._app.settings_state,
self.info_panel,
self.state_chat_history,
],
outputs=[
self.state_chat_history,
self.state_settings,
self.state_info_panel,
],
).then(
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).success(
fn=self.check_and_suggest_name_conv,
inputs=self.chat_panel.chatbot,
Expand All @@ -256,7 +226,23 @@ def on_register_events(self):
],
show_progress="hidden",
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
)

self.chat_panel.regen_btn.click(
Expand All @@ -281,23 +267,10 @@ def on_register_events(self):
concurrency_limit=20,
show_progress="minimal",
).then(
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).success(
fn=self.check_and_suggest_name_conv,
inputs=self.chat_panel.chatbot,
Expand All @@ -320,37 +293,39 @@ def on_register_events(self):
],
show_progress="hidden",
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
)

self.chat_control.btn_info_expand.click(
fn=lambda is_expanded: (
gr.update(scale=INFO_PANEL_SCALES[is_expanded]),
not is_expanded,
),
inputs=self.info_panel_expanded,
outputs=[self.info_column, self.info_panel_expanded],
inputs=self._info_panel_expanded,
outputs=[self.info_column, self._info_panel_expanded],
)

self.chat_panel.chatbot.like(
fn=self.is_liked,
inputs=[self.chat_control.conversation_id],
outputs=None,
).success(
self.save_log,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.info_panel,
self.state_chat_history,
self.state_settings,
self.state_info_panel,
gr.State(getattr(flowsettings, "KH_APP_DATA_DIR", "logs")),
],
outputs=None,
)

self.chat_control.btn_new.click(
self.chat_control.new_conv,
inputs=self._app.user_id,
Expand Down Expand Up @@ -701,7 +676,15 @@ def is_liked(self, convo_id, liked: gr.LikeData):

def message_selected(self, retrieval_history, plot_history, msg: gr.SelectData):
index = msg.index[0]
return retrieval_history[index], plot_history[index]
try:
retrieval_content, plot_content = (
retrieval_history[index],
plot_history[index],
)
except IndexError:
retrieval_content, plot_content = gr.update(), None

return retrieval_content, plot_content

def create_pipeline(
self,
Expand Down Expand Up @@ -889,96 +872,3 @@ def check_and_suggest_name_conv(self, chat_history):
renamed = True

return new_name, renamed

def backup_original_info(
self, chat_history, settings, info_pannel, original_chat_history
):
original_chat_history.append(chat_history[-1])
return original_chat_history, settings, info_pannel

def save_log(
self,
conversation_id,
chat_history,
settings,
info_panel,
original_chat_history,
original_settings,
original_info_panel,
log_dir,
):
if not Path(log_dir).exists():
Path(log_dir).mkdir(parents=True)

lock = FileLock(Path(log_dir) / ".lock")
# get current date
today = datetime.now()
formatted_date = today.strftime("%d%m%Y_%H")

with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()

data_source = deepcopy(result.data_source)
likes = data_source.get("likes", [])
if not likes:
return

feedback = likes[-1][-1]
message_index = likes[-1][0]

current_message = chat_history[message_index[0]]
original_message = original_chat_history[message_index[0]]
is_original = all(
[
current_item == original_item
for current_item, original_item in zip(
current_message, original_message
)
]
)

dataframe = [
[
conversation_id,
message_index,
current_message[0],
current_message[1],
chat_history,
settings,
info_panel,
feedback,
is_original,
original_message[1],
original_chat_history,
original_settings,
original_info_panel,
]
]

with lock:
log_file = Path(log_dir) / f"{formatted_date}_log.csv"
is_log_file_exist = log_file.is_file()
with open(log_file, "a") as f:
writer = csv.writer(f)
# write headers
if not is_log_file_exist:
writer.writerow(
[
"Conversation ID",
"Message ID",
"Question",
"Answer",
"Chat History",
"Settings",
"Evidences",
"Feedback",
"Original/ Rewritten",
"Original Answer",
"Original Chat History",
"Original Settings",
"Original Evidences",
]
)

writer.writerows(dataframe)
Loading