diff --git a/frontend/src/components/shortcuts/renderShortcut.tsx b/frontend/src/components/shortcuts/renderShortcut.tsx index 3e6e1757ab8..e86333e0c3e 100644 --- a/frontend/src/components/shortcuts/renderShortcut.tsx +++ b/frontend/src/components/shortcuts/renderShortcut.tsx @@ -108,6 +108,15 @@ function prettyPrintHotkey(key: string): [label: string, symbol?: string] { return [lowerKey]; } +export function getSymbol(key: string) { + const platform = isPlatformMac() ? "mac" : "default"; + const keyData = KEY_MAPPINGS[key.toLowerCase()]; + if (keyData) { + return keyData.symbols[platform] || keyData.symbols.default; + } + return key; +} + interface KeyData { symbols: { mac?: string; @@ -190,6 +199,10 @@ const KEY_MAPPINGS: Record = { symbols: { mac: "↘", default: "End" }, label: "End", }, + mod: { + symbols: { mac: "⌘", windows: "⊞ Win", default: "Ctrl" }, + label: "Control", + }, }; function capitalize(str: string) { diff --git a/frontend/src/core/codemirror/cm.ts b/frontend/src/core/codemirror/cm.ts index b85ed8e985b..663d10cb040 100644 --- a/frontend/src/core/codemirror/cm.ts +++ b/frontend/src/core/codemirror/cm.ts @@ -56,7 +56,9 @@ import { historyCompartment } from "./editing/extensions"; import { goToDefinitionBundle } from "./go-to-definition/extension"; import type { HotkeyProvider } from "../hotkeys/hotkeys"; import { lightTheme } from "./theme/light"; - +import { promptPlugin } from "./prompt/prompt"; +import { requestEditCompletion } from "./prompt/request"; +import { getCurrentLanguageAdapter } from "./language/commands"; export interface CodeMirrorSetupOpts { cellId: CellId; showPlaceholder: boolean; @@ -79,6 +81,7 @@ export const setupCodeMirror = (opts: CodeMirrorSetupOpts): Extension[] => { cellCodeCallbacks, keymapConfig, hotkeys, + enableAI, } = opts; return [ @@ -91,6 +94,20 @@ export const setupCodeMirror = (opts: CodeMirrorSetupOpts): Extension[] => { basicBundle(opts), // Underline cmd+clickable placeholder goToDefinitionBundle(), + // AI prompt edit + enableAI + ? promptPlugin({ + complete: (req) => { + return requestEditCompletion({ + prompt: req.prompt, + selection: req.selection, + codeBefore: req.codeBefore, + code: req.editorView.state.doc.toString(), + language: getCurrentLanguageAdapter(req.editorView), + }); + }, + }) + : [], ]; }; diff --git a/frontend/src/core/codemirror/prompt/complete.ts b/frontend/src/core/codemirror/prompt/complete.ts new file mode 100644 index 00000000000..6075a188b84 --- /dev/null +++ b/frontend/src/core/codemirror/prompt/complete.ts @@ -0,0 +1,51 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { API } from "@/core/network/api"; +import { asURL } from "@/utils/url"; +import type { LanguageAdapterType } from "../language/types"; +import { getCodes } from "../copilot/getCodes"; + +/** + * Request to edit code with AI + */ +export async function requestEditCompletion(opts: { + prompt: string; + selection: string; + code: string; + codeBefore: string; + language: LanguageAdapterType; +}): Promise { + const currentCode = opts.code; + + const otherCodes = getCodes(currentCode); + // Other code to include is the codeBefore and the other codes + const includeOtherCode = `${opts.codeBefore}\n${otherCodes}`; + + const response = await fetch(asURL("api/ai/completion").toString(), { + method: "POST", + headers: API.headers(), + body: JSON.stringify({ + prompt: opts.prompt, + code: opts.selection, + includeOtherCode: includeOtherCode, + language: opts.language, + }), + }); + + const reader = response.body?.getReader(); + if (!reader) { + throw new Error("Failed to get response reader"); + } + + let result = ""; + // eslint-disable-next-line no-constant-condition + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + result += new TextDecoder().decode(value); + } + + return result; +} diff --git a/frontend/src/core/codemirror/prompt/prompt.ts b/frontend/src/core/codemirror/prompt/prompt.ts new file mode 100644 index 00000000000..50ce1e79e35 --- /dev/null +++ b/frontend/src/core/codemirror/prompt/prompt.ts @@ -0,0 +1,523 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import "./styles.css"; +import { + EditorView, + Decoration, + type DecorationSet, + ViewPlugin, + type ViewUpdate, + WidgetType, + keymap, +} from "@codemirror/view"; +import { + StateField, + StateEffect, + EditorSelection, + type Range, + Prec, + type Extension, +} from "@codemirror/state"; +import { getSymbol } from "../../../components/shortcuts/renderShortcut"; +import { Logger } from "@/utils/Logger"; +import { toast } from "@/components/ui/use-toast"; + +export type CompleteFunction = (opts: { + prompt: string; + editorView: EditorView; + selection: string; + codeBefore: string; + codeAfter: string; +}) => Promise; + +/** + * The prompt extension creates a plugin for CodeMirror that adds AI-assisted editing functionality. + * + * Key features and behaviors: + * 1. Cmd+L Tooltip: + * - Appears when text is selected. + * - Hidden when: + * a) No text is selected + * b) Editing instructions input is open + * c) A completion is pending (not yet accepted or declined) + * + * 2. Editing Instructions Input: + * - Opened with Cmd+L when text is selected + * - Automatically focused when opened + * - Closed with Escape key + * + * 3. Completions: + * - Displayed after submitting editing instructions + * - Can be accepted with Cmd+Y or the accept button + * - Can be rejected with Cmd+U or the reject button + * + * 4. States: + * - tooltipState: Manages visibility of Cmd+L tooltip + * - inputState: Manages visibility and position of editing instructions input + * - completionState: Manages the current completion (if any) + * + * This plugin ensures that only one of these elements (tooltip, input, or completion) + * is visible at any given time, providing a clear and focused user experience. + */ +export function promptPlugin(opts: { + complete: CompleteFunction; +}): Extension[] { + const { complete } = opts; + return [ + tooltipState, + inputState, + completionState, + loadingState, + selectionPlugin, + keymap.of([ + { + key: "Mod-l", + run: showInputPrompt, + }, + ]), + Prec.highest([ + keymap.of([ + { key: "Mod-y", run: acceptCompletion }, + { key: "Mod-u", run: rejectCompletion }, + ]), + ]), + EditorView.updateListener.of((update) => { + if (update.selectionSet) { + const { from, to } = update.state.selection.main; + const inputStateValue = update.state.field(inputState); + const completionStateValue = update.state.field(completionState); + // Only show tooltip if there's a selection and no input or completion is active + update.view.dispatch({ + effects: showTooltip.of( + from !== to && !inputStateValue.show && !completionStateValue, + ), + }); + } + }), + EditorView.decorations.of((view) => { + const inputStateValue = view.state.field(inputState); + const completionStateValue = view.state.field(completionState); + const decorations: Array> = []; + + if (inputStateValue.show) { + for (let i = inputStateValue.from; i < inputStateValue.to; i++) { + decorations.push( + Decoration.line({ class: "cm-ai-selection" }).range(i), + ); + if (i === inputStateValue.from) { + decorations.push( + Decoration.widget({ + widget: new InputWidget(complete), + side: -1, + }).range(inputStateValue.from), + ); + } + } + } + + if (completionStateValue) { + decorations.push( + Decoration.widget({ + widget: new OldCodeWidget(view, completionStateValue.oldCode), + side: -1, + }).range(completionStateValue.from), + Decoration.mark({ + class: "cm-new-code-line", + }).range(completionStateValue.from, completionStateValue.to), + ); + } + + return Decoration.set(decorations); + }), + ]; +} + +// Singleton tooltip element +let tooltip: HTMLDivElement; + +// State effect to show/hide the tooltip +const showTooltip = StateEffect.define(); + +// State field to manage the tooltip visibility +const tooltipState = StateField.define({ + create() { + return false; + }, + update(value, tr) { + for (const e of tr.effects) { + if (e.is(showTooltip)) { + return e.value; + } + } + return value; + }, +}); + +// View plugin to handle selection changes +const selectionPlugin = ViewPlugin.fromClass( + class { + decorations: DecorationSet; + + constructor(view: EditorView) { + this.decorations = this.createDecorations(view); + } + + update(update: ViewUpdate) { + if ( + update.selectionSet || + update.docChanged || + update.viewportChanged || + update.transactions.some((tr) => + tr.effects.some((e) => e.is(showTooltip)), + ) + ) { + this.decorations = this.createDecorations(update.view); + } + } + + createDecorations(view: EditorView) { + const { from, to } = view.state.selection.main; + const inputStateValue = view.state.field(inputState); + const completionStateValue = view.state.field(completionState); + const tooltipStateValue = view.state.field(tooltipState); + const doc = view.state.doc; + + // Hide tooltip if there's no selection, input is open, completion is pending, or tooltipState is false + if ( + from === to || + inputStateValue.show || + completionStateValue || + !tooltipStateValue || + from < 0 || + to > doc.length + ) { + return Decoration.none; + } + + // Adjust selection to exclude empty lines at the start and end + let adjustedFrom = from; + let adjustedTo = to; + + while ( + adjustedFrom < adjustedTo && + doc.lineAt(adjustedFrom).length === 0 + ) { + adjustedFrom = doc.lineAt(adjustedFrom + 1).from; + } + while (adjustedTo > adjustedFrom && doc.lineAt(adjustedTo).length === 0) { + adjustedTo = doc.lineAt(adjustedTo - 1).to; + } + + // If the adjusted selection is empty, don't show the tooltip + if (adjustedFrom === adjustedTo) { + return Decoration.none; + } + + if (!tooltip) { + tooltip = document.createElement("div"); + tooltip.className = "cm-tooltip cm-ai-tooltip"; + tooltip.innerHTML = `Edit ${getSymbol("mod") ?? "Ctrl"} + L`; + tooltip.style.cursor = "pointer"; + tooltip.addEventListener("click", (evt) => { + evt.stopPropagation(); + showInputPrompt(view); + }); + } + + return Decoration.set([ + Decoration.widget({ + widget: new (class extends WidgetType { + toDOM() { + return tooltip; + } + override ignoreEvent() { + return true; + } + })(), + side: -1, + }).range(adjustedFrom), + ]); + } + }, + { + decorations: (v) => v.decorations, + }, +); + +const showInput = StateEffect.define<{ + show: boolean; + from: number; + to: number; +}>(); + +const inputState = StateField.define<{ + show: boolean; + from: number; + to: number; +}>({ + create() { + return { show: false, from: 0, to: 0 }; + }, + update(value, tr) { + for (const e of tr.effects) { + if (e.is(showInput)) { + return e.value; + } + } + return value; + }, +}); + +const showCompletion = StateEffect.define<{ + from: number; + to: number; + oldCode: string; + newCode: string; +} | null>(); + +const completionState = StateField.define<{ + from: number; + to: number; + oldCode: string; + newCode: string; +} | null>({ + create() { + return null; + }, + update(value, tr) { + for (const e of tr.effects) { + if (e.is(showCompletion)) { + return e.value; + } + } + return value; + }, +}); + +// Add a new state effect and state field for loading status +const setLoading = StateEffect.define(); + +const loadingState = StateField.define({ + create() { + return false; + }, + update(value, tr) { + for (const e of tr.effects) { + if (e.is(setLoading)) { + return e.value; + } + } + return value; + }, +}); + +function showInputPrompt(view: EditorView) { + const { state } = view; + const selection = state.selection.main; + if (selection.from !== selection.to) { + const doc = state.doc; + const fromLine = doc.lineAt(selection.from); + const toLine = doc.lineAt(selection.to); + const docLength = doc.length; + + // Ensure the selection is within the document bounds + const safeFrom = Math.max(0, Math.min(fromLine.from, docLength)); + const safeTo = Math.max(0, Math.min(toLine.to, docLength)); + + view.dispatch({ + effects: [ + showInput.of({ + show: true, + from: safeFrom, + to: safeTo, + }), + showTooltip.of(false), // Hide the tooltip + ], + selection: EditorSelection.cursor(safeFrom), + }); + return true; + } + return false; +} + +function acceptCompletion(view: EditorView) { + const completionStateValue = view.state.field(completionState); + if (completionStateValue) { + view.dispatch({ + effects: [ + showCompletion.of(null), + showInput.of({ show: false, from: 0, to: 0 }), + setLoading.of(false), + ], + }); + return true; + } + return false; +} + +function rejectCompletion(view: EditorView) { + const completionStateValue = view.state.field(completionState); + if (completionStateValue) { + view.dispatch({ + changes: { + from: completionStateValue.from, + to: completionStateValue.to, + insert: completionStateValue.oldCode, + }, + effects: [ + showCompletion.of(null), + showInput.of({ show: false, from: 0, to: 0 }), + setLoading.of(false), + ], + }); + return true; + } + return false; +} + +// Update the OldCodeWidget class +class OldCodeWidget extends WidgetType { + constructor( + private view: EditorView, + private oldCode: string, + ) { + super(); + } + toDOM() { + const container = document.createElement("div"); + container.className = "cm-old-code-container"; + + const oldCodeEl = document.createElement("div"); + oldCodeEl.className = "cm-old-code"; + oldCodeEl.textContent = this.oldCode; + container.append(oldCodeEl); + + const buttonsContainer = document.createElement("div"); + buttonsContainer.className = "cm-floating-buttons"; + + const modSymbol = getSymbol("mod") ?? "Ctrl"; + + const acceptButton = document.createElement("button"); + acceptButton.textContent = `${modSymbol} Y`; + acceptButton.className = "cm-floating-button cm-floating-accept"; + acceptButton.addEventListener("click", () => acceptCompletion(this.view)); + + const rejectButton = document.createElement("button"); + rejectButton.textContent = `${modSymbol} U`; + rejectButton.className = "cm-floating-button cm-floating-reject"; + rejectButton.addEventListener("click", () => rejectCompletion(this.view)); + + buttonsContainer.append(acceptButton); + buttonsContainer.append(rejectButton); + container.append(buttonsContainer); + + return container; + } +} + +// Input widget +class InputWidget extends WidgetType { + constructor(private complete: CompleteFunction) { + super(); + } + + toDOM(view: EditorView) { + const inputContainer = document.createElement("div"); + inputContainer.className = "cm-ai-input-container"; + + const input = document.createElement("input"); + input.className = "cm-ai-input"; + input.placeholder = "Editing instructions..."; + + const loadingIndicator = document.createElement("div"); + loadingIndicator.classList.add("cm-ai-loading-indicator"); + loadingIndicator.textContent = "Loading"; + const isLoading = view.state.field(loadingState); + + const helpInfo = document.createElement("div"); + helpInfo.className = "cm-ai-help-info"; + helpInfo.textContent = "Esc to close"; + + if (isLoading) { + helpInfo.classList.add("hidden"); + } else { + loadingIndicator.classList.add("hidden"); + } + + // Set up a timeout to focus the input after it's been added to the DOM + setTimeout(() => input.focus(), 0); + + input.addEventListener("keydown", async (e) => { + if (e.key === "Enter") { + const state = view.state.field(inputState); + const oldCode = view.state.sliceDoc(state.from, state.to); + const codeBefore = view.state.sliceDoc(0, state.from); + const codeAfter = view.state.sliceDoc(state.to); + + // Show loading indicator + view.dispatch({ effects: setLoading.of(true) }); + loadingIndicator.classList.remove("hidden"); + helpInfo.classList.add("hidden"); + input.disabled = true; + + try { + const result = await this.complete({ + prompt: input.value, + selection: oldCode, + codeBefore: codeBefore, + codeAfter: codeAfter, + editorView: view, + }); + + if (!view.state.field(inputState).show) { + return; + } + + view.dispatch({ + changes: { from: state.from, to: state.to, insert: result }, + effects: [ + showInput.of({ show: false, from: state.from, to: state.to }), + showCompletion.of({ + from: state.from, + to: state.from + result.length, + oldCode, + newCode: result, + }), + setLoading.of(false), + ], + selection: EditorSelection.cursor(state.to), + }); + } catch (error) { + Logger.error("Completion error:", error); + toast({ + title: "Error", + description: + "An error occurred while processing your request. Please try again.", + variant: "danger", + }); + } finally { + // Hide loading indicator + loadingIndicator.classList.add("hidden"); + helpInfo.classList.remove("hidden"); + input.disabled = false; + } + + // Refocus the editor after the prompt returns + view.focus(); + } else if (e.key === "Escape") { + // Close the input when Escape is pressed + view.dispatch({ + effects: [ + showInput.of({ show: false, from: 0, to: 0 }), + setLoading.of(false), + ], + }); + // Refocus the editor after closing the input + view.focus(); + } + }); + + inputContainer.append(input, loadingIndicator, helpInfo); + + return inputContainer; + } +} diff --git a/frontend/src/core/codemirror/prompt/request.ts b/frontend/src/core/codemirror/prompt/request.ts new file mode 100644 index 00000000000..da45d279d61 --- /dev/null +++ b/frontend/src/core/codemirror/prompt/request.ts @@ -0,0 +1,56 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { API } from "@/core/network/api"; +import { asURL } from "@/utils/url"; +import type { LanguageAdapterType } from "../language/types"; +import { getCodes } from "../copilot/getCodes"; + +/** + * Request to edit code with AI + */ +export async function requestEditCompletion(opts: { + prompt: string; + selection: string; + code: string; + codeBefore: string; + language: LanguageAdapterType; +}): Promise { + const currentCode = opts.code; + + const otherCodes = getCodes(currentCode); + // Other code to include is the codeBefore and the other codes + const includeOtherCode = `${opts.codeBefore}\n${otherCodes}`; + + const prompt = ` + I would like to edit the selected code. Specifically, I want to change the following: + ${opts.prompt} + `; + + const response = await fetch(asURL("api/ai/completion").toString(), { + method: "POST", + headers: API.headers(), + body: JSON.stringify({ + prompt: prompt, + code: opts.selection, + includeOtherCode: includeOtherCode, + language: opts.language, + }), + }); + + const reader = response.body?.getReader(); + if (!reader) { + throw new Error("Failed to get response reader"); + } + + let result = ""; + // eslint-disable-next-line no-constant-condition + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + result += new TextDecoder().decode(value); + } + + return result; +} diff --git a/frontend/src/core/codemirror/prompt/styles.css b/frontend/src/core/codemirror/prompt/styles.css new file mode 100644 index 00000000000..44094f5e78c --- /dev/null +++ b/frontend/src/core/codemirror/prompt/styles.css @@ -0,0 +1,182 @@ +.cm-tooltip.cm-ai-tooltip { + user-select: none; + pointer-events: none; + font-family: sans-serif; + position: absolute; + cursor: pointer; + right: 0; + padding: 5px; + border-radius: 3px; + font-size: 12px; + pointer-events: none; + background-color: var(--slate-1); + + @apply bg-background border shadow-sm; + + .hotkey { + @apply text-muted-foreground; + } +} + +.cm-ai-input-container { + display: flex; + flex-direction: column; + gap: 4px; + width: calc(100% + 7px); + padding: 5px 5px; + margin: 0 -6px; + background-color: var(--slate-2); +} + +.cm-ai-input { + display: block; + width: 100%; + padding: 5px 10px; + border: 2px solid var(--sky-7); + border-radius: 5px; + font-size: 12px; + + @apply shadow-sm; +} + +.cm-ai-help-info { + font-size: 10px; + padding-left: 12px; + color: var(--slate-11); +} + +.cm-line.cm-ai-selection { + background-color: var(--slate-3); +} + +.cm-line:has(.cm-new-code-line) { + background-color: var(--grass-3); +} + +.cm-old-code-container { + background-color: #ffebee; + padding: 5px 0; + position: relative; + z-index: 1; + + &::before { + content: ''; + position: absolute; + top: 0; + left: -8px; + width: 8px; + height: 100%; + background-color: var(--red-3); + } + + &::after { + content: ''; + position: absolute; + bottom: 0; + right: -8px; + width: 8px; + height: 100%; + background-color: var(--red-3); + } +} + +.cm-new-code-line { + background-color: #e8f5e9; +} + +.cm-code-button { + position: absolute; + right: 5px; + top: 50%; + transform: translateY(-50%); + background: none; + border: none; + cursor: pointer; + font-size: 16px; +} + +.cm-old-code-container { + background-color: var(--red-3); + position: relative; + display: flex; + width: 100%; + align-items: center; +} + +.cm-floating-buttons { + font-family: sans-serif; + position: absolute; + bottom: 0; + right: 0; + display: flex; +} + +.cm-floating-button { + font-family: sans-serif; + padding: 2px 5px; + font-size: 10px; + cursor: pointer; + font-weight: 700; +} + +.cm-floating-accept { + background-color: var(--grass-9); + border-top-left-radius: 5px; + border-bottom-left-radius: 5px; + opacity: 0.8; + color: white; + + &:hover { + opacity: 1; + } +} + +.cm-floating-reject { + background-color: var(--red-9); + color: white; + border-top-right-radius: 5px; + opacity: 0.8; + border-bottom-right-radius: 5px; + + &:hover { + opacity: 1; + } +} + +.cm-ai-loading-indicator { + font-style: italic; + font-size: 10px; + padding-left: 12px; + color: var(--slate-11); + opacity: 0; + transition: opacity 0.3s ease-in-out; +} + +.cm-ai-loading-indicator::after { + content: ''; + display: inline-block; + animation: ellipsis-pulse 1.5s steps(4, end) infinite; +} + +@keyframes ellipsis-pulse { + 0% { + content: '.'; + } + 25% { + content: '..'; + } + 50% { + content: '...'; + } + 75% { + content: ''; + } +} + +.cm-ai-loading-indicator:not(:empty) { + opacity: 1; +} + +.cm-ai-input:disabled { + opacity: 0.5; +} diff --git a/frontend/src/stories/codemirror-prompt.stories.tsx b/frontend/src/stories/codemirror-prompt.stories.tsx new file mode 100644 index 00000000000..3550d7cbd90 --- /dev/null +++ b/frontend/src/stories/codemirror-prompt.stories.tsx @@ -0,0 +1,38 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import ReactCodemirror, { basicSetup } from "@uiw/react-codemirror"; +import { promptPlugin } from "../core/codemirror/prompt/prompt"; +import { python } from "@codemirror/lang-python"; +import type { Meta, StoryObj } from "@storybook/react"; +import React from "react"; + +const CodeMirrorPrompt: React.FC = () => { + return ( + int: + return a + b + +mo.ui.button(label="Click me") + `} + extensions={[ + basicSetup(), + python(), + promptPlugin(async (selection, code) => { + return "def sub(a: int, b: int) -> int:\n return a - b"; + }), + ]} + /> + ); +}; + +const meta: Meta = { + title: "CodeMirror/Prompt", + component: CodeMirrorPrompt, +}; + +export default meta; +type Story = StoryObj; + +export const Default: Story = {};