diff --git a/client_test.go b/client_test.go index 9e1145c..6d3ebe7 100644 --- a/client_test.go +++ b/client_test.go @@ -1116,6 +1116,56 @@ func TestRun(t *testing.T) { assert.Equal(t, "Hello, world!", output) } +func TestRunWithVersionlessModel(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/models/owner/model/predictions": + assert.Equal(t, http.MethodPost, r.Method) + + // Check the request body + var requestBody map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&requestBody) + require.NoError(t, err) + assert.Equal(t, "Hello", requestBody["input"].(map[string]interface{})["prompt"]) + + // Respond with a prediction + prediction := replicate.Prediction{ + ID: "ndufagtsllfynwqhdngldkdrkq", + Model: "owner/model", + Status: replicate.Starting, + } + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(prediction) + case "/predictions/ndufagtsllfynwqhdngldkdrkq": + assert.Equal(t, http.MethodGet, r.Method) + prediction := replicate.Prediction{ + ID: "ndufagtsllfynwqhdngldkdrkq", + Model: "owner/model", + Status: replicate.Succeeded, + Output: "Hello, world!", + } + json.NewEncoder(w).Encode(prediction) + default: + t.Fatalf("Unexpected request to %s", r.URL.Path) + } + })) + defer mockServer.Close() + + client, err := replicate.NewClient( + replicate.WithToken("test-token"), + replicate.WithBaseURL(mockServer.URL), + ) + require.NoError(t, err) + + ctx := context.Background() + input := replicate.PredictionInput{"prompt": "Hello"} + output, err := client.Run(ctx, "owner/model", input, nil) + + require.NoError(t, err) + assert.NotNil(t, output) + assert.Equal(t, "Hello, world!", output) +} + func TestRunReturningModelError(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path {