Skip to content

Commit

Permalink
improvement: pretty gemini responses in output, streaming / non-strea…
Browse files Browse the repository at this point in the history
…ming (#3705)
  • Loading branch information
mscolnick authored Feb 6, 2025
1 parent 9ca9634 commit e504729
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 0 deletions.
53 changes: 53 additions & 0 deletions marimo/_output/formatters/ai_formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from marimo._messaging.mimetypes import KnownMimeType
from marimo._output import md
from marimo._output.formatters.formatter_factory import FormatterFactory
from marimo._runtime import output


class GoogleAiFormatter(FormatterFactory):
@staticmethod
def package_name() -> str:
return "google"

def register(self) -> None:
try:
import google.generativeai as genai # type: ignore
except (ImportError, ModuleNotFoundError):
return

from marimo._output import formatting

@formatting.formatter(genai.types.GenerateContentResponse)
def _show_response(
response: genai.types.GenerateContentResponse,
) -> tuple[KnownMimeType, str]:
if hasattr(response, "_iterator") and response._iterator is None:
return ("text/html", md.md(response.text).text)
else:
# Streaming response
total_text = ""
for chunk in response:
total_text += chunk.text
output.replace(
md.md(_ensure_closing_code_fence(total_text))
)
return ("text/html", md.md(total_text).text)


def _ensure_closing_code_fence(text: str) -> str:
"""Ensure text has an even number of code fences
If text ends with an unclosed code fence, add a closing fence.
Handles nested code fences by checking if the last fence is an opening one.
"""
# Split by code fences to track nesting
parts = text.split("```")
# If odd number of parts, we have an unclosed fence
# parts = ["before", "code", "between", "more code"] -> 4 parts = 3 fences
# parts = ["before", "code", "between", "more code", "after"] -> 5 parts = 4 fences
if len(parts) > 1 and len(parts) % 2 == 0:
return text + "\n```"
return text
2 changes: 2 additions & 0 deletions marimo/_output/formatters/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Sequence

from marimo._config.config import Theme
from marimo._output.formatters.ai_formatters import GoogleAiFormatter
from marimo._output.formatters.altair_formatters import AltairFormatter
from marimo._output.formatters.anywidget_formatters import AnyWidgetFormatter
from marimo._output.formatters.arviz_formatters import ArviZFormatter
Expand Down Expand Up @@ -55,6 +56,7 @@
SympyFormatter.package_name(): SympyFormatter(),
PyechartsFormatter.package_name(): PyechartsFormatter(),
PanelFormatter.package_name(): PanelFormatter(),
GoogleAiFormatter.package_name(): GoogleAiFormatter(),
}

# Formatters for builtin types and other things that don't require a
Expand Down
85 changes: 85 additions & 0 deletions marimo/_smoke_tests/ai/gemini_responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import marimo

__generated_with = "0.11.0"
app = marimo.App(width="medium")


@app.cell
def _():
import marimo as mo
import os
return mo, os


@app.cell(hide_code=True)
def _(mo):
mo.md(r"""## Setup""")
return


@app.cell(hide_code=True)
def _(mo, os):
api_token = mo.ui.text(
label="Gemini API Token", value=os.environ.get("GEMINI_API_KEY") or ""
)
api_token
return (api_token,)


@app.cell(hide_code=True)
def _(api_token, mo):
mo.callout("Missing API Key", kind="danger") if not api_token.value else None
return


@app.cell(hide_code=True)
def _(mo):
tools = mo.ui.multiselect(["code_execution"], value=[], label="Tools")
tools
return (tools,)


@app.cell(hide_code=True)
def _(mo):
mo.md("""## Run some queries""")
return


@app.cell(hide_code=True)
def _(mo):
mo.md("""### Streaming""")
return


@app.cell
def _(api_token, tools):
import google.generativeai as genai

genai.configure(api_key=api_token.value)
model = genai.GenerativeModel(
model_name="gemini-2.0-flash-exp", tools=tools.value
)

model.generate_content(
"Create a function that takes a list of numbers and returns the sum of all the numbers in the list.",
stream=True,
)
return genai, model


@app.cell(hide_code=True)
def _(mo):
mo.md(r"""### Non-streaming""")
return


@app.cell
def _(model):
model.generate_content(
"Create a function that takes a list of numbers and returns the sum of all the numbers in the list."
)
return


if __name__ == "__main__":
app.run()
63 changes: 63 additions & 0 deletions tests/_output/formatters/test_ai_formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from unittest.mock import MagicMock, call, patch

import pytest

from marimo._dependencies.dependencies import DependencyManager
from marimo._output.formatters.ai_formatters import GoogleAiFormatter
from marimo._output.formatting import get_formatter
from marimo._output.hypertext import Html


@pytest.mark.skipif(
not DependencyManager.google_ai.has(), reason="Google AI is not installed"
)
@patch(
"marimo._output.formatters.ai_formatters.md.md",
)
def test_register_with_dummy_google(mock_md: MagicMock):
GoogleAiFormatter().register()
import google.generativeai as genai

mock_md.side_effect = lambda x: Html(f"<md>{x}</md>")

mock_response = MagicMock(genai.types.GenerateContentResponse)
mock_response.text = "# Hello"
mock_response._iterator = None
mock_response.candidates = [
MagicMock(content="# Hello", index=0, finish_reason="STOP")
]
formatter = get_formatter(mock_response)
assert formatter is not None
result = formatter(mock_response)
assert result == (
"text/html",
"<md># Hello</md>",
)
# Verify md.md was called with the normal response text
mock_md.assert_called_with("# Hello")

# Streaming response
mock_response._iterator = iter(
[
MagicMock(
text="```python\ndef foo():\n",
index=0,
finish_reason="INCOMPLETE",
),
MagicMock(text=" pass\n```", index=1, finish_reason="STOP"),
]
)
mock_response.__iter__ = lambda self: iter(self._iterator)

# Reset the mock before the streaming call
mock_md.reset_mock()
result = formatter(mock_response)
# Check that md.md was called 3 times
assert mock_md.call_count == 3
assert mock_md.call_args_list == [
call("```python\ndef foo():\n\n```"), # First chunk
call("```python\ndef foo():\n pass\n```"), # First + second chunk
call("```python\ndef foo():\n pass\n```"), # End result
]

0 comments on commit e504729

Please sign in to comment.