diff --git a/shiny/ui/_chat.py b/shiny/ui/_chat.py index 99a399109..98f1defad 100644 --- a/shiny/ui/_chat.py +++ b/shiny/ui/_chat.py @@ -507,7 +507,12 @@ async def append_message(self, message: Any) -> None: await self._append_message(message) async def _append_message( - self, message: Any, *, chunk: ChunkOption = False, stream_id: str | None = None + self, + message: Any, + *, + chunk: ChunkOption = False, + stream_id: str | None = None, + normalizer: Callable[[object], str] | None = None, ) -> None: # If currently we're in a stream, handle other messages (outside the stream) later if not self._can_append_message(stream_id): @@ -519,6 +524,15 @@ async def _append_message( if chunk == "end": self._current_stream_id = None + # Apply the user provided normalizer, if any + if normalizer is not None: + res = normalizer(message) + if not isinstance(res, str): + raise ValueError( + f"Normalizer function must return a string, got {type(res)}" + ) + message = {"content": res, "role": "assistant"} + if chunk is False: msg = normalize_message(message) chunk_content = None @@ -539,7 +553,11 @@ async def _append_message( msg = self._store_message(msg, chunk=chunk) await self._send_append_message(msg, chunk=chunk) - async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any]): + async def append_message_stream( + self, + message: Iterable[Any] | AsyncIterable[Any], + normalizer: Callable[[object], str] | None = None, + ) -> None: """ Append a message as a stream of message chunks. @@ -550,6 +568,11 @@ async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any message chunk formats are supported, including a string, a dictionary with `content` and `role` keys, or a relevant chat completion object from platforms like OpenAI, Anthropic, Ollama, and others. + normalizer + A function to apply to each message chunk (i.e., each item of the `message` + iterator) before appending it to the chat. This is useful for handling + response formats that `Chat` may not already natively support. The function + should take a message chunk and return a string. Note ---- @@ -562,7 +585,7 @@ async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any # Run the stream in the background to get non-blocking behavior @reactive.extended_task async def _stream_task(): - await self._append_message_stream(message) + await self._append_message_stream(message, normalizer) _stream_task() @@ -582,7 +605,11 @@ async def _handle_error(): ctx.on_invalidate(_handle_error.destroy) self._effects.append(_handle_error) - async def _append_message_stream(self, message: AsyncIterable[Any]): + async def _append_message_stream( + self, + message: AsyncIterable[Any], + normalizer: Callable[[object], str] | None = None, + ) -> None: id = _utils.private_random_id() empty = ChatMessage(content="", role="assistant") @@ -590,7 +617,9 @@ async def _append_message_stream(self, message: AsyncIterable[Any]): try: async for msg in message: - await self._append_message(msg, chunk=True, stream_id=id) + await self._append_message( + msg, chunk=True, stream_id=id, normalizer=normalizer + ) finally: await self._append_message(empty, chunk="end", stream_id=id) await self._flush_pending_messages()