diff --git a/account.go b/account.go index e1b7499..7017d03 100644 --- a/account.go +++ b/account.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" ) type Account struct { @@ -34,7 +35,7 @@ func (a *Account) UnmarshalJSON(data []byte) error { // GetCurrentAccount returns the authenticated user or organization. func (r *Client) GetCurrentAccount(ctx context.Context) (*Account, error) { response := &Account{} - err := r.fetch(ctx, "GET", "/account", nil, response) + err := r.fetch(ctx, http.MethodGet, "/account", nil, response) if err != nil { return nil, fmt.Errorf("failed to list collections: %w", err) } diff --git a/client_test.go b/client_test.go index b72f12b..d0fe446 100644 --- a/client_test.go +++ b/client_test.go @@ -124,7 +124,7 @@ func TestListCollections(t *testing.T) { func TestGetCollection(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "GET", r.Method) + assert.Equal(t, http.MethodGet, r.Method) assert.Equal(t, "/collections/super-resolution", r.URL.Path) collection := &replicate.Collection{ @@ -248,7 +248,7 @@ func TestGetModel(t *testing.T) { func TestCreateModel(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "POST", r.Method) + assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, "/models", r.URL.Path) body, err := io.ReadAll(r.Body) @@ -413,7 +413,7 @@ func TestGetModelVersion(t *testing.T) { func TestCreatePrediction(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "POST", r.Method) + assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, "/predictions", r.URL.Path) body, err := io.ReadAll(r.Body) @@ -492,7 +492,7 @@ func TestCreatePrediction(t *testing.T) { func TestCreatePredictionWithDeployment(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "POST", r.Method) + assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, "/deployments/owner/name/predictions", r.URL.Path) body, err := io.ReadAll(r.Body) @@ -573,7 +573,7 @@ func TestCreatePredictionWithDeployment(t *testing.T) { func TestCreatePredictionWithModel(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "POST", r.Method) + assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, "/models/owner/model/predictions", r.URL.Path) body, err := io.ReadAll(r.Body) @@ -641,7 +641,7 @@ func TestCreatePredictionWithModel(t *testing.T) { func TestCancelPrediction(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "POST", r.Method) + assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, "/predictions/ufawqhfynnddngldkgtslldrkq/cancel", r.URL.Path) response := replicate.Prediction{ @@ -826,7 +826,7 @@ func TestListPredictions(t *testing.T) { func TestGetPrediction(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "GET", r.Method) + assert.Equal(t, http.MethodGet, r.Method) assert.Equal(t, "/predictions/ufawqhfynnddngldkgtslldrkq", r.URL.Path) prediction := &replicate.Prediction{ @@ -875,7 +875,7 @@ func TestWait(t *testing.T) { i := 0 mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "GET", r.Method) + assert.Equal(t, http.MethodGet, r.Method) assert.Equal(t, "/predictions/ufawqhfynnddngldkgtslldrkq", r.URL.Path) prediction := &replicate.Prediction{ @@ -957,7 +957,7 @@ func TestWaitAsync(t *testing.T) { i := 0 mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "GET", r.Method) + assert.Equal(t, http.MethodGet, r.Method) assert.Equal(t, "/predictions/ufawqhfynnddngldkgtslldrkq", r.URL.Path) prediction := &replicate.Prediction{ @@ -1024,7 +1024,7 @@ func TestWaitAsync(t *testing.T) { func TestCreateTraining(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "POST", r.Method) + assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, "/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings", r.URL.Path) training := &replicate.Training{ @@ -1077,7 +1077,7 @@ func TestCreateTraining(t *testing.T) { func TestGetTraining(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "GET", r.Method) + assert.Equal(t, http.MethodGet, r.Method) assert.Equal(t, "/trainings/zz4ibbonubfz7carwiefibzgga", r.URL.Path) training := &replicate.Training{ @@ -1114,7 +1114,7 @@ func TestGetTraining(t *testing.T) { func TestCancelTraining(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "POST", r.Method) + assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, "/trainings/zz4ibbonubfz7carwiefibzgga/cancel", r.URL.Path) training := &replicate.Training{ @@ -1151,7 +1151,7 @@ func TestCancelTraining(t *testing.T) { func TestListTrainings(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "GET", r.Method) + assert.Equal(t, http.MethodGet, r.Method) assert.Equal(t, "/trainings", r.URL.Path) response := &replicate.Page[replicate.Training]{ @@ -1338,7 +1338,7 @@ func TestStream(t *testing.T) { mockServer := httptest.NewUnstartedServer(nil) mockServer.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == "POST" && r.URL.Path == "/predictions": + case r.Method == http.MethodPost && r.URL.Path == "/predictions": body, err := io.ReadAll(r.Body) if err != nil { t.Fatal(err) @@ -1373,7 +1373,7 @@ func TestStream(t *testing.T) { w.WriteHeader(http.StatusCreated) w.Write(responseBytes) - case r.Method == "GET" && r.URL.Path == "/predictions/ufawqhfynnddngldkgtslldrkq/stream": + case r.Method == http.MethodGet && r.URL.Path == "/predictions/ufawqhfynnddngldkgtslldrkq/stream": flusher, _ := w.(http.Flusher) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") diff --git a/collection.go b/collection.go index de14e1d..33abc5e 100644 --- a/collection.go +++ b/collection.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" ) type Collection struct { @@ -33,7 +34,7 @@ func (c *Collection) UnmarshalJSON(data []byte) error { // ListCollections returns a list of all collections. func (r *Client) ListCollections(ctx context.Context) (*Page[Collection], error) { response := &Page[Collection]{} - err := r.fetch(ctx, "GET", "/collections", nil, response) + err := r.fetch(ctx, http.MethodGet, "/collections", nil, response) if err != nil { return nil, fmt.Errorf("failed to list collections: %w", err) } @@ -43,7 +44,7 @@ func (r *Client) ListCollections(ctx context.Context) (*Page[Collection], error) // GetCollection returns a collection by slug. func (r *Client) GetCollection(ctx context.Context, slug string) (*Collection, error) { collection := &Collection{} - err := r.fetch(ctx, "GET", fmt.Sprintf("/collections/%s", slug), nil, collection) + err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/collections/%s", slug), nil, collection) if err != nil { return nil, fmt.Errorf("failed to get collection: %w", err) } diff --git a/deployment.go b/deployment.go index 3b9d304..c2b17a4 100644 --- a/deployment.go +++ b/deployment.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" ) type Deployment struct { @@ -48,7 +49,7 @@ func (d *Deployment) UnmarshalJSON(data []byte) error { func (r *Client) GetDeployment(ctx context.Context, deploymentOwner string, deploymentName string) (*Deployment, error) { deployment := &Deployment{} path := fmt.Sprintf("/deployments/%s/%s", deploymentOwner, deploymentName) - err := r.fetch(ctx, "GET", path, nil, deployment) + err := r.fetch(ctx, http.MethodGet, path, nil, deployment) if err != nil { return nil, fmt.Errorf("failed to get deployment: %w", err) } @@ -75,7 +76,7 @@ func (r *Client) CreatePredictionWithDeployment(ctx context.Context, deploymentO prediction := &Prediction{} path := fmt.Sprintf("/deployments/%s/%s/predictions", deploymentOwner, deploymentName) - err := r.fetch(ctx, "POST", path, data, prediction) + err := r.fetch(ctx, http.MethodPost, path, data, prediction) if err != nil { return nil, fmt.Errorf("failed to create prediction: %w", err) } diff --git a/files.go b/files.go index 4bdc3ff..b2921a2 100644 --- a/files.go +++ b/files.go @@ -127,7 +127,7 @@ func (r *Client) createFile(ctx context.Context, reader io.Reader, options Creat return nil, fmt.Errorf("failed to close writer: %w", err) } - req, err := r.newRequest(ctx, "POST", "/files", body) + req, err := r.newRequest(ctx, http.MethodPost, "/files", body) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -145,7 +145,7 @@ func (r *Client) createFile(ctx context.Context, reader io.Reader, options Creat // ListFiles lists your files. func (r *Client) ListFiles(ctx context.Context) (*Page[File], error) { response := &Page[File]{} - err := r.fetch(ctx, "GET", "/files", nil, response) + err := r.fetch(ctx, http.MethodGet, "/files", nil, response) if err != nil { return nil, fmt.Errorf("failed to list files: %w", err) } @@ -156,7 +156,7 @@ func (r *Client) ListFiles(ctx context.Context) (*Page[File], error) { // GetFile retrieves information about a file. func (r *Client) GetFile(ctx context.Context, fileID string) (*File, error) { file := &File{} - err := r.fetch(ctx, "GET", fmt.Sprintf("/files/%s", fileID), nil, file) + err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/files/%s", fileID), nil, file) if err != nil { return nil, fmt.Errorf("failed to get file: %w", err) } @@ -166,7 +166,7 @@ func (r *Client) GetFile(ctx context.Context, fileID string) (*File, error) { // DeleteFile deletes a file. func (r *Client) DeleteFile(ctx context.Context, fileID string) error { - err := r.fetch(ctx, "DELETE", fmt.Sprintf("/files/%s", fileID), nil, nil) + err := r.fetch(ctx, http.MethodDelete, fmt.Sprintf("/files/%s", fileID), nil, nil) if err != nil { return fmt.Errorf("failed to delete file: %w", err) } diff --git a/hardware.go b/hardware.go index a418054..28e0275 100644 --- a/hardware.go +++ b/hardware.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" ) type Hardware struct { @@ -31,7 +32,7 @@ func (h *Hardware) UnmarshalJSON(data []byte) error { // ListHardware returns a list of available hardware. func (r *Client) ListHardware(ctx context.Context) (*[]Hardware, error) { response := &[]Hardware{} - err := r.fetch(ctx, "GET", "/hardware", nil, response) + err := r.fetch(ctx, http.MethodGet, "/hardware", nil, response) if err != nil { return nil, fmt.Errorf("failed to list collections: %w", err) } diff --git a/model.go b/model.go index 29631fc..6bac8ff 100644 --- a/model.go +++ b/model.go @@ -76,7 +76,7 @@ func (m *ModelVersion) UnmarshalJSON(data []byte) error { // ListModels lists public models. func (r *Client) ListModels(ctx context.Context) (*Page[Model], error) { response := &Page[Model]{} - err := r.fetch(ctx, "GET", "/models", nil, response) + err := r.fetch(ctx, http.MethodGet, "/models", nil, response) if err != nil { return nil, fmt.Errorf("failed to list models: %w", err) } @@ -86,7 +86,7 @@ func (r *Client) ListModels(ctx context.Context) (*Page[Model], error) { // GetModel retrieves information about a model. func (r *Client) GetModel(ctx context.Context, modelOwner string, modelName string) (*Model, error) { model := &Model{} - err := r.fetch(ctx, "GET", fmt.Sprintf("/models/%s/%s", modelOwner, modelName), nil, model) + err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/models/%s/%s", modelOwner, modelName), nil, model) if err != nil { return nil, fmt.Errorf("failed to get model: %w", err) } @@ -107,7 +107,7 @@ func (r *Client) CreateModel(ctx context.Context, modelOwner string, modelName s CreateModelOptions: options, } - err := r.fetch(ctx, "POST", "/models", body, model) + err := r.fetch(ctx, http.MethodPost, "/models", body, model) if err != nil { return nil, fmt.Errorf("failed to create model: %w", err) } @@ -117,7 +117,7 @@ func (r *Client) CreateModel(ctx context.Context, modelOwner string, modelName s // ListModelVersions lists the versions of a model. func (r *Client) ListModelVersions(ctx context.Context, modelOwner string, modelName string) (*Page[ModelVersion], error) { response := &Page[ModelVersion]{} - err := r.fetch(ctx, "GET", fmt.Sprintf("/models/%s/%s/versions", modelOwner, modelName), nil, response) + err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/models/%s/%s/versions", modelOwner, modelName), nil, response) if err != nil { return nil, fmt.Errorf("failed to list model versions: %w", err) } @@ -127,7 +127,7 @@ func (r *Client) ListModelVersions(ctx context.Context, modelOwner string, model // GetModelVersion retrieves a specific version of a model. func (r *Client) GetModelVersion(ctx context.Context, modelOwner string, modelName string, versionID string) (*ModelVersion, error) { version := &ModelVersion{} - err := r.fetch(ctx, "GET", fmt.Sprintf("/models/%s/%s/versions/%s", modelOwner, modelName, versionID), nil, version) + err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/models/%s/%s/versions/%s", modelOwner, modelName, versionID), nil, version) if err != nil { return nil, fmt.Errorf("failed to get model version: %w", err) } @@ -161,7 +161,7 @@ func (r *Client) CreatePredictionWithModel(ctx context.Context, modelOwner strin } prediction := &Prediction{} - err := r.fetch(ctx, "POST", fmt.Sprintf("/models/%s/%s/predictions", modelOwner, modelName), data, prediction) + err := r.fetch(ctx, http.MethodPost, fmt.Sprintf("/models/%s/%s/predictions", modelOwner, modelName), data, prediction) if err != nil { return nil, err } diff --git a/paginate.go b/paginate.go index 2e04de4..8f88fc2 100644 --- a/paginate.go +++ b/paginate.go @@ -3,6 +3,7 @@ package replicate import ( "context" "encoding/json" + "net/http" ) // Page represents a paginated response from Replicate's API. @@ -43,7 +44,7 @@ func Paginate[T any](ctx context.Context, client *Client, initialPage *Page[T]) for nextURL != nil { page := &Page[T]{} - err := client.fetch(ctx, "GET", *nextURL, nil, page) + err := client.fetch(ctx, http.MethodGet, *nextURL, nil, page) if err != nil { errChan <- err return diff --git a/prediction.go b/prediction.go index dc5d1e8..c03e95b 100644 --- a/prediction.go +++ b/prediction.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "regexp" "strings" ) @@ -122,7 +123,7 @@ func (r *Client) CreatePrediction(ctx context.Context, version string, input Pre } prediction := &Prediction{} - err := r.fetch(ctx, "POST", "/predictions", data, prediction) + err := r.fetch(ctx, http.MethodPost, "/predictions", data, prediction) if err != nil { return nil, fmt.Errorf("failed to create prediction: %w", err) } @@ -133,7 +134,7 @@ func (r *Client) CreatePrediction(ctx context.Context, version string, input Pre // ListPredictions returns a paginated list of predictions. func (r *Client) ListPredictions(ctx context.Context) (*Page[Prediction], error) { response := &Page[Prediction]{} - err := r.fetch(ctx, "GET", "/predictions", nil, response) + err := r.fetch(ctx, http.MethodGet, "/predictions", nil, response) if err != nil { return nil, fmt.Errorf("failed to list predictions: %w", err) } @@ -143,7 +144,7 @@ func (r *Client) ListPredictions(ctx context.Context) (*Page[Prediction], error) // GetPrediction retrieves a prediction from the Replicate API by its ID. func (r *Client) GetPrediction(ctx context.Context, id string) (*Prediction, error) { prediction := &Prediction{} - err := r.fetch(ctx, "GET", fmt.Sprintf("/predictions/%s", id), nil, prediction) + err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/predictions/%s", id), nil, prediction) if err != nil { return nil, fmt.Errorf("failed to get prediction: %w", err) } @@ -153,7 +154,7 @@ func (r *Client) GetPrediction(ctx context.Context, id string) (*Prediction, err // CancelPrediction cancels a running prediction by its ID. func (r *Client) CancelPrediction(ctx context.Context, id string) (*Prediction, error) { prediction := &Prediction{} - err := r.fetch(ctx, "POST", fmt.Sprintf("/predictions/%s/cancel", id), nil, prediction) + err := r.fetch(ctx, http.MethodPost, fmt.Sprintf("/predictions/%s/cancel", id), nil, prediction) if err != nil { return nil, fmt.Errorf("failed to cancel prediction: %w", err) } diff --git a/training.go b/training.go index 2722061..938bec5 100644 --- a/training.go +++ b/training.go @@ -3,6 +3,7 @@ package replicate import ( "context" "fmt" + "net/http" ) type Training Prediction @@ -25,7 +26,7 @@ func (r *Client) CreateTraining(ctx context.Context, modelOwner string, modelNam training := &Training{} path := fmt.Sprintf("/models/%s/%s/versions/%s/trainings", modelOwner, modelName, version) - err := r.fetch(ctx, "POST", path, data, training) + err := r.fetch(ctx, http.MethodPost, path, data, training) if err != nil { return nil, fmt.Errorf("failed to create training: %w", err) } @@ -36,7 +37,7 @@ func (r *Client) CreateTraining(ctx context.Context, modelOwner string, modelNam // ListTrainings returns a list of trainings. func (r *Client) ListTrainings(ctx context.Context) (*Page[Training], error) { response := &Page[Training]{} - err := r.fetch(ctx, "GET", "/trainings", nil, response) + err := r.fetch(ctx, http.MethodGet, "/trainings", nil, response) if err != nil { return nil, fmt.Errorf("failed to list trainings: %w", err) } @@ -46,7 +47,7 @@ func (r *Client) ListTrainings(ctx context.Context) (*Page[Training], error) { // GetTraining sends a request to the Replicate API to get a training. func (r *Client) GetTraining(ctx context.Context, trainingID string) (*Training, error) { training := &Training{} - err := r.fetch(ctx, "GET", fmt.Sprintf("/trainings/%s", trainingID), nil, training) + err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/trainings/%s", trainingID), nil, training) if err != nil { return nil, fmt.Errorf("failed to get training: %w", err) } @@ -57,7 +58,7 @@ func (r *Client) GetTraining(ctx context.Context, trainingID string) (*Training, // CancelTraining sends a request to the Replicate API to cancel a training. func (r *Client) CancelTraining(ctx context.Context, trainingID string) (*Training, error) { training := &Training{} - err := r.fetch(ctx, "POST", fmt.Sprintf("/trainings/%s/cancel", trainingID), nil, training) + err := r.fetch(ctx, http.MethodPost, fmt.Sprintf("/trainings/%s/cancel", trainingID), nil, training) if err != nil { return nil, fmt.Errorf("failed to cancel training: %w", err) }