Skip to content

Commit

Permalink
Add test coverage for running a versionless model
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt committed Oct 2, 2024
1 parent 030989b commit e52be91
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,56 @@ func TestRun(t *testing.T) {
assert.Equal(t, "Hello, world!", output)
}

func TestRunWithVersionlessModel(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/models/owner/model/predictions":
assert.Equal(t, http.MethodPost, r.Method)

// Check the request body
var requestBody map[string]interface{}
err := json.NewDecoder(r.Body).Decode(&requestBody)
require.NoError(t, err)
assert.Equal(t, "Hello", requestBody["input"].(map[string]interface{})["prompt"])

// Respond with a prediction
prediction := replicate.Prediction{
ID: "ndufagtsllfynwqhdngldkdrkq",
Model: "owner/model",
Status: replicate.Starting,
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(prediction)
case "/predictions/ndufagtsllfynwqhdngldkdrkq":
assert.Equal(t, http.MethodGet, r.Method)
prediction := replicate.Prediction{
ID: "ndufagtsllfynwqhdngldkdrkq",
Model: "owner/model",
Status: replicate.Succeeded,
Output: "Hello, world!",
}
json.NewEncoder(w).Encode(prediction)
default:
t.Fatalf("Unexpected request to %s", r.URL.Path)
}
}))
defer mockServer.Close()

client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx := context.Background()
input := replicate.PredictionInput{"prompt": "Hello"}
output, err := client.Run(ctx, "owner/model", input, nil)

require.NoError(t, err)
assert.NotNil(t, output)
assert.Equal(t, "Hello, world!", output)
}

func TestRunReturningModelError(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
Expand Down

0 comments on commit e52be91

Please sign in to comment.