-
Notifications
You must be signed in to change notification settings - Fork 379
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
improvement: pretty gemini responses in output, streaming / non-strea…
…ming (#3705)
- Loading branch information
Showing
4 changed files
with
203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright 2024 Marimo. All rights reserved. | ||
from __future__ import annotations | ||
|
||
from marimo._messaging.mimetypes import KnownMimeType | ||
from marimo._output import md | ||
from marimo._output.formatters.formatter_factory import FormatterFactory | ||
from marimo._runtime import output | ||
|
||
|
||
class GoogleAiFormatter(FormatterFactory): | ||
@staticmethod | ||
def package_name() -> str: | ||
return "google" | ||
|
||
def register(self) -> None: | ||
try: | ||
import google.generativeai as genai # type: ignore | ||
except (ImportError, ModuleNotFoundError): | ||
return | ||
|
||
from marimo._output import formatting | ||
|
||
@formatting.formatter(genai.types.GenerateContentResponse) | ||
def _show_response( | ||
response: genai.types.GenerateContentResponse, | ||
) -> tuple[KnownMimeType, str]: | ||
if hasattr(response, "_iterator") and response._iterator is None: | ||
return ("text/html", md.md(response.text).text) | ||
else: | ||
# Streaming response | ||
total_text = "" | ||
for chunk in response: | ||
total_text += chunk.text | ||
output.replace( | ||
md.md(_ensure_closing_code_fence(total_text)) | ||
) | ||
return ("text/html", md.md(total_text).text) | ||
|
||
|
||
def _ensure_closing_code_fence(text: str) -> str: | ||
"""Ensure text has an even number of code fences | ||
If text ends with an unclosed code fence, add a closing fence. | ||
Handles nested code fences by checking if the last fence is an opening one. | ||
""" | ||
# Split by code fences to track nesting | ||
parts = text.split("```") | ||
# If odd number of parts, we have an unclosed fence | ||
# parts = ["before", "code", "between", "more code"] -> 4 parts = 3 fences | ||
# parts = ["before", "code", "between", "more code", "after"] -> 5 parts = 4 fences | ||
if len(parts) > 1 and len(parts) % 2 == 0: | ||
return text + "\n```" | ||
return text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import marimo | ||
|
||
__generated_with = "0.11.0" | ||
app = marimo.App(width="medium") | ||
|
||
|
||
@app.cell | ||
def _(): | ||
import marimo as mo | ||
import os | ||
return mo, os | ||
|
||
|
||
@app.cell(hide_code=True) | ||
def _(mo): | ||
mo.md(r"""## Setup""") | ||
return | ||
|
||
|
||
@app.cell(hide_code=True) | ||
def _(mo, os): | ||
api_token = mo.ui.text( | ||
label="Gemini API Token", value=os.environ.get("GEMINI_API_KEY") or "" | ||
) | ||
api_token | ||
return (api_token,) | ||
|
||
|
||
@app.cell(hide_code=True) | ||
def _(api_token, mo): | ||
mo.callout("Missing API Key", kind="danger") if not api_token.value else None | ||
return | ||
|
||
|
||
@app.cell(hide_code=True) | ||
def _(mo): | ||
tools = mo.ui.multiselect(["code_execution"], value=[], label="Tools") | ||
tools | ||
return (tools,) | ||
|
||
|
||
@app.cell(hide_code=True) | ||
def _(mo): | ||
mo.md("""## Run some queries""") | ||
return | ||
|
||
|
||
@app.cell(hide_code=True) | ||
def _(mo): | ||
mo.md("""### Streaming""") | ||
return | ||
|
||
|
||
@app.cell | ||
def _(api_token, tools): | ||
import google.generativeai as genai | ||
|
||
genai.configure(api_key=api_token.value) | ||
model = genai.GenerativeModel( | ||
model_name="gemini-2.0-flash-exp", tools=tools.value | ||
) | ||
|
||
model.generate_content( | ||
"Create a function that takes a list of numbers and returns the sum of all the numbers in the list.", | ||
stream=True, | ||
) | ||
return genai, model | ||
|
||
|
||
@app.cell(hide_code=True) | ||
def _(mo): | ||
mo.md(r"""### Non-streaming""") | ||
return | ||
|
||
|
||
@app.cell | ||
def _(model): | ||
model.generate_content( | ||
"Create a function that takes a list of numbers and returns the sum of all the numbers in the list." | ||
) | ||
return | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import annotations | ||
|
||
from unittest.mock import MagicMock, call, patch | ||
|
||
import pytest | ||
|
||
from marimo._dependencies.dependencies import DependencyManager | ||
from marimo._output.formatters.ai_formatters import GoogleAiFormatter | ||
from marimo._output.formatting import get_formatter | ||
from marimo._output.hypertext import Html | ||
|
||
|
||
@pytest.mark.skipif( | ||
not DependencyManager.google_ai.has(), reason="Google AI is not installed" | ||
) | ||
@patch( | ||
"marimo._output.formatters.ai_formatters.md.md", | ||
) | ||
def test_register_with_dummy_google(mock_md: MagicMock): | ||
GoogleAiFormatter().register() | ||
import google.generativeai as genai | ||
|
||
mock_md.side_effect = lambda x: Html(f"<md>{x}</md>") | ||
|
||
mock_response = MagicMock(genai.types.GenerateContentResponse) | ||
mock_response.text = "# Hello" | ||
mock_response._iterator = None | ||
mock_response.candidates = [ | ||
MagicMock(content="# Hello", index=0, finish_reason="STOP") | ||
] | ||
formatter = get_formatter(mock_response) | ||
assert formatter is not None | ||
result = formatter(mock_response) | ||
assert result == ( | ||
"text/html", | ||
"<md># Hello</md>", | ||
) | ||
# Verify md.md was called with the normal response text | ||
mock_md.assert_called_with("# Hello") | ||
|
||
# Streaming response | ||
mock_response._iterator = iter( | ||
[ | ||
MagicMock( | ||
text="```python\ndef foo():\n", | ||
index=0, | ||
finish_reason="INCOMPLETE", | ||
), | ||
MagicMock(text=" pass\n```", index=1, finish_reason="STOP"), | ||
] | ||
) | ||
mock_response.__iter__ = lambda self: iter(self._iterator) | ||
|
||
# Reset the mock before the streaming call | ||
mock_md.reset_mock() | ||
result = formatter(mock_response) | ||
# Check that md.md was called 3 times | ||
assert mock_md.call_count == 3 | ||
assert mock_md.call_args_list == [ | ||
call("```python\ndef foo():\n\n```"), # First chunk | ||
call("```python\ndef foo():\n pass\n```"), # First + second chunk | ||
call("```python\ndef foo():\n pass\n```"), # End result | ||
] |