Skip to content

Commit

Permalink
Close #1621. Add a normalizer function argument to Chat.append_messag…
Browse files Browse the repository at this point in the history
…e_stream()
  • Loading branch information
cpsievert committed Aug 21, 2024
1 parent c97f09b commit eb8a1ff
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions shiny/ui/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,12 @@ async def append_message(self, message: Any) -> None:
await self._append_message(message)

async def _append_message(
self, message: Any, *, chunk: ChunkOption = False, stream_id: str | None = None
self,
message: Any,
*,
chunk: ChunkOption = False,
stream_id: str | None = None,
normalizer: Callable[[object], str] | None = None,
) -> None:
# If currently we're in a stream, handle other messages (outside the stream) later
if not self._can_append_message(stream_id):
Expand All @@ -519,6 +524,15 @@ async def _append_message(
if chunk == "end":
self._current_stream_id = None

# Apply the user provided normalizer, if any
if normalizer is not None:
res = normalizer(message)
if not isinstance(res, str):
raise ValueError(
f"Normalizer function must return a string, got {type(res)}"
)
message = {"content": res, "role": "assistant"}

if chunk is False:
msg = normalize_message(message)
chunk_content = None
Expand All @@ -539,7 +553,11 @@ async def _append_message(
msg = self._store_message(msg, chunk=chunk)
await self._send_append_message(msg, chunk=chunk)

async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any]):
async def append_message_stream(
self,
message: Iterable[Any] | AsyncIterable[Any],
normalizer: Callable[[object], str] | None = None,
) -> None:
"""
Append a message as a stream of message chunks.
Expand All @@ -550,6 +568,11 @@ async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any
message chunk formats are supported, including a string, a dictionary with
`content` and `role` keys, or a relevant chat completion object from
platforms like OpenAI, Anthropic, Ollama, and others.
normalizer
A function to apply to each message chunk (i.e., each item of the `message`
iterator) before appending it to the chat. This is useful for handling
response formats that `Chat` may not already natively support. The function
should take a message chunk and return a string.
Note
----
Expand All @@ -562,7 +585,7 @@ async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any
# Run the stream in the background to get non-blocking behavior
@reactive.extended_task
async def _stream_task():
await self._append_message_stream(message)
await self._append_message_stream(message, normalizer)

_stream_task()

Expand All @@ -582,15 +605,21 @@ async def _handle_error():
ctx.on_invalidate(_handle_error.destroy)
self._effects.append(_handle_error)

async def _append_message_stream(self, message: AsyncIterable[Any]):
async def _append_message_stream(
self,
message: AsyncIterable[Any],
normalizer: Callable[[object], str] | None = None,
) -> None:
id = _utils.private_random_id()

empty = ChatMessage(content="", role="assistant")
await self._append_message(empty, chunk="start", stream_id=id)

try:
async for msg in message:
await self._append_message(msg, chunk=True, stream_id=id)
await self._append_message(
msg, chunk=True, stream_id=id, normalizer=normalizer
)
finally:
await self._append_message(empty, chunk="end", stream_id=id)
await self._flush_pending_messages()
Expand Down

0 comments on commit eb8a1ff

Please sign in to comment.