diff --git a/hedge.go b/hedge.go new file mode 100644 index 0000000..eb114b6 --- /dev/null +++ b/hedge.go @@ -0,0 +1,97 @@ +package parallel + +import ( + "context" + "time" +) + +type HedgedRequestConfig struct { + delay time.Duration // time to wait before issuing hedged requests + numHedgedRequests int // the maximum permitted number of outstanding hedged requests +} + +func (h *HedgedRequestConfig) Apply(opts ...HedgedRequestOpt) { + for _, opt := range opts { + opt(h) + } +} + +type HedgedRequestOpt func(config *HedgedRequestConfig) + +func WithDelay(delay time.Duration) HedgedRequestOpt { + return func(config *HedgedRequestConfig) { + config.delay = delay + } +} + +func WithNumHedgedRequests(numHedgedRequests int) HedgedRequestOpt { + return func(config *HedgedRequestConfig) { + config.numHedgedRequests = numHedgedRequests + } +} + +func HedgedRequest[T any]( + ctx context.Context, + requester func(context.Context) (T, error), + opts ...HedgedRequestOpt, +) (T, error) { + cfg := HedgedRequestConfig{ + delay: 50 * time.Millisecond, + numHedgedRequests: 2, + } + cfg.Apply(opts...) + + hedgeSignal := make(chan struct{}) // closed when hedged requests should fire + responses := make(chan T) // unbuffered, we only expect one response + ctx, cancel := context.WithCancelCause(ctx) + + group := ErrGroup(Limited(ctx, cfg.numHedgedRequests+1)) + for i := 0; i < cfg.numHedgedRequests+1; i++ { + i := i + group.Go(func(ctx context.Context) (rerr error) { + defer func() { + cancel(rerr) + }() + + if i == 0 { + // Initial request case: if this does not complete within the hedge delay, we signal the + // hedge requests to fire off. + time.AfterFunc(cfg.delay, func() { + close(hedgeSignal) + }) + } else { + // Hedged request case: wait for the go-ahead for hedged requests first. + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-hedgeSignal: + // good to proceed + } + } + + res, err := requester(ctx) + if err != nil { + return err + } + + select { + case <-ctx.Done(): + return context.Cause(ctx) + case responses <- res: + return nil + } + }) + } + + go func() { + _ = group.Wait() + close(responses) + }() + + for response := range responses { + return response, nil + } + + var empty T + return empty, context.Cause(ctx) +} diff --git a/hedge_test.go b/hedge_test.go new file mode 100644 index 0000000..936f781 --- /dev/null +++ b/hedge_test.go @@ -0,0 +1,120 @@ +package parallel + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHedgedRequestBasic(t *testing.T) { + ctx := context.Background() + count := 0 + expected := "success" + requester := func(ctx context.Context) (string, error) { + count++ + if count == 1 { + return expected, nil + } else { + return "fail", fmt.Errorf("should not trigger hedged request") + } + } + + actual, err := HedgedRequest[string](ctx, requester) + + require.NoError(t, err) + assert.Equal(t, expected, actual) +} + +func TestHedgedRequestHedgingTriggered(t *testing.T) { + ctx := context.Background() + count := 0 + delay := 50 * time.Millisecond + expected := "success" + requester := func(ctx context.Context) (string, error) { + count++ + if count == 0 { + select { + case <-time.After(2 * delay): + return "fail", fmt.Errorf("original request slow") + case <-ctx.Done(): + return "fail", ctx.Err() + } + } else { + return expected, nil + } + } + + actual, err := HedgedRequest[string](ctx, requester, WithDelay(delay), WithNumHedgedRequests(1)) + + require.NoError(t, err) + assert.Equal(t, expected, actual) +} + +func TestHedgedRequestMultipleSuccess(t *testing.T) { + ctx := context.Background() + expected := "success" + delay := 5 * time.Millisecond + done := make(chan struct{}) + requester := func(ctx context.Context) (string, error) { + // Synchronize on the done channel. The original request and + // all hedged requests will line up and block here. + select { + case <-done: + return expected, nil + case <-ctx.Done(): + return "fail", ctx.Err() + } + } + + // Wait for 2x the delay to ensure that all hedged requests have + // fired alongside the original request. Then close the done channel + // so the original and hedged requests complete almost simultaneously. + time.AfterFunc(2*delay, func() { + close(done) + }) + + actual, err := HedgedRequest[string]( + ctx, + requester, + WithDelay(delay), + ) + + require.NoError(t, err) + assert.Equal(t, expected, actual) +} + +func TestHedgedRequestErrorPropagation(t *testing.T) { + ctx := context.Background() + expectedErr := errors.New("failure") + requester := func(ctx context.Context) (string, error) { + return "fail", expectedErr + } + + _, err := HedgedRequest[string](ctx, requester) + + require.ErrorIs(t, err, expectedErr) +} + +func TestHedgedRequestCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + delay := 100 * time.Millisecond + requester := func(ctx context.Context) (string, error) { + <-time.After(delay) + return "fail", fmt.Errorf("context should be canceled") + } + + go func() { + // Cancel context before the requester completes + time.Sleep(10 * time.Millisecond) + cancel() + }() + + _, err := HedgedRequest[string](ctx, requester) + + require.ErrorIs(t, err, context.Canceled) +}