From 8ab3f73df160fd5ca515ce14bda3498dd43aa504 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 11 Mar 2024 03:20:44 -0700 Subject: [PATCH 1/3] Add fields to Prediction metrics --- prediction.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/prediction.go b/prediction.go index f578bc4..707372f 100644 --- a/prediction.go +++ b/prediction.go @@ -26,7 +26,12 @@ type Prediction struct { Error interface{} `json:"error,omitempty"` Logs *string `json:"logs,omitempty"` Metrics *struct { - PredictTime *float64 `json:"predict_time,omitempty"` + PredictTime *float64 `json:"predict_time,omitempty"` + TotalTime *float64 `json:"total_time,omitempty"` + InputTokenCount *int `json:"input_token_count,omitempty"` + OutputTokenCount *int `json:"output_token_count,omitempty"` + TimeToFirstToken *float64 `json:"time_to_first_token,omitempty"` + TokensPerSecond *float64 `json:"tokens_per_second,omitempty"` } `json:"metrics,omitempty"` Webhook *string `json:"webhook,omitempty"` WebhookEventsFilter []WebhookEventType `json:"webhook_events_filter,omitempty"` From af1ef9c084050cd89b431c6173a025df1b201ecf Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 11 Mar 2024 03:22:07 -0700 Subject: [PATCH 2/3] Extract PredictionMetrics into separate struct --- prediction.go | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/prediction.go b/prediction.go index 707372f..dc5d1e8 100644 --- a/prediction.go +++ b/prediction.go @@ -16,23 +16,16 @@ const ( ) type Prediction struct { - ID string `json:"id"` - Status Status `json:"status"` - Model string `json:"model"` - Version string `json:"version"` - Input PredictionInput `json:"input"` - Output PredictionOutput `json:"output,omitempty"` - Source Source `json:"source"` - Error interface{} `json:"error,omitempty"` - Logs *string `json:"logs,omitempty"` - Metrics *struct { - PredictTime *float64 `json:"predict_time,omitempty"` - TotalTime *float64 `json:"total_time,omitempty"` - InputTokenCount *int `json:"input_token_count,omitempty"` - OutputTokenCount *int `json:"output_token_count,omitempty"` - TimeToFirstToken *float64 `json:"time_to_first_token,omitempty"` - TokensPerSecond *float64 `json:"tokens_per_second,omitempty"` - } `json:"metrics,omitempty"` + ID string `json:"id"` + Status Status `json:"status"` + Model string `json:"model"` + Version string `json:"version"` + Input PredictionInput `json:"input"` + Output PredictionOutput `json:"output,omitempty"` + Source Source `json:"source"` + Error interface{} `json:"error,omitempty"` + Logs *string `json:"logs,omitempty"` + Metrics *PredictionMetrics `json:"metrics,omitempty"` Webhook *string `json:"webhook,omitempty"` WebhookEventsFilter []WebhookEventType `json:"webhook_events_filter,omitempty"` URLs map[string]string `json:"urls,omitempty"` @@ -58,6 +51,18 @@ func (p *Prediction) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, alias) } +type PredictionInput map[string]interface{} +type PredictionOutput interface{} + +type PredictionMetrics struct { + PredictTime *float64 `json:"predict_time,omitempty"` + TotalTime *float64 `json:"total_time,omitempty"` + InputTokenCount *int `json:"input_token_count,omitempty"` + OutputTokenCount *int `json:"output_token_count,omitempty"` + TimeToFirstToken *float64 `json:"time_to_first_token,omitempty"` + TokensPerSecond *float64 `json:"tokens_per_second,omitempty"` +} + type PredictionProgress struct { Percentage float64 Current int @@ -98,9 +103,6 @@ func (p Prediction) Progress() *PredictionProgress { return nil } -type PredictionInput map[string]interface{} -type PredictionOutput interface{} - // CreatePrediction sends a request to the Replicate API to create a prediction. func (r *Client) CreatePrediction(ctx context.Context, version string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error) { data := map[string]interface{}{ From 0da33072c0080aa0bc93c635f8b74fdbbdce8735 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 11 Mar 2024 03:36:29 -0700 Subject: [PATCH 3/3] Add test coverage for prediction metrics --- client_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/client_test.go b/client_test.go index d3385f7..0daaead 100644 --- a/client_test.go +++ b/client_test.go @@ -863,6 +863,23 @@ func TestWait(t *testing.T) { if statuses[i] == replicate.Succeeded { prediction.Output = map[string]interface{}{"text": "Hello, Alice"} + + startedAt := "2022-04-26T22:13:06.324088Z" + prediction.StartedAt = &startedAt + + completedAt := "2022-04-26T22:13:07.224088Z" + prediction.CompletedAt = &completedAt + + predictTime := 0.5 + totalTime := 1.0 + inputTokenCount := 1 + outputTokenCount := 2 + prediction.Metrics = &replicate.PredictionMetrics{ + PredictTime: &predictTime, + TotalTime: &totalTime, + InputTokenCount: &inputTokenCount, + OutputTokenCount: &outputTokenCount, + } } if i < len(statuses)-1 { @@ -901,6 +918,12 @@ func TestWait(t *testing.T) { assert.Equal(t, replicate.Succeeded, prediction.Status) assert.Equal(t, replicate.PredictionInput{"text": "Alice"}, prediction.Input) assert.Equal(t, map[string]interface{}{"text": "Hello, Alice"}, prediction.Output) + assert.Equal(t, "2022-04-26T22:13:06.324088Z", *prediction.StartedAt) + assert.Equal(t, "2022-04-26T22:13:07.224088Z", *prediction.CompletedAt) + assert.Equal(t, 0.5, *prediction.Metrics.PredictTime) + assert.Equal(t, 1.0, *prediction.Metrics.TotalTime) + assert.Equal(t, 1, *prediction.Metrics.InputTokenCount) + assert.Equal(t, 2, *prediction.Metrics.OutputTokenCount) } func TestWaitAsync(t *testing.T) {