-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from wandb/annirudh/hedge
- Loading branch information
Showing
2 changed files
with
217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
package parallel | ||
|
||
import ( | ||
"context" | ||
"time" | ||
) | ||
|
||
type HedgedRequestConfig struct { | ||
delay time.Duration // time to wait before issuing hedged requests | ||
numHedgedRequests int // the maximum permitted number of outstanding hedged requests | ||
} | ||
|
||
func (h *HedgedRequestConfig) Apply(opts ...HedgedRequestOpt) { | ||
for _, opt := range opts { | ||
opt(h) | ||
} | ||
} | ||
|
||
type HedgedRequestOpt func(config *HedgedRequestConfig) | ||
|
||
func WithDelay(delay time.Duration) HedgedRequestOpt { | ||
return func(config *HedgedRequestConfig) { | ||
config.delay = delay | ||
} | ||
} | ||
|
||
func WithNumHedgedRequests(numHedgedRequests int) HedgedRequestOpt { | ||
return func(config *HedgedRequestConfig) { | ||
config.numHedgedRequests = numHedgedRequests | ||
} | ||
} | ||
|
||
func HedgedRequest[T any]( | ||
ctx context.Context, | ||
requester func(context.Context) (T, error), | ||
opts ...HedgedRequestOpt, | ||
) (T, error) { | ||
cfg := HedgedRequestConfig{ | ||
delay: 50 * time.Millisecond, | ||
numHedgedRequests: 2, | ||
} | ||
cfg.Apply(opts...) | ||
|
||
hedgeSignal := make(chan struct{}) // closed when hedged requests should fire | ||
responses := make(chan T) // unbuffered, we only expect one response | ||
ctx, cancel := context.WithCancelCause(ctx) | ||
|
||
group := ErrGroup(Limited(ctx, cfg.numHedgedRequests+1)) | ||
for i := 0; i < cfg.numHedgedRequests+1; i++ { | ||
i := i | ||
group.Go(func(ctx context.Context) (rerr error) { | ||
defer func() { | ||
cancel(rerr) | ||
}() | ||
|
||
if i == 0 { | ||
// Initial request case: if this does not complete within the hedge delay, we signal the | ||
// hedge requests to fire off. | ||
time.AfterFunc(cfg.delay, func() { | ||
close(hedgeSignal) | ||
}) | ||
} else { | ||
// Hedged request case: wait for the go-ahead for hedged requests first. | ||
select { | ||
case <-ctx.Done(): | ||
return context.Cause(ctx) | ||
case <-hedgeSignal: | ||
// good to proceed | ||
} | ||
} | ||
|
||
res, err := requester(ctx) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
select { | ||
case <-ctx.Done(): | ||
return context.Cause(ctx) | ||
case responses <- res: | ||
return nil | ||
} | ||
}) | ||
} | ||
|
||
go func() { | ||
_ = group.Wait() | ||
close(responses) | ||
}() | ||
|
||
for response := range responses { | ||
return response, nil | ||
} | ||
|
||
var empty T | ||
return empty, context.Cause(ctx) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
package parallel | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestHedgedRequestBasic(t *testing.T) { | ||
ctx := context.Background() | ||
count := 0 | ||
expected := "success" | ||
requester := func(ctx context.Context) (string, error) { | ||
count++ | ||
if count == 1 { | ||
return expected, nil | ||
} else { | ||
return "fail", fmt.Errorf("should not trigger hedged request") | ||
} | ||
} | ||
|
||
actual, err := HedgedRequest[string](ctx, requester) | ||
|
||
require.NoError(t, err) | ||
assert.Equal(t, expected, actual) | ||
} | ||
|
||
func TestHedgedRequestHedgingTriggered(t *testing.T) { | ||
ctx := context.Background() | ||
count := 0 | ||
delay := 50 * time.Millisecond | ||
expected := "success" | ||
requester := func(ctx context.Context) (string, error) { | ||
count++ | ||
if count == 0 { | ||
select { | ||
case <-time.After(2 * delay): | ||
return "fail", fmt.Errorf("original request slow") | ||
case <-ctx.Done(): | ||
return "fail", ctx.Err() | ||
} | ||
} else { | ||
return expected, nil | ||
} | ||
} | ||
|
||
actual, err := HedgedRequest[string](ctx, requester, WithDelay(delay), WithNumHedgedRequests(1)) | ||
|
||
require.NoError(t, err) | ||
assert.Equal(t, expected, actual) | ||
} | ||
|
||
func TestHedgedRequestMultipleSuccess(t *testing.T) { | ||
ctx := context.Background() | ||
expected := "success" | ||
delay := 5 * time.Millisecond | ||
done := make(chan struct{}) | ||
requester := func(ctx context.Context) (string, error) { | ||
// Synchronize on the done channel. The original request and | ||
// all hedged requests will line up and block here. | ||
select { | ||
case <-done: | ||
return expected, nil | ||
case <-ctx.Done(): | ||
return "fail", ctx.Err() | ||
} | ||
} | ||
|
||
// Wait for 2x the delay to ensure that all hedged requests have | ||
// fired alongside the original request. Then close the done channel | ||
// so the original and hedged requests complete almost simultaneously. | ||
time.AfterFunc(2*delay, func() { | ||
close(done) | ||
}) | ||
|
||
actual, err := HedgedRequest[string]( | ||
ctx, | ||
requester, | ||
WithDelay(delay), | ||
) | ||
|
||
require.NoError(t, err) | ||
assert.Equal(t, expected, actual) | ||
} | ||
|
||
func TestHedgedRequestErrorPropagation(t *testing.T) { | ||
ctx := context.Background() | ||
expectedErr := errors.New("failure") | ||
requester := func(ctx context.Context) (string, error) { | ||
return "fail", expectedErr | ||
} | ||
|
||
_, err := HedgedRequest[string](ctx, requester) | ||
|
||
require.ErrorIs(t, err, expectedErr) | ||
} | ||
|
||
func TestHedgedRequestCancellation(t *testing.T) { | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
delay := 100 * time.Millisecond | ||
requester := func(ctx context.Context) (string, error) { | ||
<-time.After(delay) | ||
return "fail", fmt.Errorf("context should be canceled") | ||
} | ||
|
||
go func() { | ||
// Cancel context before the requester completes | ||
time.Sleep(10 * time.Millisecond) | ||
cancel() | ||
}() | ||
|
||
_, err := HedgedRequest[string](ctx, requester) | ||
|
||
require.ErrorIs(t, err, context.Canceled) | ||
} |