diff --git a/workspace.go b/workspace.go index d04a2e6..8bd5a6f 100644 --- a/workspace.go +++ b/workspace.go @@ -6,10 +6,13 @@ import ( "encoding/json" "fmt" "os" + "regexp" "strings" "time" ) +var conflictErrParser = regexp.MustCompile(`^.+500 Internal Server Error: conflict: (.+)/([^/]+) \(latest revision: (-?\d+), current revision: (-?\d+)\)$`) + type NotFoundInWorkspaceError struct { id string name string @@ -23,6 +26,29 @@ func newNotFoundInWorkspaceError(id, name string) *NotFoundInWorkspaceError { return &NotFoundInWorkspaceError{id: id, name: name} } +type ConflictInWorkspaceError struct { + ID string + Name string + LatestRevision string + CurrentRevision string +} + +func parsePossibleConflictInWorkspaceError(err error) error { + if err == nil { + return err + } + + matches := conflictErrParser.FindStringSubmatch(err.Error()) + if len(matches) != 5 { + return err + } + return &ConflictInWorkspaceError{ID: matches[1], Name: matches[2], LatestRevision: matches[3], CurrentRevision: matches[4]} +} + +func (e *ConflictInWorkspaceError) Error() string { + return fmt.Sprintf("conflict: %s/%s (latest revision: %s, current revision: %s)", e.ID, e.Name, e.LatestRevision, e.CurrentRevision) +} + func (g *GPTScript) CreateWorkspace(ctx context.Context, providerType string, fromWorkspaces ...string) (string, error) { out, err := g.runBasicCommand(ctx, "workspaces/create", map[string]any{ "providerType": providerType, @@ -123,6 +149,7 @@ func (g *GPTScript) RemoveAll(ctx context.Context, opts ...RemoveAllOptions) err type WriteFileInWorkspaceOptions struct { WorkspaceID string CreateRevision *bool + LatestRevision string } func (g *GPTScript) WriteFileInWorkspace(ctx context.Context, filePath string, contents []byte, opts ...WriteFileInWorkspaceOptions) error { @@ -134,6 +161,9 @@ func (g *GPTScript) WriteFileInWorkspace(ctx context.Context, filePath string, c if o.CreateRevision != nil { opt.CreateRevision = o.CreateRevision } + if o.LatestRevision != "" { + opt.LatestRevision = o.LatestRevision + } } if opt.WorkspaceID == "" { @@ -145,11 +175,12 @@ func (g *GPTScript) WriteFileInWorkspace(ctx context.Context, filePath string, c "contents": base64.StdEncoding.EncodeToString(contents), "filePath": filePath, "createRevision": opt.CreateRevision, + "latestRevision": opt.LatestRevision, "workspaceTool": g.globalOpts.WorkspaceTool, "env": g.globalOpts.Env, }) - return err + return parsePossibleConflictInWorkspaceError(err) } type DeleteFileInWorkspaceOptions struct { diff --git a/workspace_test.go b/workspace_test.go index 601d614..669f2b0 100644 --- a/workspace_test.go +++ b/workspace_test.go @@ -307,6 +307,92 @@ func TestDisableCreateRevisionsForFileInWorkspace(t *testing.T) { } } +func TestConflictsForFileInWorkspace(t *testing.T) { + id, err := g.CreateWorkspace(context.Background(), "directory") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + ce := (*ConflictInWorkspaceError)(nil) + // Writing a new file with a non-zero latest revision should fail + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevision: "1"}) + if err == nil || !errors.As(err, &ce) { + t.Errorf("Expected error writing file with non-zero latest revision: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test1"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err := g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Writing to the file with the latest revision should succeed + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test2"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevision: revisions[0].RevisionID}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 2 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Writing to the file with the same revision should fail + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test3"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevision: revisions[0].RevisionID}) + if err == nil || !errors.As(err, &ce) { + t.Errorf("Expected error writing file with same revision: %v", err) + } + + err = g.DeleteRevisionForFileInWorkspace(context.Background(), "test.txt", revisions[1].RevisionID, DeleteRevisionForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting revision for file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Ensure we can write a new file after deleting the latest revision + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test4"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevision: revisions[0].RevisionID}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.DeleteFileInWorkspace(context.Background(), "test.txt", DeleteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting file: %v", err) + } +} + func TestLsComplexWorkspace(t *testing.T) { id, err := g.CreateWorkspace(context.Background(), "directory") if err != nil { @@ -690,6 +776,96 @@ func TestRevisionsForFileInWorkspaceS3(t *testing.T) { } } +func TestConflictsForFileInWorkspaceS3(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" || os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" { + t.Skip("Skipping test because AWS credentials are not set") + } + + id, err := g.CreateWorkspace(context.Background(), "s3") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + ce := (*ConflictInWorkspaceError)(nil) + // Writing a new file with a non-zero latest revision should fail + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevision: "1"}) + if err == nil || !errors.As(err, &ce) { + t.Errorf("Expected error writing file with non-zero latest revision: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test1"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err := g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Writing to the file with the latest revision should succeed + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test2"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevision: revisions[0].RevisionID}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 2 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Writing to the file with the same revision should fail + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test3"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevision: revisions[0].RevisionID}) + if err == nil || !errors.As(err, &ce) { + t.Errorf("Expected error writing file with same revision: %v", err) + } + + err = g.DeleteRevisionForFileInWorkspace(context.Background(), "test.txt", revisions[1].RevisionID, DeleteRevisionForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting revision for file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Ensure we can write a new file after deleting the latest revision + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test4"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevision: revisions[0].RevisionID}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.DeleteFileInWorkspace(context.Background(), "test.txt", DeleteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting file: %v", err) + } +} + func TestDisableCreatingRevisionsForFileInWorkspaceS3(t *testing.T) { if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" || os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" { t.Skip("Skipping test because AWS credentials are not set")