Skip to content

Commit

Permalink
introduce streaming.File interface
Browse files Browse the repository at this point in the history
  • Loading branch information
philandstuff committed Oct 10, 2024
1 parent a4b1858 commit 965e0ef
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 47 deletions.
90 changes: 53 additions & 37 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"unicode/utf8"

"github.com/launchdarkly/eventsource"
"github.com/replicate/replicate-go/streaming"
"github.com/vincent-petithory/dataurl"
"golang.org/x/sync/errgroup"
)
Expand Down Expand Up @@ -143,13 +144,13 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) (
return sseChan, errChan
}

func (r *Client) StreamPredictionFiles(ctx context.Context, prediction *Prediction) (<-chan io.ReadCloser, error) {
func (r *Client) StreamPredictionFiles(ctx context.Context, prediction *Prediction) (<-chan streaming.File, error) {
url := prediction.URLs["stream"]
if url == "" {
return nil, errors.New("streaming not supported or not enabled for this prediction")
}

ch := make(chan io.ReadCloser)
ch := make(chan streaming.File)

go r.streamFilesTo(ctx, ch, url, "")
return ch, nil
Expand Down Expand Up @@ -220,28 +221,63 @@ func (r *Client) streamTextTo(ctx context.Context, writer *io.PipeWriter, url st
}
}

type errorReader struct {
type dataURL struct {
url string
}

var _ streaming.File = &dataURL{}

func (d *dataURL) Body(_ context.Context) (io.ReadCloser, error) {
data, err := dataurl.DecodeString(d.url)

if err != nil {
return nil, err
}

return io.NopCloser(bytes.NewReader(data.Data)), nil
}

type httpURL struct {
c *http.Client
url string
}

var _ streaming.File = &httpURL{}

func (h *httpURL) Body(ctx context.Context) (io.ReadCloser, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, h.url, nil)
if err != nil {
return nil, err
}
resp, err := h.c.Do(req)
if err != nil {
return nil, err
}
return resp.Body, nil
}

type errWrapper struct {
err error
}

var _ io.ReadCloser = &errorReader{}
var _ streaming.File = &errWrapper{}

func errReader(err error) io.ReadCloser {
return &errorReader{err: err}
func fileError(err error) streaming.File {
return &errWrapper{err: err}
}

func (e *errorReader) Read(p []byte) (int, error) {
return 0, e.err
func (e *errWrapper) Body(_ context.Context) (io.ReadCloser, error) {
return nil, e.err
}

func (e *errorReader) Close() error {
func (e *errWrapper) Close() error {
return nil
}

func (r *Client) streamFilesTo(ctx context.Context, out chan<- io.ReadCloser, url string, lastEventID string) {
func (r *Client) streamFilesTo(ctx context.Context, out chan<- streaming.File, url string, lastEventID string) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
out <- errReader(err)
out <- fileError(err)
close(out)
return
}
Expand All @@ -255,13 +291,13 @@ func (r *Client) streamFilesTo(ctx context.Context, out chan<- io.ReadCloser, ur

resp, err := r.c.Do(req)
if err != nil {
out <- errReader(fmt.Errorf("failed to send request: %w", err))
out <- fileError(fmt.Errorf("failed to send request: %w", err))
close(out)
return
}

if resp.StatusCode != http.StatusOK {
out <- errReader(fmt.Errorf("received invalid status code: %d", resp.StatusCode))
out <- fileError(fmt.Errorf("received invalid status code: %d", resp.StatusCode))
close(out)
return
}
Expand All @@ -275,45 +311,25 @@ func (r *Client) streamFilesTo(ctx context.Context, out chan<- io.ReadCloser, ur
r.streamFilesTo(ctx, out, url, lastEventID)
return
}
out <- errReader(fmt.Errorf("Failed to get token: %w", err))
out <- fileError(fmt.Errorf("Failed to get token: %w", err))
close(out)
return
}
lastEventID = event.Id()
switch event.Event() {
case SSETypeOutput:
if strings.HasPrefix(event.Data(), "data:") {
data, err := dataurl.DecodeString(event.Data())

if err != nil {
out <- errReader(err)
close(out)
return
}

out <- io.NopCloser(bytes.NewReader(data.Data))
out <- &dataURL{url: event.Data()}
} else if strings.HasPrefix(event.Data(), "http") {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, event.Data(), nil)
if err != nil {
out <- errReader(err)
close(out)
return
}
resp, err := r.c.Do(req)
if err != nil {
out <- errReader(err)
close(out)
return
}
out <- resp.Body
out <- &httpURL{c: r.c, url: event.Data()}
}
case SSETypeDone:
close(out)
return
case SSETypeLogs:
// TODO
default:
out <- errReader(fmt.Errorf("unknown event type %s", event.Event()))
out <- fileError(fmt.Errorf("unknown event type %s", event.Event()))
return
}
}
Expand Down
26 changes: 16 additions & 10 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,40 +140,46 @@ event: done

assert.NoError(t, err)

var f1, f2, f3 io.Reader
var body io.Reader
// first file is a data URI
select {
case f1 = <-files:
require.NotNil(t, f1)
case file := <-files:
require.NotNil(t, file)
body, err = file.Body(ctx)
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Timed out waiting for file")
return
}
content1, err := io.ReadAll(f1)
content1, err := io.ReadAll(body)
assert.NoError(t, err)
assert.Equal(t, "banana", string(content1))

// second file is a base64'd data URI
select {
case f2 = <-files:
require.NotNil(t, f2)
case file := <-files:
require.NotNil(t, file)
body, err = file.Body(ctx)
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Timed out waiting for file")
return
}
content2, err := io.ReadAll(f2)
content2, err := io.ReadAll(body)
assert.NoError(t, err)
assert.Equal(t, "apple", string(content2))

// third file is an http URI
select {
case f3 = <-files:
require.NotNil(t, f3)
case file := <-files:
require.NotNil(t, file)
body, err = file.Body(ctx)
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Timed out waiting for file")
return
}
content3, err := io.ReadAll(f3)
content3, err := io.ReadAll(body)
assert.NoError(t, err)
assert.Equal(t, "mango\n", string(content3))
}
16 changes: 16 additions & 0 deletions streaming/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package streaming

import (
"context"
"io"
)

// File represents a file output from a model over an SSE stream. On the wire,
// it might be a data URL or a regular http URL. File abstracts over this and
// provides a way to get the data regardless of the implementation.
type File interface {
// Body fetches the content of the file. If there are any errors, the
// io.ReadCloser will be nil. It is the caller's responsibility to close
// the io.ReadCloser.
Body(ctx context.Context) (io.ReadCloser, error)
}

0 comments on commit 965e0ef

Please sign in to comment.