Skip to content

Commit

Permalink
Add fields for other prediction metrics (#53)
Browse files Browse the repository at this point in the history
* Add fields to Prediction metrics

* Extract PredictionMetrics into separate struct

* Add test coverage for prediction metrics
  • Loading branch information
mattt authored Mar 11, 2024
1 parent 39eb976 commit 442cf0a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
23 changes: 23 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,23 @@ func TestWait(t *testing.T) {

if statuses[i] == replicate.Succeeded {
prediction.Output = map[string]interface{}{"text": "Hello, Alice"}

startedAt := "2022-04-26T22:13:06.324088Z"
prediction.StartedAt = &startedAt

completedAt := "2022-04-26T22:13:07.224088Z"
prediction.CompletedAt = &completedAt

predictTime := 0.5
totalTime := 1.0
inputTokenCount := 1
outputTokenCount := 2
prediction.Metrics = &replicate.PredictionMetrics{
PredictTime: &predictTime,
TotalTime: &totalTime,
InputTokenCount: &inputTokenCount,
OutputTokenCount: &outputTokenCount,
}
}

if i < len(statuses)-1 {
Expand Down Expand Up @@ -901,6 +918,12 @@ func TestWait(t *testing.T) {
assert.Equal(t, replicate.Succeeded, prediction.Status)
assert.Equal(t, replicate.PredictionInput{"text": "Alice"}, prediction.Input)
assert.Equal(t, map[string]interface{}{"text": "Hello, Alice"}, prediction.Output)
assert.Equal(t, "2022-04-26T22:13:06.324088Z", *prediction.StartedAt)
assert.Equal(t, "2022-04-26T22:13:07.224088Z", *prediction.CompletedAt)
assert.Equal(t, 0.5, *prediction.Metrics.PredictTime)
assert.Equal(t, 1.0, *prediction.Metrics.TotalTime)
assert.Equal(t, 1, *prediction.Metrics.InputTokenCount)
assert.Equal(t, 2, *prediction.Metrics.OutputTokenCount)
}

func TestWaitAsync(t *testing.T) {
Expand Down
37 changes: 22 additions & 15 deletions prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,16 @@ const (
)

type Prediction struct {
ID string `json:"id"`
Status Status `json:"status"`
Model string `json:"model"`
Version string `json:"version"`
Input PredictionInput `json:"input"`
Output PredictionOutput `json:"output,omitempty"`
Source Source `json:"source"`
Error interface{} `json:"error,omitempty"`
Logs *string `json:"logs,omitempty"`
Metrics *struct {
PredictTime *float64 `json:"predict_time,omitempty"`
} `json:"metrics,omitempty"`
ID string `json:"id"`
Status Status `json:"status"`
Model string `json:"model"`
Version string `json:"version"`
Input PredictionInput `json:"input"`
Output PredictionOutput `json:"output,omitempty"`
Source Source `json:"source"`
Error interface{} `json:"error,omitempty"`
Logs *string `json:"logs,omitempty"`
Metrics *PredictionMetrics `json:"metrics,omitempty"`
Webhook *string `json:"webhook,omitempty"`
WebhookEventsFilter []WebhookEventType `json:"webhook_events_filter,omitempty"`
URLs map[string]string `json:"urls,omitempty"`
Expand All @@ -53,6 +51,18 @@ func (p *Prediction) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, alias)
}

type PredictionInput map[string]interface{}
type PredictionOutput interface{}

type PredictionMetrics struct {
PredictTime *float64 `json:"predict_time,omitempty"`
TotalTime *float64 `json:"total_time,omitempty"`
InputTokenCount *int `json:"input_token_count,omitempty"`
OutputTokenCount *int `json:"output_token_count,omitempty"`
TimeToFirstToken *float64 `json:"time_to_first_token,omitempty"`
TokensPerSecond *float64 `json:"tokens_per_second,omitempty"`
}

type PredictionProgress struct {
Percentage float64
Current int
Expand Down Expand Up @@ -93,9 +103,6 @@ func (p Prediction) Progress() *PredictionProgress {
return nil
}

type PredictionInput map[string]interface{}
type PredictionOutput interface{}

// CreatePrediction sends a request to the Replicate API to create a prediction.
func (r *Client) CreatePrediction(ctx context.Context, version string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error) {
data := map[string]interface{}{
Expand Down

0 comments on commit 442cf0a

Please sign in to comment.