diff --git a/stream.go b/stream.go index 6a09863..e91b379 100644 --- a/stream.go +++ b/stream.go @@ -12,6 +12,7 @@ import ( "unicode/utf8" "github.com/launchdarkly/eventsource" + "github.com/replicate/replicate-go/streaming" "github.com/vincent-petithory/dataurl" "golang.org/x/sync/errgroup" ) @@ -143,13 +144,13 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) ( return sseChan, errChan } -func (r *Client) StreamPredictionFiles(ctx context.Context, prediction *Prediction) (<-chan io.ReadCloser, error) { +func (r *Client) StreamPredictionFiles(ctx context.Context, prediction *Prediction) (<-chan streaming.File, error) { url := prediction.URLs["stream"] if url == "" { return nil, errors.New("streaming not supported or not enabled for this prediction") } - ch := make(chan io.ReadCloser) + ch := make(chan streaming.File) go r.streamFilesTo(ctx, ch, url, "") return ch, nil @@ -220,28 +221,63 @@ func (r *Client) streamTextTo(ctx context.Context, writer *io.PipeWriter, url st } } -type errorReader struct { +type dataURL struct { + url string +} + +var _ streaming.File = &dataURL{} + +func (d *dataURL) Body(_ context.Context) (io.ReadCloser, error) { + data, err := dataurl.DecodeString(d.url) + + if err != nil { + return nil, err + } + + return io.NopCloser(bytes.NewReader(data.Data)), nil +} + +type httpURL struct { + c *http.Client + url string +} + +var _ streaming.File = &httpURL{} + +func (h *httpURL) Body(ctx context.Context) (io.ReadCloser, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, h.url, nil) + if err != nil { + return nil, err + } + resp, err := h.c.Do(req) + if err != nil { + return nil, err + } + return resp.Body, nil +} + +type errWrapper struct { err error } -var _ io.ReadCloser = &errorReader{} +var _ streaming.File = &errWrapper{} -func errReader(err error) io.ReadCloser { - return &errorReader{err: err} +func fileError(err error) streaming.File { + return &errWrapper{err: err} } -func (e *errorReader) Read(p []byte) (int, error) { - return 0, e.err +func (e *errWrapper) Body(_ context.Context) (io.ReadCloser, error) { + return nil, e.err } -func (e *errorReader) Close() error { +func (e *errWrapper) Close() error { return nil } -func (r *Client) streamFilesTo(ctx context.Context, out chan<- io.ReadCloser, url string, lastEventID string) { +func (r *Client) streamFilesTo(ctx context.Context, out chan<- streaming.File, url string, lastEventID string) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - out <- errReader(err) + out <- fileError(err) close(out) return } @@ -255,13 +291,13 @@ func (r *Client) streamFilesTo(ctx context.Context, out chan<- io.ReadCloser, ur resp, err := r.c.Do(req) if err != nil { - out <- errReader(fmt.Errorf("failed to send request: %w", err)) + out <- fileError(fmt.Errorf("failed to send request: %w", err)) close(out) return } if resp.StatusCode != http.StatusOK { - out <- errReader(fmt.Errorf("received invalid status code: %d", resp.StatusCode)) + out <- fileError(fmt.Errorf("received invalid status code: %d", resp.StatusCode)) close(out) return } @@ -275,7 +311,7 @@ func (r *Client) streamFilesTo(ctx context.Context, out chan<- io.ReadCloser, ur r.streamFilesTo(ctx, out, url, lastEventID) return } - out <- errReader(fmt.Errorf("Failed to get token: %w", err)) + out <- fileError(fmt.Errorf("Failed to get token: %w", err)) close(out) return } @@ -283,29 +319,9 @@ func (r *Client) streamFilesTo(ctx context.Context, out chan<- io.ReadCloser, ur switch event.Event() { case SSETypeOutput: if strings.HasPrefix(event.Data(), "data:") { - data, err := dataurl.DecodeString(event.Data()) - - if err != nil { - out <- errReader(err) - close(out) - return - } - - out <- io.NopCloser(bytes.NewReader(data.Data)) + out <- &dataURL{url: event.Data()} } else if strings.HasPrefix(event.Data(), "http") { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, event.Data(), nil) - if err != nil { - out <- errReader(err) - close(out) - return - } - resp, err := r.c.Do(req) - if err != nil { - out <- errReader(err) - close(out) - return - } - out <- resp.Body + out <- &httpURL{c: r.c, url: event.Data()} } case SSETypeDone: close(out) @@ -313,7 +329,7 @@ func (r *Client) streamFilesTo(ctx context.Context, out chan<- io.ReadCloser, ur case SSETypeLogs: // TODO default: - out <- errReader(fmt.Errorf("unknown event type %s", event.Event())) + out <- fileError(fmt.Errorf("unknown event type %s", event.Event())) return } } diff --git a/stream_test.go b/stream_test.go index 91cb5a9..7d75467 100644 --- a/stream_test.go +++ b/stream_test.go @@ -140,40 +140,46 @@ event: done assert.NoError(t, err) - var f1, f2, f3 io.Reader + var body io.Reader // first file is a data URI select { - case f1 = <-files: - require.NotNil(t, f1) + case file := <-files: + require.NotNil(t, file) + body, err = file.Body(ctx) + require.NoError(t, err) case <-time.After(time.Second): assert.Fail(t, "Timed out waiting for file") return } - content1, err := io.ReadAll(f1) + content1, err := io.ReadAll(body) assert.NoError(t, err) assert.Equal(t, "banana", string(content1)) // second file is a base64'd data URI select { - case f2 = <-files: - require.NotNil(t, f2) + case file := <-files: + require.NotNil(t, file) + body, err = file.Body(ctx) + require.NoError(t, err) case <-time.After(time.Second): assert.Fail(t, "Timed out waiting for file") return } - content2, err := io.ReadAll(f2) + content2, err := io.ReadAll(body) assert.NoError(t, err) assert.Equal(t, "apple", string(content2)) // third file is an http URI select { - case f3 = <-files: - require.NotNil(t, f3) + case file := <-files: + require.NotNil(t, file) + body, err = file.Body(ctx) + require.NoError(t, err) case <-time.After(time.Second): assert.Fail(t, "Timed out waiting for file") return } - content3, err := io.ReadAll(f3) + content3, err := io.ReadAll(body) assert.NoError(t, err) assert.Equal(t, "mango\n", string(content3)) } diff --git a/streaming/file.go b/streaming/file.go new file mode 100644 index 0000000..9b3fc18 --- /dev/null +++ b/streaming/file.go @@ -0,0 +1,16 @@ +package streaming + +import ( + "context" + "io" +) + +// File represents a file output from a model over an SSE stream. On the wire, +// it might be a data URL or a regular http URL. File abstracts over this and +// provides a way to get the data regardless of the implementation. +type File interface { + // Body fetches the content of the file. If there are any errors, the + // io.ReadCloser will be nil. It is the caller's responsibility to close + // the io.ReadCloser. + Body(ctx context.Context) (io.ReadCloser, error) +}