Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:streaming support #13

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions ovos_persona/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
from os.path import join, dirname
from typing import Optional, Dict, List, Union
from typing import Optional, Dict, List, Union, Iterable

from ovos_bus_client.client import MessageBusClient
from ovos_bus_client.message import Message, dig_for_message
Expand Down Expand Up @@ -51,6 +51,9 @@ def __repr__(self):
def chat(self, messages: list = None, lang: str = None) -> str:
return self.solvers.chat_completion(messages, lang)

def stream(self, messages: list = None, lang: str = None) -> Iterable[str]:
return self.solvers.stream_completion(messages, lang)


class PersonaService(PipelineStageConfidenceMatcher, OVOSAbstractApplication):

Expand Down Expand Up @@ -143,7 +146,8 @@ def deregister_persona(self, name):
def chatbox_ask(self, prompt: str,
persona: Optional[str] = None,
lang: Optional[str] = None,
message: Message = None) -> Optional[str]:
message: Message = None,
stream: bool = True) -> Iterable[str]:
persona = persona or self.active_persona or self.default_persona
if persona not in self.personas:
LOG.error(f"unknown persona, choose one of {self.personas.keys()}")
Expand All @@ -155,8 +159,12 @@ def chatbox_ask(self, prompt: str,
messages.append({"role": "user", "content": q})
messages.append({"role": "assistant", "content": a})
messages.append({"role": "user", "content": prompt})

return self.personas[persona].chat(messages, lang)
if stream:
yield from self.personas[persona].stream(messages, lang)
else:
ans = self.personas[persona].chat(messages, lang)
if ans:
yield ans

def _build_msg_history(self, message: Message):
sess = SessionManager.get(message)
Expand Down Expand Up @@ -249,7 +257,7 @@ def match_low(self, utterances: List[str], lang: Optional[str] = None,
return IntentHandlerMatch(match_type='persona:query',
match_data={"utterance": utterances[0],
"lang": lang,
"persona": self.active_persona},
"persona": self.active_persona or self.default_persona},
skill_id="persona.openvoiceos",
utterance=utterances[0])

Expand All @@ -276,12 +284,12 @@ def handle_persona_query(self, message):
self.speak_dialog("unknown_persona", {"persona": persona})
return

# TODO - streaming support
ans = self.chatbox_ask(utt, lang=lang, persona=persona)
if not ans:
self.speak_dialog("persona_error")
else:
handled = False
for ans in self.chatbox_ask(utt, lang=lang, persona=persona):
self.speak(ans)
handled = True
if not handled:
self.speak_dialog("persona_error")

def handle_persona_summon(self, message):
persona = message.data["persona"]
Expand All @@ -303,7 +311,8 @@ def handle_persona_release(self, message):
print(b.personas)

print(b.match_low(["what is the speed of light"]))

for ans in b.chatbox_ask("what is the speed of light"):
print(ans)
# The speed of light has a value of about 300 million meters per second
# The telephone was invented by Alexander Graham Bell
# Stephen William Hawking (8 January 1942 – 14 March 2018) was an English theoretical physicist, cosmologist, and author who, at the time of his death, was director of research at the Centre for Theoretical Cosmology at the University of Cambridge.
Expand Down
41 changes: 16 additions & 25 deletions ovos_persona/solvers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Optional, List, Dict
from typing import Optional, List, Dict, Iterable

from ovos_config import Configuration
from ovos_plugin_manager.solvers import find_question_solver_plugins
from ovos_utils.log import LOG
from ovos_utils.messagebus import FakeBus

try:
from ovos_plugin_manager.solvers import find_chat_solver_plugins
from ovos_plugin_manager.templates.solvers import ChatMessageSolver
Expand All @@ -12,6 +13,7 @@
class ChatMessageSolver:
pass


def find_chat_solver_plugins():
return {}

Expand Down Expand Up @@ -60,42 +62,31 @@ def shutdown(self):
pass

def chat_completion(self, messages: List[Dict[str, str]],
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
for module in self.modules:
try:
if isinstance(module, ChatMessageSolver):
ans = module.get_chat_completion(messages=messages,
lang=lang)
ans = module.get_chat_completion(messages=messages, lang=lang, units=units)
else:
LOG.debug(f"{module} does not supported chat history!")
query = messages[-1]["content"]
ans = module.spoken_answer(query, lang=lang)
ans = module.spoken_answer(query, lang=lang, units=units)
if ans:
return ans
except Exception as e:
LOG.error(e)
pass

def spoken_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
"""
Obtain the spoken answer for a given query.

Args:
query (str): The query text.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
str: The spoken answer as a text response.
"""
def stream_completion(self, messages: List[Dict[str, str]],
lang: Optional[str] = None,
units: Optional[str] = None) -> Iterable[str]:
for module in self.modules:
try:
ans = module.spoken_answer(query, lang=lang)
if ans:
return ans
if isinstance(module, ChatMessageSolver):
yield from module.stream_chat_utterances(messages=messages, lang=lang, units=units)
else:
LOG.debug(f"{module} does not supported chat history!")
query = messages[-1]["content"]
yield from module.stream_utterances(query, lang=lang, units=units)
except Exception as e:
LOG.error(e)
pass
Loading