Skip to content

Commit

Permalink
Add WithFileOutput to streaming methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt committed Sep 23, 2024
1 parent b745dd9 commit b67cc30
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 13 deletions.
18 changes: 11 additions & 7 deletions run.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,17 @@ func transformOutput(ctx context.Context, value interface{}, client *Client) (in
}
return v, nil
case string:
if strings.HasPrefix(v, "data:") {
return readDataURI(v)
}
if strings.HasPrefix(v, "https:") || strings.HasPrefix(v, "http:") {
return readHTTP(ctx, v, client)
}
return v, nil
return convertStringToFileOutput(ctx, v, client)
}
return value, nil
}

func convertStringToFileOutput(ctx context.Context, value string, client *Client) (interface{}, error) {
if strings.HasPrefix(value, "data:") {
return readDataURI(value)
}
if strings.HasPrefix(value, "https:") || strings.HasPrefix(value, "http:") {
return readHTTP(ctx, value, client)
}
return value, nil
}
Expand Down
45 changes: 39 additions & 6 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@ func (e *SSEEvent) String() string {
}
}

func (r *Client) Stream(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (<-chan SSEEvent, <-chan error) {
func (r *Client) Stream(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook, opts ...RunOption) (<-chan SSEEvent, <-chan error) {
options := runOptions{}
for _, opt := range opts {
opt(&options)
}

sseChan := make(chan SSEEvent, 64)
errChan := make(chan error, 64)

Expand All @@ -119,21 +124,26 @@ func (r *Client) Stream(ctx context.Context, identifier string, input Prediction
return sseChan, errChan
}

r.streamPrediction(ctx, prediction, nil, sseChan, errChan)
r.streamPrediction(ctx, prediction, nil, options, sseChan, errChan)

return sseChan, errChan
}

func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) (<-chan SSEEvent, <-chan error) {
func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction, opts ...RunOption) (<-chan SSEEvent, <-chan error) {
options := runOptions{}
for _, opt := range opts {
opt(&options)
}

sseChan := make(chan SSEEvent, 64)
errChan := make(chan error, 64)

r.streamPrediction(ctx, prediction, nil, sseChan, errChan)
r.streamPrediction(ctx, prediction, nil, options, sseChan, errChan)

return sseChan, errChan
}

func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) {
func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, options runOptions, sseChan chan SSEEvent, errChan chan error) {
url := prediction.URLs["stream"]
if url == "" {
select {
Expand Down Expand Up @@ -244,6 +254,29 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
continue
}

if options.useFileOutput && event.Type == SSETypeOutput {
data, err := convertStringToFileOutput(ctx, event.Data, r)
if err != nil {
select {
case errChan <- err:
default:
}
return
}

if file, ok := data.(*FileOutput); ok {
bytes, err := io.ReadAll(file)
if err != nil {
select {
case errChan <- err:
default:
}
return
}
event.Data = string(bytes)
}
}

select {
case sseChan <- *event:
case <-done:
Expand All @@ -267,7 +300,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
if err != nil {
if errors.Is(err, io.EOF) {
// Attempt to reconnect if the connection was closed before the stream was done
r.streamPrediction(ctx, prediction, lastEvent, sseChan, errChan)
r.streamPrediction(ctx, prediction, lastEvent, options, sseChan, errChan)
return
}

Expand Down

0 comments on commit b67cc30

Please sign in to comment.