From a02ccda04803d4889dd1eafe7bd401a69d557c64 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Mon, 24 Feb 2025 11:32:19 -0500 Subject: [PATCH] enhance: Add proper aborting of runs Aborting a run is different from "closing" it. Closing a run will result in an error. Aborting a run will cause it to stop at the the next available event and not return any error. Instead, the run will have its text appended with "ABORTED BY USER" and all the chat state will be preserved. Signed-off-by: Donnie Adams --- gptscript.go | 5 ++ gptscript_test.go | 139 +++++++++++++++++++++++++++++++++++++++++++++- run.go | 2 + 3 files changed, 144 insertions(+), 2 deletions(-) diff --git a/gptscript.go b/gptscript.go index 28a46a2..8ca9747 100644 --- a/gptscript.go +++ b/gptscript.go @@ -170,6 +170,11 @@ func (g *GPTScript) Run(ctx context.Context, toolPath string, opts Options) (*Ru }).NextChat(ctx, opts.Input) } +func (g *GPTScript) AbortRun(ctx context.Context, run *Run) error { + _, err := g.runBasicCommand(ctx, "abort/"+run.id, (map[string]any)(nil)) + return err +} + type ParseOptions struct { DisableCache bool } diff --git a/gptscript_test.go b/gptscript_test.go index 8f299c8..476492d 100644 --- a/gptscript_test.go +++ b/gptscript_test.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/getkin/kin-openapi/openapi3" "github.com/stretchr/testify/require" @@ -134,7 +135,7 @@ func TestListModelsWithDefaultProvider(t *testing.T) { } } -func TestAbortRun(t *testing.T) { +func TestCancelRun(t *testing.T) { tool := ToolDef{Instructions: "What is the capital of the united states?"} run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) @@ -146,7 +147,7 @@ func TestAbortRun(t *testing.T) { <-run.Events() if err := run.Close(); err != nil { - t.Errorf("Error aborting run: %v", err) + t.Errorf("Error canceling run: %v", err) } if run.State() != Error { @@ -158,6 +159,77 @@ func TestAbortRun(t *testing.T) { } } +func TestAbortChatCompletionRun(t *testing.T) { + tool := ToolDef{Instructions: "What is the capital of the united states?"} + + run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) + if err != nil { + t.Errorf("Error executing tool: %v", err) + } + + // Abort the run after the first event from the LLM + for e := range run.Events() { + if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." { + break + } + } + + if err := g.AbortRun(context.Background(), run); err != nil { + t.Errorf("Error aborting run: %v", err) + } + + // Wait for run to stop + for range run.Events() { + continue + } + + if run.State() != Finished { + t.Errorf("Unexpected run state: %s", run.State()) + } + + if out, err := run.Text(); err != nil { + t.Errorf("Error reading output: %v", err) + } else if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") { + t.Errorf("Unexpected output: %s", out) + } +} + +func TestAbortCommandRun(t *testing.T) { + tool := ToolDef{Instructions: "#!/usr/bin/env bash\necho Hello, world!\nsleep 5\necho Hello, again!\nsleep 5"} + + run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) + if err != nil { + t.Errorf("Error executing tool: %v", err) + } + + // Abort the run after the first event. + for e := range run.Events() { + if e.Call != nil && e.Call.Type == EventTypeChat { + time.Sleep(2 * time.Second) + break + } + } + + if err := g.AbortRun(context.Background(), run); err != nil { + t.Errorf("Error aborting run: %v", err) + } + + // Wait for run to stop + for range run.Events() { + continue + } + + if run.State() != Finished { + t.Errorf("Unexpected run state: %s", run.State()) + } + + if out, err := run.Text(); err != nil { + t.Errorf("Error reading output: %v", err) + } else if !strings.Contains(out, "Hello, world!") || strings.Contains(out, "Hello, again!") || !strings.HasSuffix(out, "\nABORTED BY USER") { + t.Errorf("Unexpected output: %s", out) + } +} + func TestSimpleEvaluate(t *testing.T) { tool := ToolDef{Instructions: "What is the capital of the united states?"} @@ -844,6 +916,69 @@ func TestToolChat(t *testing.T) { } } +func TestAbortChat(t *testing.T) { + tool := ToolDef{ + Chat: true, + Instructions: "You are a chat bot. Don't finish the conversation until I say 'bye'.", + Tools: []string{"sys.chat.finish"}, + } + + run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) + if err != nil { + t.Fatalf("Error executing tool: %v", err) + } + inputs := []string{ + "Tell me a joke.", + "What was my first message?", + } + + // Just wait for the chat to start up. + for range run.Events() { + continue + } + + for i, input := range inputs { + run, err = run.NextChat(context.Background(), input) + if err != nil { + t.Fatalf("Error sending next input %q: %v", input, err) + } + + // Abort the run after the first event from the LLM + for e := range run.Events() { + if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." { + break + } + } + + if i == 0 { + if err := g.AbortRun(context.Background(), run); err != nil { + t.Fatalf("Error aborting run: %v", err) + } + } + + // Wait for the run to complete + for range run.Events() { + continue + } + + out, err := run.Text() + if err != nil { + t.Errorf("Error reading output: %s", run.ErrorOutput()) + t.Fatalf("Error reading output: %v", err) + } + + if i == 0 { + if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") { + t.Fatalf("Unexpected output: %s", out) + } + } else { + if !strings.Contains(out, "Tell me a joke") { + t.Errorf("Unexpected output: %s", out) + } + } + } +} + func TestFileChat(t *testing.T) { wd, err := os.Getwd() if err != nil { diff --git a/run.go b/run.go index 52ad6e4..558b388 100644 --- a/run.go +++ b/run.go @@ -37,6 +37,7 @@ type Run struct { basicCommand bool program *Program + id string callsLock sync.RWMutex calls CallFrames rawOutput map[string]any @@ -400,6 +401,7 @@ func (r *Run) request(ctx context.Context, payload any) (err error) { if event.Run.Type == EventTypeRunStart { r.callsLock.Lock() r.program = &event.Run.Program + r.id = event.Run.ID r.callsLock.Unlock() } else if event.Run.Type == EventTypeRunFinish && event.Run.Error != "" { r.state = Error