Skip to content

Commit

Permalink
common chainlit codes moved to another file
Browse files Browse the repository at this point in the history
  • Loading branch information
ilkersigirci committed Jan 4, 2025
1 parent 4da434d commit 0d6455e
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 117 deletions.
2 changes: 1 addition & 1 deletion src/podflix/graph/podcast_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def generate(state: AgentState) -> AgentState:
response = await chain.ainvoke({"context": context, "question": question})

return {
"messages": [*state["messages"], AIMessage(content=response)],
"messages": [AIMessage(content=response)],
}


Expand Down
82 changes: 22 additions & 60 deletions src/podflix/gui/audio.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import webbrowser # noqa: F401
from dataclasses import dataclass # noqa: F401
from pathlib import Path
from uuid import uuid4

import chainlit as cl
import chainlit.data as cl_data
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.user import PersistedUser, User
from chainlit.types import ThreadDict
from langchain.schema.runnable.config import RunnableConfig
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.messages import AIMessageChunk
Expand All @@ -16,11 +15,12 @@

from podflix.db.db_factory import DBInterfaceFactory
from podflix.graph.podcast_rag import compiled_graph
from podflix.utils.general import (
check_lf_credentials,
get_lf_session_url,
get_lf_traces_url,
from podflix.utils.chainlit_ui import (
create_message_history_from_db_thread,
set_extra_user_session_params,
simple_auth_callback,
)
from podflix.utils.general import get_lf_traces_url
from podflix.utils.model import transcribe_audio_file

cl_data._data_layer = SQLAlchemyDataLayer(
Expand All @@ -30,49 +30,15 @@
)


Chainlit_User_Type = User | PersistedUser


# TODO: Set starters based on audio file
# @dataclass
# class StartQuestions:
# label: str
# message: str
# icon: str | None = None


# mock_starters = [
# StartQuestions(
# label="Start",
# message="Start the conversation",
# icon="🚀",
# ),
# StartQuestions(
# label="Middle",
# message="Middle the conversation",
# icon="🚀",
# ),
# ]

# @cl.set_starters
# async def set_starters() -> list[cl.Starter]:
# return [
# cl.Starter(
# label=mock_starter.label,
# message=mock_starter.message,
# icon=mock_starter.icon,
# )
# for mock_starter in mock_starters
# ]
# pass


@cl.password_auth_callback
def auth_callback(username: str, password: str):
if (username, password) == ("admin", "admin"):
return cl.User(
identifier="admin", metadata={"role": "admin", "provider": "credentials"}
)
return None
return simple_auth_callback(username, password)


# @cl.action_callback("Detailed Traces")
Expand All @@ -90,24 +56,7 @@ def auth_callback(username: str, password: str):

@cl.on_chat_start
async def on_chat_start():
session_id = str(uuid4())
chainlit_user: Chainlit_User_Type = cl.user_session.get("user")
chainlit_user_id = chainlit_user.identifier
message_history = ChatMessageHistory()

check_lf_credentials()
lf_cb_handler = LangfuseCallbackHandler(
user_id=chainlit_user_id,
session_id=session_id,
)

cl.user_session.set("lf_cb_handler", lf_cb_handler)
cl.user_session.set("session_id", session_id)
cl.user_session.set("message_history", message_history)

langfuse_session_url = get_lf_session_url(session_id=session_id)

logger.debug(f"Langfuse Session URL: {langfuse_session_url}")
set_extra_user_session_params()

system_message = cl.Message(
content=" ",
Expand Down Expand Up @@ -151,6 +100,19 @@ async def on_chat_start():
await system_message.update()


@cl.on_chat_resume
def setup_chat_resume(thread: ThreadDict):
thread["metadata"] = {}
message_history = create_message_history_from_db_thread(thread=thread)

set_extra_user_session_params(
user_id=thread["userIdentifier"], message_history=message_history
)

# TODO: Set audio_text from the db
# cl.user_session.set("audio_text", audio_text)


@cl.on_message
async def on_message(msg: cl.Message):
lf_cb_handler: LangfuseCallbackHandler = cl.user_session.get("lf_cb_handler")
Expand Down
85 changes: 29 additions & 56 deletions src/podflix/gui/mock.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import webbrowser
from dataclasses import dataclass
from uuid import uuid4

import chainlit as cl
import chainlit.data as cl_data
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.types import ThreadDict
from chainlit.user import PersistedUser, User
from langchain.schema.runnable.config import RunnableConfig
from langchain_core.messages import AIMessageChunk, HumanMessage
from langfuse.callback import CallbackHandler as LangfuseCallbackHandler
from literalai.helper import utc_now
from loguru import logger

from podflix.db.db_factory import DBInterfaceFactory
from podflix.graph.mock import compiled_graph
from podflix.utils.general import (
check_lf_credentials,
get_lf_session_url,
get_lf_traces_url,
from podflix.utils.chainlit_ui import (
create_message_history_from_db_thread,
set_extra_user_session_params,
simple_auth_callback,
)
from podflix.utils.general import get_lf_traces_url

cl_data._data_layer = SQLAlchemyDataLayer(
DBInterfaceFactory.create().async_connection(),
Expand All @@ -30,46 +29,25 @@
Chainlit_User_Type = User | PersistedUser


@dataclass
class StartQuestions:
label: str
message: str
icon: str | None = None


mock_starters = [
StartQuestions(
label="Start",
message="Start the conversation",
icon="🚀",
),
StartQuestions(
label="Middle",
message="Middle the conversation",
icon="🚀",
),
]


@cl.set_starters
async def set_starters() -> list[cl.Starter]:
return [
cl.Starter(
label=mock_starter.label,
message=mock_starter.message,
icon=mock_starter.icon,
)
for mock_starter in mock_starters
]
# @cl.set_starters
# async def set_starters() -> list[cl.Starter]:
# mock_starters = [
# StartQuestions(
# label="Start",
# message="Start the conversation",
# icon="🚀",
# ),
# StartQuestions(
# label="Middle",
# message="Middle the conversation",
# icon="🚀",
# ),
# ]


@cl.password_auth_callback
def auth_callback(username: str, password: str):
if (username, password) == ("admin", "admin"):
return cl.User(
identifier="admin", metadata={"role": "admin", "provider": "credentials"}
)
return None
return simple_auth_callback(username, password)


@cl.action_callback("Detailed Traces")
Expand All @@ -82,22 +60,17 @@ async def on_action(action: cl.Action):

@cl.on_chat_start
async def on_chat_start():
session_id = str(uuid4())
chainlit_user: Chainlit_User_Type = cl.user_session.get("user")
chainlit_user_id = chainlit_user.identifier

check_lf_credentials()
lf_cb_handler = LangfuseCallbackHandler(
user_id=chainlit_user_id,
session_id=session_id,
)
set_extra_user_session_params()

cl.user_session.set("lf_cb_handler", lf_cb_handler)
cl.user_session.set("session_id", session_id)

langfuse_session_url = get_lf_session_url(session_id=session_id)
@cl.on_chat_resume
def setup_chat_resume(thread: ThreadDict):
thread["metadata"] = {}
message_history = create_message_history_from_db_thread(thread=thread)

logger.debug(f"Langfuse Session URL: {langfuse_session_url}")
set_extra_user_session_params(
user_id=thread["userIdentifier"], message_history=message_history
)


@cl.on_message
Expand Down
Loading

0 comments on commit 0d6455e

Please sign in to comment.