Skip to content

Commit

Permalink
Use HTTP method constants instead of string literals (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt authored Mar 11, 2024
1 parent ee68c24 commit 9c783f2
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 40 deletions.
3 changes: 2 additions & 1 deletion account.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
)

type Account struct {
Expand Down Expand Up @@ -34,7 +35,7 @@ func (a *Account) UnmarshalJSON(data []byte) error {
// GetCurrentAccount returns the authenticated user or organization.
func (r *Client) GetCurrentAccount(ctx context.Context) (*Account, error) {
response := &Account{}
err := r.fetch(ctx, "GET", "/account", nil, response)
err := r.fetch(ctx, http.MethodGet, "/account", nil, response)
if err != nil {
return nil, fmt.Errorf("failed to list collections: %w", err)
}
Expand Down
30 changes: 15 additions & 15 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func TestListCollections(t *testing.T) {

func TestGetCollection(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/collections/super-resolution", r.URL.Path)

collection := &replicate.Collection{
Expand Down Expand Up @@ -248,7 +248,7 @@ func TestGetModel(t *testing.T) {

func TestCreateModel(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/models", r.URL.Path)

body, err := io.ReadAll(r.Body)
Expand Down Expand Up @@ -413,7 +413,7 @@ func TestGetModelVersion(t *testing.T) {

func TestCreatePrediction(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/predictions", r.URL.Path)

body, err := io.ReadAll(r.Body)
Expand Down Expand Up @@ -492,7 +492,7 @@ func TestCreatePrediction(t *testing.T) {

func TestCreatePredictionWithDeployment(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/deployments/owner/name/predictions", r.URL.Path)

body, err := io.ReadAll(r.Body)
Expand Down Expand Up @@ -573,7 +573,7 @@ func TestCreatePredictionWithDeployment(t *testing.T) {

func TestCreatePredictionWithModel(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/models/owner/model/predictions", r.URL.Path)

body, err := io.ReadAll(r.Body)
Expand Down Expand Up @@ -641,7 +641,7 @@ func TestCreatePredictionWithModel(t *testing.T) {

func TestCancelPrediction(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/predictions/ufawqhfynnddngldkgtslldrkq/cancel", r.URL.Path)

response := replicate.Prediction{
Expand Down Expand Up @@ -826,7 +826,7 @@ func TestListPredictions(t *testing.T) {

func TestGetPrediction(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/predictions/ufawqhfynnddngldkgtslldrkq", r.URL.Path)

prediction := &replicate.Prediction{
Expand Down Expand Up @@ -875,7 +875,7 @@ func TestWait(t *testing.T) {

i := 0
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/predictions/ufawqhfynnddngldkgtslldrkq", r.URL.Path)

prediction := &replicate.Prediction{
Expand Down Expand Up @@ -957,7 +957,7 @@ func TestWaitAsync(t *testing.T) {

i := 0
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/predictions/ufawqhfynnddngldkgtslldrkq", r.URL.Path)

prediction := &replicate.Prediction{
Expand Down Expand Up @@ -1024,7 +1024,7 @@ func TestWaitAsync(t *testing.T) {

func TestCreateTraining(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings", r.URL.Path)

training := &replicate.Training{
Expand Down Expand Up @@ -1077,7 +1077,7 @@ func TestCreateTraining(t *testing.T) {

func TestGetTraining(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/trainings/zz4ibbonubfz7carwiefibzgga", r.URL.Path)

training := &replicate.Training{
Expand Down Expand Up @@ -1114,7 +1114,7 @@ func TestGetTraining(t *testing.T) {

func TestCancelTraining(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/trainings/zz4ibbonubfz7carwiefibzgga/cancel", r.URL.Path)

training := &replicate.Training{
Expand Down Expand Up @@ -1151,7 +1151,7 @@ func TestCancelTraining(t *testing.T) {

func TestListTrainings(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/trainings", r.URL.Path)

response := &replicate.Page[replicate.Training]{
Expand Down Expand Up @@ -1338,7 +1338,7 @@ func TestStream(t *testing.T) {
mockServer := httptest.NewUnstartedServer(nil)
mockServer.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "POST" && r.URL.Path == "/predictions":
case r.Method == http.MethodPost && r.URL.Path == "/predictions":
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1373,7 +1373,7 @@ func TestStream(t *testing.T) {

w.WriteHeader(http.StatusCreated)
w.Write(responseBytes)
case r.Method == "GET" && r.URL.Path == "/predictions/ufawqhfynnddngldkgtslldrkq/stream":
case r.Method == http.MethodGet && r.URL.Path == "/predictions/ufawqhfynnddngldkgtslldrkq/stream":
flusher, _ := w.(http.Flusher)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
Expand Down
5 changes: 3 additions & 2 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
)

type Collection struct {
Expand Down Expand Up @@ -33,7 +34,7 @@ func (c *Collection) UnmarshalJSON(data []byte) error {
// ListCollections returns a list of all collections.
func (r *Client) ListCollections(ctx context.Context) (*Page[Collection], error) {
response := &Page[Collection]{}
err := r.fetch(ctx, "GET", "/collections", nil, response)
err := r.fetch(ctx, http.MethodGet, "/collections", nil, response)
if err != nil {
return nil, fmt.Errorf("failed to list collections: %w", err)
}
Expand All @@ -43,7 +44,7 @@ func (r *Client) ListCollections(ctx context.Context) (*Page[Collection], error)
// GetCollection returns a collection by slug.
func (r *Client) GetCollection(ctx context.Context, slug string) (*Collection, error) {
collection := &Collection{}
err := r.fetch(ctx, "GET", fmt.Sprintf("/collections/%s", slug), nil, collection)
err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/collections/%s", slug), nil, collection)
if err != nil {
return nil, fmt.Errorf("failed to get collection: %w", err)
}
Expand Down
5 changes: 3 additions & 2 deletions deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
)

type Deployment struct {
Expand Down Expand Up @@ -48,7 +49,7 @@ func (d *Deployment) UnmarshalJSON(data []byte) error {
func (r *Client) GetDeployment(ctx context.Context, deploymentOwner string, deploymentName string) (*Deployment, error) {
deployment := &Deployment{}
path := fmt.Sprintf("/deployments/%s/%s", deploymentOwner, deploymentName)
err := r.fetch(ctx, "GET", path, nil, deployment)
err := r.fetch(ctx, http.MethodGet, path, nil, deployment)
if err != nil {
return nil, fmt.Errorf("failed to get deployment: %w", err)
}
Expand All @@ -75,7 +76,7 @@ func (r *Client) CreatePredictionWithDeployment(ctx context.Context, deploymentO

prediction := &Prediction{}
path := fmt.Sprintf("/deployments/%s/%s/predictions", deploymentOwner, deploymentName)
err := r.fetch(ctx, "POST", path, data, prediction)
err := r.fetch(ctx, http.MethodPost, path, data, prediction)
if err != nil {
return nil, fmt.Errorf("failed to create prediction: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions files.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (r *Client) createFile(ctx context.Context, reader io.Reader, options Creat
return nil, fmt.Errorf("failed to close writer: %w", err)
}

req, err := r.newRequest(ctx, "POST", "/files", body)
req, err := r.newRequest(ctx, http.MethodPost, "/files", body)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
Expand All @@ -145,7 +145,7 @@ func (r *Client) createFile(ctx context.Context, reader io.Reader, options Creat
// ListFiles lists your files.
func (r *Client) ListFiles(ctx context.Context) (*Page[File], error) {
response := &Page[File]{}
err := r.fetch(ctx, "GET", "/files", nil, response)
err := r.fetch(ctx, http.MethodGet, "/files", nil, response)
if err != nil {
return nil, fmt.Errorf("failed to list files: %w", err)
}
Expand All @@ -156,7 +156,7 @@ func (r *Client) ListFiles(ctx context.Context) (*Page[File], error) {
// GetFile retrieves information about a file.
func (r *Client) GetFile(ctx context.Context, fileID string) (*File, error) {
file := &File{}
err := r.fetch(ctx, "GET", fmt.Sprintf("/files/%s", fileID), nil, file)
err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/files/%s", fileID), nil, file)
if err != nil {
return nil, fmt.Errorf("failed to get file: %w", err)
}
Expand All @@ -166,7 +166,7 @@ func (r *Client) GetFile(ctx context.Context, fileID string) (*File, error) {

// DeleteFile deletes a file.
func (r *Client) DeleteFile(ctx context.Context, fileID string) error {
err := r.fetch(ctx, "DELETE", fmt.Sprintf("/files/%s", fileID), nil, nil)
err := r.fetch(ctx, http.MethodDelete, fmt.Sprintf("/files/%s", fileID), nil, nil)
if err != nil {
return fmt.Errorf("failed to delete file: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion hardware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
)

type Hardware struct {
Expand Down Expand Up @@ -31,7 +32,7 @@ func (h *Hardware) UnmarshalJSON(data []byte) error {
// ListHardware returns a list of available hardware.
func (r *Client) ListHardware(ctx context.Context) (*[]Hardware, error) {
response := &[]Hardware{}
err := r.fetch(ctx, "GET", "/hardware", nil, response)
err := r.fetch(ctx, http.MethodGet, "/hardware", nil, response)
if err != nil {
return nil, fmt.Errorf("failed to list collections: %w", err)
}
Expand Down
12 changes: 6 additions & 6 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (m *ModelVersion) UnmarshalJSON(data []byte) error {
// ListModels lists public models.
func (r *Client) ListModels(ctx context.Context) (*Page[Model], error) {
response := &Page[Model]{}
err := r.fetch(ctx, "GET", "/models", nil, response)
err := r.fetch(ctx, http.MethodGet, "/models", nil, response)
if err != nil {
return nil, fmt.Errorf("failed to list models: %w", err)
}
Expand All @@ -86,7 +86,7 @@ func (r *Client) ListModels(ctx context.Context) (*Page[Model], error) {
// GetModel retrieves information about a model.
func (r *Client) GetModel(ctx context.Context, modelOwner string, modelName string) (*Model, error) {
model := &Model{}
err := r.fetch(ctx, "GET", fmt.Sprintf("/models/%s/%s", modelOwner, modelName), nil, model)
err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/models/%s/%s", modelOwner, modelName), nil, model)
if err != nil {
return nil, fmt.Errorf("failed to get model: %w", err)
}
Expand All @@ -107,7 +107,7 @@ func (r *Client) CreateModel(ctx context.Context, modelOwner string, modelName s
CreateModelOptions: options,
}

err := r.fetch(ctx, "POST", "/models", body, model)
err := r.fetch(ctx, http.MethodPost, "/models", body, model)
if err != nil {
return nil, fmt.Errorf("failed to create model: %w", err)
}
Expand All @@ -117,7 +117,7 @@ func (r *Client) CreateModel(ctx context.Context, modelOwner string, modelName s
// ListModelVersions lists the versions of a model.
func (r *Client) ListModelVersions(ctx context.Context, modelOwner string, modelName string) (*Page[ModelVersion], error) {
response := &Page[ModelVersion]{}
err := r.fetch(ctx, "GET", fmt.Sprintf("/models/%s/%s/versions", modelOwner, modelName), nil, response)
err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/models/%s/%s/versions", modelOwner, modelName), nil, response)
if err != nil {
return nil, fmt.Errorf("failed to list model versions: %w", err)
}
Expand All @@ -127,7 +127,7 @@ func (r *Client) ListModelVersions(ctx context.Context, modelOwner string, model
// GetModelVersion retrieves a specific version of a model.
func (r *Client) GetModelVersion(ctx context.Context, modelOwner string, modelName string, versionID string) (*ModelVersion, error) {
version := &ModelVersion{}
err := r.fetch(ctx, "GET", fmt.Sprintf("/models/%s/%s/versions/%s", modelOwner, modelName, versionID), nil, version)
err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/models/%s/%s/versions/%s", modelOwner, modelName, versionID), nil, version)
if err != nil {
return nil, fmt.Errorf("failed to get model version: %w", err)
}
Expand Down Expand Up @@ -161,7 +161,7 @@ func (r *Client) CreatePredictionWithModel(ctx context.Context, modelOwner strin
}

prediction := &Prediction{}
err := r.fetch(ctx, "POST", fmt.Sprintf("/models/%s/%s/predictions", modelOwner, modelName), data, prediction)
err := r.fetch(ctx, http.MethodPost, fmt.Sprintf("/models/%s/%s/predictions", modelOwner, modelName), data, prediction)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion paginate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package replicate
import (
"context"
"encoding/json"
"net/http"
)

// Page represents a paginated response from Replicate's API.
Expand Down Expand Up @@ -43,7 +44,7 @@ func Paginate[T any](ctx context.Context, client *Client, initialPage *Page[T])

for nextURL != nil {
page := &Page[T]{}
err := client.fetch(ctx, "GET", *nextURL, nil, page)
err := client.fetch(ctx, http.MethodGet, *nextURL, nil, page)
if err != nil {
errChan <- err
return
Expand Down
9 changes: 5 additions & 4 deletions prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
)
Expand Down Expand Up @@ -122,7 +123,7 @@ func (r *Client) CreatePrediction(ctx context.Context, version string, input Pre
}

prediction := &Prediction{}
err := r.fetch(ctx, "POST", "/predictions", data, prediction)
err := r.fetch(ctx, http.MethodPost, "/predictions", data, prediction)
if err != nil {
return nil, fmt.Errorf("failed to create prediction: %w", err)
}
Expand All @@ -133,7 +134,7 @@ func (r *Client) CreatePrediction(ctx context.Context, version string, input Pre
// ListPredictions returns a paginated list of predictions.
func (r *Client) ListPredictions(ctx context.Context) (*Page[Prediction], error) {
response := &Page[Prediction]{}
err := r.fetch(ctx, "GET", "/predictions", nil, response)
err := r.fetch(ctx, http.MethodGet, "/predictions", nil, response)
if err != nil {
return nil, fmt.Errorf("failed to list predictions: %w", err)
}
Expand All @@ -143,7 +144,7 @@ func (r *Client) ListPredictions(ctx context.Context) (*Page[Prediction], error)
// GetPrediction retrieves a prediction from the Replicate API by its ID.
func (r *Client) GetPrediction(ctx context.Context, id string) (*Prediction, error) {
prediction := &Prediction{}
err := r.fetch(ctx, "GET", fmt.Sprintf("/predictions/%s", id), nil, prediction)
err := r.fetch(ctx, http.MethodGet, fmt.Sprintf("/predictions/%s", id), nil, prediction)
if err != nil {
return nil, fmt.Errorf("failed to get prediction: %w", err)
}
Expand All @@ -153,7 +154,7 @@ func (r *Client) GetPrediction(ctx context.Context, id string) (*Prediction, err
// CancelPrediction cancels a running prediction by its ID.
func (r *Client) CancelPrediction(ctx context.Context, id string) (*Prediction, error) {
prediction := &Prediction{}
err := r.fetch(ctx, "POST", fmt.Sprintf("/predictions/%s/cancel", id), nil, prediction)
err := r.fetch(ctx, http.MethodPost, fmt.Sprintf("/predictions/%s/cancel", id), nil, prediction)
if err != nil {
return nil, fmt.Errorf("failed to cancel prediction: %w", err)
}
Expand Down
Loading

0 comments on commit 9c783f2

Please sign in to comment.