Skip to content

Commit

Permalink
feat: integrate chat functionality with ChatTask and ChatTaskWorker f…
Browse files Browse the repository at this point in the history
…or enhanced user interaction
  • Loading branch information
provos committed Jan 30, 2025
1 parent a58ab6a commit 2eb2a88
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 13 deletions.
33 changes: 24 additions & 9 deletions examples/deepsearch/deepsearch/deepsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from llm_interface import ListResponse, llm_from_config
from session import SessionManager

from planai import ProvenanceChain, Task, TaskWorker
from planai import ChatTask, ProvenanceChain, Task, TaskWorker
from planai.utils import setup_logging

app = Flask(__name__)
Expand All @@ -51,7 +51,8 @@
graph_thread = None
should_stop = False
graph = None
entry_worker = None
entry_worker: TaskWorker = None # Executes the whole plan
chat_worker: TaskWorker = None # Allows the user to just chat with the AI assistant

# Add new global settings variable
current_settings = {
Expand Down Expand Up @@ -156,13 +157,13 @@ def start_graph_thread(
provider: str = "ollama", model: str = "llama2", host: str = "localhost:11434"
):
"""Modified to use current settings."""
global graph_thread, graph, entry_worker, debug_saver, current_settings
global graph_thread, graph, entry_worker, chat_worker, debug_saver, current_settings

# Update current settings
current_settings["provider"] = provider
current_settings["model"] = model

graph, entry_worker = setup_graph(
graph, entry_worker, chat_worker = setup_graph(
provider=provider, model=model, host=host, notify=notify
)

Expand Down Expand Up @@ -272,7 +273,7 @@ def handle_abort(data):
@socketio.on("chat_message")
def handle_message(data):
session_id = data.get("session_id")
message = data.get("message")
messages = data.get("messages", []) # Get full message history

if not session_id or session_manager.get_session(session_id) is None:
print(f"Invalid session ID: {session_id}")
Expand All @@ -286,16 +287,33 @@ def handle_message(data):
emit("error", "Session ID does not match current connection")
return

print(f'Received message: "{message}" from session: {session_id}')
print(f'Received messages: "{messages}" from session: {session_id}')

# Capture the request.sid (SocketIO SID)
sid = request.sid
current_metadata = {"session_id": session_id, "sid": sid}

# Update session timestamp on activity
session_manager.update_session_timestamp(session_id)

# record that we started
session_metadata = session_manager.metadata(session_id)
session_metadata["started"] = True

# XXX - this is a hack
if len(messages) > 1:
global chat_worker
task = ChatTask(messages=messages)
provenance = chat_worker.add_work(
task,
metadata={"session_id": session_id, "sid": sid},
)
session_metadata["provenance"] = provenance
return

# there is only one message:
message = messages[0].get("content", "")

# If in replay mode, trigger replay with current session
global debug_saver
if debug_saver:
Expand All @@ -305,9 +323,6 @@ def handle_message(data):
else:
debug_saver.save_prompt(session_id, message.strip())

# Update session timestamp on activity
session_manager.update_session_timestamp(session_id)

def wrapped_notify_planai(*args, **kwargs):
return notify_planai(*args, **kwargs)

Expand Down
34 changes: 32 additions & 2 deletions examples/deepsearch/deepsearch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from planai import (
CachedLLMTaskWorker,
ChatMessage,
ChatTaskWorker,
Graph,
InitialTaskWorker,
JoinedTaskWorker,
Expand Down Expand Up @@ -315,6 +317,17 @@ def post_process(self, response: FinalWriteup, input_task: Response):
)


class UserChat(ChatTaskWorker):
pass


class ChatAdapter(TaskWorker):
output_types: List[Type[Task]] = [Response]

def consume_work(self, task: ChatMessage):
self.publish_work(Response(response_type="final", message=task.content), task)


class ResponsePublisher(TaskWorker):
"""Re-iterates the response to the user, so that we can use a sink to notify the user on thinking updates"""

Expand All @@ -329,14 +342,23 @@ def setup_graph(
model: str = "llama3.3:latest",
host: str = "localhost:11434",
notify: Optional[Callable[Dict[str, Any], None]] = None,
) -> Tuple[Graph, TaskWorker]:
) -> Tuple[Graph, TaskWorker, TaskWorker]:
llm = llm_from_config(
provider=provider,
model_name=model,
host=host,
use_cache=False,
)

llm_chat = llm_from_config(
provider=provider,
model_name=model,
host=host,
use_cache=False,
)
llm_chat.support_json_mode = False
llm_chat.support_structured_outputs = False

graph = Graph(name="Plan Graph")
plan_worker = PlanWorker(llm=llm)
search_worker = SearchCreator(llm=llm)
Expand All @@ -345,6 +367,10 @@ def setup_graph(
analysis_worker = SearchSummarizer(llm=llm)
analysis_joiner = AnalysisJoiner()
final_narrative_worker = FinalNarrativeWorker(llm=llm)

chat_worker = UserChat(llm=llm_chat)
chat_adapter = ChatAdapter()

response_publisher = ResponsePublisher()
graph.add_workers(
plan_worker,
Expand All @@ -355,16 +381,20 @@ def setup_graph(
analysis_worker,
analysis_joiner,
final_narrative_worker,
chat_worker,
chat_adapter,
)
graph.set_dependency(plan_worker, response_publisher)
graph.set_dependency(plan_worker, search_worker).next(split_worker).next(
search_fetch_worker
).next(analysis_worker).next(analysis_joiner).next(final_narrative_worker).next(
response_publisher
)
graph.set_dependency(chat_worker, chat_adapter).next(response_publisher)
graph.set_entry(plan_worker)
graph.set_entry(chat_worker)
graph.set_sink(response_publisher, Response, notify=notify)

# limit the amount of LLM calls we will do in parallel
graph.set_max_parallel_tasks(LLMTaskWorker, 2 if provider == "ollama" else 6)
return graph, plan_worker
return graph, plan_worker, chat_worker
2 changes: 1 addition & 1 deletion examples/deepsearch/frontend/src/app.css
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
}

.audio-player-wrapper {
@apply fixed bottom-0 left-16 right-0 bg-white dark:bg-gray-800 border-t border-gray-200 dark:border-gray-700 p-4; /* Adjust left position */
@apply fixed bottom-0 left-16 right-0 bg-white dark:bg-gray-800 border-t border-gray-200 dark:border-gray-700 p-4;
}

.audio-player {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,16 @@ Outgoing Events (sent):
};
messages = [...messages, userMessage];
// Convert messages to simplified format for backend
const messageHistory = messages.map(msg => ({
role: msg.role,
content: msg.content
}));
sessionState.socket?.emit('chat_message', {
session_id: sessionState.sessionId,
message: message
messages: messageHistory
});
}
Expand Down

0 comments on commit 2eb2a88

Please sign in to comment.