Skip to content

Commit

Permalink
Merge pull request #4 from wandb/annirudh/hedge
Browse files Browse the repository at this point in the history
  • Loading branch information
annirudh authored May 6, 2024
2 parents 34abaa4 + 871d875 commit 82cc1cd
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 0 deletions.
97 changes: 97 additions & 0 deletions hedge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package parallel

import (
"context"
"time"
)

type HedgedRequestConfig struct {
delay time.Duration // time to wait before issuing hedged requests
numHedgedRequests int // the maximum permitted number of outstanding hedged requests
}

func (h *HedgedRequestConfig) Apply(opts ...HedgedRequestOpt) {
for _, opt := range opts {
opt(h)
}
}

type HedgedRequestOpt func(config *HedgedRequestConfig)

func WithDelay(delay time.Duration) HedgedRequestOpt {
return func(config *HedgedRequestConfig) {
config.delay = delay
}
}

func WithNumHedgedRequests(numHedgedRequests int) HedgedRequestOpt {
return func(config *HedgedRequestConfig) {
config.numHedgedRequests = numHedgedRequests
}
}

func HedgedRequest[T any](
ctx context.Context,
requester func(context.Context) (T, error),
opts ...HedgedRequestOpt,
) (T, error) {
cfg := HedgedRequestConfig{
delay: 50 * time.Millisecond,
numHedgedRequests: 2,
}
cfg.Apply(opts...)

hedgeSignal := make(chan struct{}) // closed when hedged requests should fire
responses := make(chan T) // unbuffered, we only expect one response
ctx, cancel := context.WithCancelCause(ctx)

group := ErrGroup(Limited(ctx, cfg.numHedgedRequests+1))
for i := 0; i < cfg.numHedgedRequests+1; i++ {
i := i
group.Go(func(ctx context.Context) (rerr error) {
defer func() {
cancel(rerr)
}()

if i == 0 {
// Initial request case: if this does not complete within the hedge delay, we signal the
// hedge requests to fire off.
time.AfterFunc(cfg.delay, func() {
close(hedgeSignal)
})
} else {
// Hedged request case: wait for the go-ahead for hedged requests first.
select {
case <-ctx.Done():
return context.Cause(ctx)
case <-hedgeSignal:
// good to proceed
}
}

res, err := requester(ctx)
if err != nil {
return err
}

select {
case <-ctx.Done():
return context.Cause(ctx)
case responses <- res:
return nil
}
})
}

go func() {
_ = group.Wait()
close(responses)
}()

for response := range responses {
return response, nil
}

var empty T
return empty, context.Cause(ctx)
}
120 changes: 120 additions & 0 deletions hedge_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package parallel

import (
"context"
"errors"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestHedgedRequestBasic(t *testing.T) {
ctx := context.Background()
count := 0
expected := "success"
requester := func(ctx context.Context) (string, error) {
count++
if count == 1 {
return expected, nil
} else {
return "fail", fmt.Errorf("should not trigger hedged request")
}
}

actual, err := HedgedRequest[string](ctx, requester)

require.NoError(t, err)
assert.Equal(t, expected, actual)
}

func TestHedgedRequestHedgingTriggered(t *testing.T) {
ctx := context.Background()
count := 0
delay := 50 * time.Millisecond
expected := "success"
requester := func(ctx context.Context) (string, error) {
count++
if count == 0 {
select {
case <-time.After(2 * delay):
return "fail", fmt.Errorf("original request slow")
case <-ctx.Done():
return "fail", ctx.Err()
}
} else {
return expected, nil
}
}

actual, err := HedgedRequest[string](ctx, requester, WithDelay(delay), WithNumHedgedRequests(1))

require.NoError(t, err)
assert.Equal(t, expected, actual)
}

func TestHedgedRequestMultipleSuccess(t *testing.T) {
ctx := context.Background()
expected := "success"
delay := 5 * time.Millisecond
done := make(chan struct{})
requester := func(ctx context.Context) (string, error) {
// Synchronize on the done channel. The original request and
// all hedged requests will line up and block here.
select {
case <-done:
return expected, nil
case <-ctx.Done():
return "fail", ctx.Err()
}
}

// Wait for 2x the delay to ensure that all hedged requests have
// fired alongside the original request. Then close the done channel
// so the original and hedged requests complete almost simultaneously.
time.AfterFunc(2*delay, func() {
close(done)
})

actual, err := HedgedRequest[string](
ctx,
requester,
WithDelay(delay),
)

require.NoError(t, err)
assert.Equal(t, expected, actual)
}

func TestHedgedRequestErrorPropagation(t *testing.T) {
ctx := context.Background()
expectedErr := errors.New("failure")
requester := func(ctx context.Context) (string, error) {
return "fail", expectedErr
}

_, err := HedgedRequest[string](ctx, requester)

require.ErrorIs(t, err, expectedErr)
}

func TestHedgedRequestCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
delay := 100 * time.Millisecond
requester := func(ctx context.Context) (string, error) {
<-time.After(delay)
return "fail", fmt.Errorf("context should be canceled")
}

go func() {
// Cancel context before the requester completes
time.Sleep(10 * time.Millisecond)
cancel()
}()

_, err := HedgedRequest[string](ctx, requester)

require.ErrorIs(t, err, context.Canceled)
}

0 comments on commit 82cc1cd

Please sign in to comment.