Skip to content

Commit

Permalink
enhance: Add proper aborting of runs
Browse files Browse the repository at this point in the history
Aborting a run is different from "closing" it. Closing a run will result
in an error. Aborting a run will cause it to stop at the the next
available event and not return any error. Instead, the run will have its
text appended with "ABORTED BY USER" and all the chat state will be
preserved.

Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams committed Feb 24, 2025
1 parent eee4337 commit a02ccda
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 2 deletions.
5 changes: 5 additions & 0 deletions gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ func (g *GPTScript) Run(ctx context.Context, toolPath string, opts Options) (*Ru
}).NextChat(ctx, opts.Input)
}

func (g *GPTScript) AbortRun(ctx context.Context, run *Run) error {
_, err := g.runBasicCommand(ctx, "abort/"+run.id, (map[string]any)(nil))
return err
}

type ParseOptions struct {
DisableCache bool
}
Expand Down
139 changes: 137 additions & 2 deletions gptscript_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/getkin/kin-openapi/openapi3"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -134,7 +135,7 @@ func TestListModelsWithDefaultProvider(t *testing.T) {
}
}

func TestAbortRun(t *testing.T) {
func TestCancelRun(t *testing.T) {
tool := ToolDef{Instructions: "What is the capital of the united states?"}

run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
Expand All @@ -146,7 +147,7 @@ func TestAbortRun(t *testing.T) {
<-run.Events()

if err := run.Close(); err != nil {
t.Errorf("Error aborting run: %v", err)
t.Errorf("Error canceling run: %v", err)
}

if run.State() != Error {
Expand All @@ -158,6 +159,77 @@ func TestAbortRun(t *testing.T) {
}
}

func TestAbortChatCompletionRun(t *testing.T) {
tool := ToolDef{Instructions: "What is the capital of the united states?"}

run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
if err != nil {
t.Errorf("Error executing tool: %v", err)
}

// Abort the run after the first event from the LLM
for e := range run.Events() {
if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." {
break
}
}

if err := g.AbortRun(context.Background(), run); err != nil {
t.Errorf("Error aborting run: %v", err)
}

// Wait for run to stop
for range run.Events() {
continue
}

if run.State() != Finished {
t.Errorf("Unexpected run state: %s", run.State())
}

if out, err := run.Text(); err != nil {
t.Errorf("Error reading output: %v", err)
} else if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") {
t.Errorf("Unexpected output: %s", out)
}
}

func TestAbortCommandRun(t *testing.T) {
tool := ToolDef{Instructions: "#!/usr/bin/env bash\necho Hello, world!\nsleep 5\necho Hello, again!\nsleep 5"}

run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
if err != nil {
t.Errorf("Error executing tool: %v", err)
}

// Abort the run after the first event.
for e := range run.Events() {
if e.Call != nil && e.Call.Type == EventTypeChat {
time.Sleep(2 * time.Second)
break
}
}

if err := g.AbortRun(context.Background(), run); err != nil {
t.Errorf("Error aborting run: %v", err)
}

// Wait for run to stop
for range run.Events() {
continue
}

if run.State() != Finished {
t.Errorf("Unexpected run state: %s", run.State())
}

if out, err := run.Text(); err != nil {
t.Errorf("Error reading output: %v", err)
} else if !strings.Contains(out, "Hello, world!") || strings.Contains(out, "Hello, again!") || !strings.HasSuffix(out, "\nABORTED BY USER") {
t.Errorf("Unexpected output: %s", out)
}
}

func TestSimpleEvaluate(t *testing.T) {
tool := ToolDef{Instructions: "What is the capital of the united states?"}

Expand Down Expand Up @@ -844,6 +916,69 @@ func TestToolChat(t *testing.T) {
}
}

func TestAbortChat(t *testing.T) {
tool := ToolDef{
Chat: true,
Instructions: "You are a chat bot. Don't finish the conversation until I say 'bye'.",
Tools: []string{"sys.chat.finish"},
}

run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
if err != nil {
t.Fatalf("Error executing tool: %v", err)
}
inputs := []string{
"Tell me a joke.",
"What was my first message?",
}

// Just wait for the chat to start up.
for range run.Events() {
continue
}

for i, input := range inputs {
run, err = run.NextChat(context.Background(), input)
if err != nil {
t.Fatalf("Error sending next input %q: %v", input, err)
}

// Abort the run after the first event from the LLM
for e := range run.Events() {
if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." {
break
}
}

if i == 0 {
if err := g.AbortRun(context.Background(), run); err != nil {
t.Fatalf("Error aborting run: %v", err)
}
}

// Wait for the run to complete
for range run.Events() {
continue
}

out, err := run.Text()
if err != nil {
t.Errorf("Error reading output: %s", run.ErrorOutput())
t.Fatalf("Error reading output: %v", err)
}

if i == 0 {
if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") {
t.Fatalf("Unexpected output: %s", out)
}
} else {
if !strings.Contains(out, "Tell me a joke") {
t.Errorf("Unexpected output: %s", out)
}
}
}
}

func TestFileChat(t *testing.T) {
wd, err := os.Getwd()
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions run.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Run struct {
basicCommand bool

program *Program
id string
callsLock sync.RWMutex
calls CallFrames
rawOutput map[string]any
Expand Down Expand Up @@ -400,6 +401,7 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
if event.Run.Type == EventTypeRunStart {
r.callsLock.Lock()
r.program = &event.Run.Program
r.id = event.Run.ID
r.callsLock.Unlock()
} else if event.Run.Type == EventTypeRunFinish && event.Run.Error != "" {
r.state = Error
Expand Down

0 comments on commit a02ccda

Please sign in to comment.