Skip to content

Commit

Permalink
fix: mo.ui.chat config casing (#3690)
Browse files Browse the repository at this point in the history
Fixes ##3690

Use the correct casing for mo.ui.chat config
  • Loading branch information
mscolnick authored Feb 4, 2025
1 parent 58098bc commit 39e72f5
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 39 deletions.
10 changes: 5 additions & 5 deletions frontend/src/plugins/impl/chat/ChatPlugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ export const ChatPlugin = createPlugin<ChatMessage[]>("marimo-chatbot")
showConfigurationControls: z.boolean(),
maxHeight: z.number().optional(),
config: z.object({
maxTokens: z.number().default(100),
max_tokens: z.number().default(100),
temperature: z.number().default(0.5),
topP: z.number().default(1),
topK: z.number().default(40),
frequencyPenalty: z.number().default(0),
presencePenalty: z.number().default(0),
top_p: z.number().default(1),
top_k: z.number().default(40),
frequency_penalty: z.number().default(0),
presence_penalty: z.number().default(0),
}),
allowAttachments: z.union([z.boolean(), z.string().array()]),
}),
Expand Down
20 changes: 10 additions & 10 deletions frontend/src/plugins/impl/chat/chat-ui.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ export const Chatbot: React.FC<Props> = (props) => {
attachments: m.experimental_attachments,
})),
config: {
max_tokens: config.maxTokens,
max_tokens: config.max_tokens,
temperature: config.temperature,
top_p: config.topP,
top_k: config.topK,
frequency_penalty: config.frequencyPenalty,
presence_penalty: config.presencePenalty,
top_p: config.top_p,
top_k: config.top_k,
frequency_penalty: config.frequency_penalty,
presence_penalty: config.presence_penalty,
},
});
return new Response(response);
Expand Down Expand Up @@ -439,7 +439,7 @@ const configDescriptions: Record<
keyof ChatConfig,
{ min: number; max: number; description: string; step?: number }
> = {
maxTokens: {
max_tokens: {
min: 1,
max: 4096,
description: "Maximum number of tokens to generate",
Expand All @@ -450,24 +450,24 @@ const configDescriptions: Record<
step: 0.1,
description: "Controls randomness (0: deterministic, 2: very random)",
},
topP: {
top_p: {
min: 0,
max: 1,
step: 0.1,
description: "Nucleus sampling: probability mass to consider",
},
topK: {
top_k: {
min: 1,
max: 100,
description:
"Top-k sampling: number of highest probability tokens to consider",
},
frequencyPenalty: {
frequency_penalty: {
min: -2,
max: 2,
description: "Penalizes frequent tokens (-2: favor, 2: avoid)",
},
presencePenalty: {
presence_penalty: {
min: -2,
max: 2,
description: "Penalizes new tokens (-2: favor, 2: avoid)",
Expand Down
14 changes: 9 additions & 5 deletions frontend/src/plugins/impl/chat/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@ export interface SendMessageRequest {
};
}

/**
* These are snake_case because they come from the backend,
* and are not modified when sent to the frontend.
*/
export interface ChatConfig {
maxTokens: number;
max_tokens: number;
temperature: number;
topP: number;
topK: number;
frequencyPenalty: number;
presencePenalty: number;
top_p: number;
top_k: number;
frequency_penalty: number;
presence_penalty: number;
}
39 changes: 20 additions & 19 deletions marimo/_smoke_tests/chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,36 @@

import marimo

__generated_with = "0.9.14"
__generated_with = "0.11.0"
app = marimo.App(width="medium")


@app.cell(hide_code=True)
def __():
def _():
import marimo as mo
return (mo,)


@app.cell(hide_code=True)
def __(mo):
def _(mo):
mo.md(r"""# Built-in chatbots""")
return


@app.cell(hide_code=True)
def __(mo):
def _(mo):
mo.md(r"""## OpenAI""")
return


@app.cell
def __(mo):
def _(mo):
mo.ui.chat(
mo.ai.llm.openai(
"gpt-4-turbo", system_message="You are a helpful data scientist"
),
show_configuration_controls=True,
config={"max_tokens": 20},
prompts=[
"Tell me a joke",
"What is the meaning of life?",
Expand All @@ -50,13 +51,13 @@ def __(mo):


@app.cell(hide_code=True)
def __(mo):
def _(mo):
mo.md(r"""## Anthropic""")
return


@app.cell
def __(mo):
def _(mo):
mo.ui.chat(
mo.ai.llm.anthropic("claude-3-5-sonnet-20240620"),
show_configuration_controls=True,
Expand All @@ -70,13 +71,13 @@ def __(mo):


@app.cell(hide_code=True)
def __(mo):
def _(mo):
mo.md(r"""## Google Gemini""")
return


@app.cell
def __(mo):
def _(mo):
mo.ui.chat(
mo.ai.llm.google("gemini-1.5-pro-001"),
show_configuration_controls=True,
Expand All @@ -90,13 +91,13 @@ def __(mo):


@app.cell(hide_code=True)
def __(mo):
def _(mo):
mo.md(r"""# Custom chatbots""")
return


@app.cell(hide_code=True)
def __(mo):
def _(mo):
import os

os_key = os.environ.get("OPENAI_API_KEY")
Expand All @@ -106,13 +107,13 @@ def __(mo):


@app.cell
def __(input_key, os_key):
def _(input_key, os_key):
openai_key = os_key or input_key.value
return (openai_key,)


@app.cell(hide_code=True)
def __(mo, openai_key):
def _(mo, openai_key):
# Initialize a client
mo.stop(
not openai_key,
Expand All @@ -128,13 +129,13 @@ def __(mo, openai_key):


@app.cell(hide_code=True)
def __(mo):
def _(mo):
mo.md(r"""## Simple""")
return


@app.cell
def __(client, ell, mo):
def _(client, ell, mo):
@ell.simple("gpt-4o-mini-2024-07-18", client=client)
def _my_model(prompt):
"""You are an annoying little brother, whatever I say, be sassy with your response"""
Expand All @@ -146,13 +147,13 @@ def _my_model(prompt):


@app.cell(hide_code=True)
def __(mo):
def _(mo):
mo.md(r"""## Complex""")
return


@app.cell
def __():
def _():
# Grab a dataset for the chatbot conversation, we will use the cars dataset

from vega_datasets import data
Expand All @@ -162,7 +163,7 @@ def __():


@app.cell
def __(cars, client, ell):
def _(cars, client, ell):
from pydantic import BaseModel, Field


Expand Down Expand Up @@ -236,7 +237,7 @@ def chat_bot(message_history):


@app.cell
def __(cars, get_sample_prompts, mo, my_complex_model):
def _(cars, get_sample_prompts, mo, my_complex_model):
prompts = get_sample_prompts(cars).parsed.prompts
mo.ui.chat(
my_complex_model,
Expand Down

0 comments on commit 39e72f5

Please sign in to comment.