Skip to content

Commit

Permalink
Showing 4 changed files with 234 additions and 73 deletions.
11 changes: 11 additions & 0 deletions dispatcher/handler/test_utils.go
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ package handler
import (
"context"
"github.com/miekg/dns"
"time"
)

// Types and funcs in this file are for testing only
@@ -36,11 +37,16 @@ func (d *DummyMatcherPlugin) Match(_ context.Context, _ *Context) (matched bool,

type DummyExecutablePlugin struct {
*BP
Sleep time.Duration
WantR *dns.Msg
WantErr error
}

func (d *DummyExecutablePlugin) Exec(_ context.Context, qCtx *Context) (err error) {
if d.Sleep != 0 {
time.Sleep(d.Sleep)
}

if d.WantErr != nil {
return d.WantErr
}
@@ -52,12 +58,17 @@ func (d *DummyExecutablePlugin) Exec(_ context.Context, qCtx *Context) (err erro

type DummyESExecutablePlugin struct {
*BP
Sleep time.Duration
WantR *dns.Msg
WantSkip bool
WantErr error
}

func (d *DummyESExecutablePlugin) ExecES(_ context.Context, qCtx *Context) (earlyStop bool, err error) {
if d.Sleep != 0 {
time.Sleep(d.Sleep)
}

if d.WantErr != nil {
return false, d.WantErr
}
2 changes: 1 addition & 1 deletion dispatcher/plugin/executable/fallback/fallback.go
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ func Init(bp *handler.BP, args interface{}) (p handler.Plugin, err error) {
}

func newFallback(bp *handler.BP, args *Args) (*fallback, error) {
fallbackECS, err := utils.ParseFallbackECS(args.Primary, args.Secondary, args.Threshold, args.StatLength)
fallbackECS, err := utils.ParseFallbackECS(args)
if err != nil {
return nil, err
}
199 changes: 130 additions & 69 deletions dispatcher/utils/executable_sequence.go
Original file line number Diff line number Diff line change
@@ -250,30 +250,7 @@ func (p *ParallelECS) execCmd(ctx context.Context, qCtx *handler.Context, logger
}()
}

for i := 0; i < t; i++ {
select {
case res := <-c:
if res.err != nil {
logger.Warn("sequence failed", qCtx.InfoField(), zap.Int("sequence_index", res.from), zap.Error(res.err))
continue
}
if res.r == nil {
logger.Debug("sequence returned with an empty response", qCtx.InfoField(), zap.Int("sequence_index", res.from))
continue
}

logger.Debug("sequence returned a response", qCtx.InfoField(), zap.Int("sequence_index", res.from))
qCtx.SetResponse(res.r, res.status)
return nil

case <-ctx.Done():
return ctx.Err()
}
}

// No valid response, all parallel sequences are failed.
qCtx.SetResponse(nil, handler.ContextStatusServerFailed)
return errors.New("no response")
return asyncWait(ctx, qCtx, logger, c, t)
}

type FallbackConfig struct {
@@ -284,12 +261,19 @@ type FallbackConfig struct {

StatLength int `yaml:"stat_length"` // default is 10
Threshold int `yaml:"threshold"` // default is 5

// FastFallback threshold in milliseconds. Zero means disable fast fallback.
FastFallback int `yaml:"fast_fallback"`

// AlwaysStandby: secondary should always standby in fast fallback.
AlwaysStandby bool `yaml:"always_standby"`
}

type FallbackECS struct {
primary *ExecutableCmdSequence
secondary *ExecutableCmdSequence
threshold int
primary *ExecutableCmdSequence
secondary *ExecutableCmdSequence
fastFallbackDuration time.Duration
alwaysStandby bool

primaryST *statusTracker
}
@@ -332,39 +316,40 @@ func (t *statusTracker) update(s uint8) {
t.p++
}

func ParseFallbackECS(primary, secondary []interface{}, threshold, statLength int) (*FallbackECS, error) {
if len(primary) == 0 {
func ParseFallbackECS(c *FallbackConfig) (*FallbackECS, error) {
if len(c.Primary) == 0 {
return nil, errors.New("primary sequence is empty")
}
if len(secondary) == 0 {
if len(c.Secondary) == 0 {
return nil, errors.New("secondary sequence is empty")
}

primaryECS, err := ParseExecutableCmdSequence(primary)
primaryECS, err := ParseExecutableCmdSequence(c.Primary)
if err != nil {
return nil, fmt.Errorf("invalid primary sequence: %w", err)
}

secondaryECS, err := ParseExecutableCmdSequence(secondary)
secondaryECS, err := ParseExecutableCmdSequence(c.Secondary)
if err != nil {
return nil, fmt.Errorf("invalid secondary sequence: %w", err)
}

if threshold > statLength {
threshold = statLength
if c.Threshold > c.StatLength {
c.Threshold = c.StatLength
}
if statLength <= 0 {
statLength = 10
if c.StatLength <= 0 {
c.StatLength = 10
}
if threshold <= 0 {
threshold = 5
if c.Threshold <= 0 {
c.Threshold = 5
}

return &FallbackECS{
primary: primaryECS,
secondary: secondaryECS,
threshold: threshold,
primaryST: newStatusTracker(threshold, statLength),
primary: primaryECS,
secondary: secondaryECS,
fastFallbackDuration: time.Duration(c.FastFallback) * time.Millisecond,
alwaysStandby: c.AlwaysStandby,
primaryST: newStatusTracker(c.Threshold, c.Threshold),
}, nil
}

@@ -374,20 +359,28 @@ func (f *FallbackECS) ExecCmd(ctx context.Context, qCtx *handler.Context, logger

func (f *FallbackECS) execCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
if f.primaryST.good() {
qCtxCopy := qCtx.Copy()
err = f.execPrimary(ctx, qCtxCopy, logger)
if err == nil {
err = qCtxCopy.ExecDefer(ctx)
if f.fastFallbackDuration > 0 {
return f.doFastFallback(ctx, qCtx, logger)
} else {
return f.isolateDoPrimary(ctx, qCtx, logger)
}
qCtx.SetResponse(qCtxCopy.R(), qCtxCopy.Status())
return err
}
logger.Debug("primary is not good", qCtx.InfoField())
return f.doFallback(ctx, qCtx, logger)
}

func (f *FallbackECS) execPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
func (f *FallbackECS) isolateDoPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
qCtxCopy := qCtx.Copy()
err = f.doPrimary(ctx, qCtxCopy, logger)
qCtx.SetResponse(qCtxCopy.R(), qCtxCopy.Status())
return err
}

func (f *FallbackECS) doPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
err = WalkExecutableCmd(ctx, qCtx, logger, f.primary)
if err == nil {
err = qCtx.ExecDefer(ctx)
}
if err != nil || qCtx.R() == nil {
f.primaryST.update(1)
} else {
@@ -396,30 +389,94 @@ func (f *FallbackECS) execPrimary(ctx context.Context, qCtx *handler.Context, lo
return err
}

type fallbackResult struct {
r *dns.Msg
status handler.ContextStatus
err error
from string
func (f *FallbackECS) doFastFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
fCtx, cancel := context.WithCancel(ctx)
defer cancel()

timer := GetTimer(f.fastFallbackDuration)
defer ReleaseTimer(timer)

c := make(chan *parallelECSResult, 2) // this chan only has nil-err result.
primFailed := make(chan struct{}) // will be closed if primary returns an err.

qCtxCopyP := qCtx.Copy()
go func() {
err := f.doPrimary(fCtx, qCtxCopyP, logger)
if err != nil {
logger.Warn("primary sequence failed", qCtx.InfoField(), zap.Error(err))
close(primFailed)
return // do not send this err
}
c <- &parallelECSResult{
r: qCtxCopyP.R(),
status: qCtxCopyP.Status(),
err: nil,
from: 1,
}
}()

qCtxCopyS := qCtx.Copy() // TODO: this copy sometime is unnecessary, try to avoid it?
go func() {
if !f.alwaysStandby { // not always standby, wait here.
select {
case <-fCtx.Done(): // primary is done, no needs to exec this.
return
case <-primFailed: // primary failed or timeout, exec now.
case <-timer.C:
}
}

err := f.doSecondary(fCtx, qCtxCopyS, logger)
if err != nil {
logger.Warn("secondary sequence failed", qCtx.InfoField(), zap.Error(err))
return
}
res := &parallelECSResult{
r: qCtxCopyS.R(),
status: qCtxCopyS.Status(),
err: nil,
from: 2,
}

if f.alwaysStandby { // always standby
select {
case <-fCtx.Done():
return
case <-primFailed: // only send secondary result when primary is failed.
c <- res
case <-timer.C: // or timeout.
c <- res
}
} else {
c <- res // not always standby, send the result asap.
}
}()

return asyncWait(ctx, qCtx, logger, c, 2)
}

func (f *FallbackECS) doFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
func (f *FallbackECS) doSecondary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
err = WalkExecutableCmd(ctx, qCtx, logger, f.secondary)
if err == nil {
err = qCtx.ExecDefer(ctx)
}
return err
}

func (f *FallbackECS) doFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) error {
fCtx, cancel := context.WithCancel(ctx)
defer cancel()

c := make(chan *fallbackResult, 2) // buf size is 2, avoid block.
c := make(chan *parallelECSResult, 2) // buf size is 2, avoid block.

qCtxCopyP := qCtx.Copy()
go func() {
err := f.execPrimary(fCtx, qCtxCopyP, logger)
if err == nil {
err = qCtxCopyP.ExecDefer(fCtx)
}
c <- &fallbackResult{
err := f.doPrimary(fCtx, qCtxCopyP, logger)
c <- &parallelECSResult{
r: qCtxCopyP.R(),
status: qCtxCopyP.Status(),
err: err,
from: "primary",
from: 1,
}
}()

@@ -429,28 +486,32 @@ func (f *FallbackECS) doFallback(ctx context.Context, qCtx *handler.Context, log
if err == nil {
err = qCtxCopyS.ExecDefer(fCtx)
}
c <- &fallbackResult{
c <- &parallelECSResult{
r: qCtxCopyS.R(),
status: qCtxCopyS.Status(),
err: err,
from: "secondary",
from: 2,
}
}()

for i := 0; i < 2; i++ {
return asyncWait(ctx, qCtx, logger, c, 2)
}

func asyncWait(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, c chan *parallelECSResult, total int) error {
for i := 0; i < total; i++ {
select {
case res := <-c:
if res.err != nil {
logger.Warn("sequence failed", qCtx.InfoField(), zap.String("sequence", res.from), zap.Error(err))
logger.Warn("sequence failed", qCtx.InfoField(), zap.Int("sequence", res.from), zap.Error(res.err))
continue
}

if res.r == nil {
logger.Debug("sequence returned with an empty response ", qCtx.InfoField(), zap.String("sequence", res.from))
logger.Debug("sequence returned with an empty response", qCtx.InfoField(), zap.Int("sequence", res.from))
continue
}

logger.Debug("sequence returned a response", qCtx.InfoField(), zap.String("sequence", res.from))
logger.Debug("sequence returned a response", qCtx.InfoField(), zap.Int("sequence", res.from))
qCtx.SetResponse(res.r, res.status)
return nil

@@ -527,7 +588,7 @@ func parseFallbackECS(m map[string]interface{}) (ec ExecutableCmd, err error) {
if err != nil {
return nil, err
}
return ParseFallbackECS(conf.Primary, conf.Secondary, conf.Threshold, conf.StatLength)
return ParseFallbackECS(conf)
}

func hasKey(m map[string]interface{}, key string) bool {
95 changes: 92 additions & 3 deletions dispatcher/utils/executable_sequence_test.go
Original file line number Diff line number Diff line change
@@ -218,14 +218,14 @@ func Test_ParallelECS(t *testing.T) {
}
}

func Test_FallbackECS(t *testing.T) {
func Test_FallbackECS_fallback(t *testing.T) {
handler.PurgePluginRegister()
defer handler.PurgePluginRegister()

r1 := new(dns.Msg)
r2 := new(dns.Msg)
p1 := &handler.DummyExecutablePlugin{BP: handler.NewBP("p1", "")}
p2 := &handler.DummyExecutablePlugin{BP: handler.NewBP("p2", ""), WantR: r2}
p2 := &handler.DummyExecutablePlugin{BP: handler.NewBP("p2", "")}
handler.MustRegPlugin(p1, true)
handler.MustRegPlugin(p2, true)
er := errors.New("")
@@ -248,8 +248,16 @@ func Test_FallbackECS(t *testing.T) {
{"success 2 failed 1", nil, er, nil, nil, nil, true}, // end of fallback, but primary returns an err again
{"success 1 failed 2", nil, er, nil, er, nil, true}, // no response
}
conf := &FallbackConfig{
Primary: []interface{}{"p1"},
Secondary: []interface{}{"p2"},
StatLength: 2,
Threshold: 3,
FastFallback: 0,
AlwaysStandby: false,
}

fallbackECS, err := ParseFallbackECS([]interface{}{"p1"}, []interface{}{"p2"}, 2, 3)
fallbackECS, err := ParseFallbackECS(conf)
if err != nil {
t.Fatal(err)
}
@@ -277,6 +285,87 @@ func Test_FallbackECS(t *testing.T) {
}
}

func Test_FallbackECS_fast_fallback(t *testing.T) {
handler.PurgePluginRegister()
defer handler.PurgePluginRegister()

r1 := new(dns.Msg)
r2 := new(dns.Msg)
p1 := &handler.DummyExecutablePlugin{BP: handler.NewBP("p1", "")}
p2 := &handler.DummyExecutablePlugin{BP: handler.NewBP("p2", "")}
handler.MustRegPlugin(p1, true)
handler.MustRegPlugin(p2, true)
er := errors.New("")

tests := []struct {
name string
r1 *dns.Msg
e1 error
l1 int
r2 *dns.Msg
e2 error
l2 int
alwaysStandby bool
wantLatency int
wantR *dns.Msg
wantErr bool
}{
{"p succeed", r1, nil, 50, r2, nil, 0, false, 70, r1, false},
{"p failed", nil, er, 0, r2, nil, 0, false, 20, r2, false},
{"p timeout", r1, nil, 200, r2, nil, 0, false, 120, r2, false},
{"p timeout, s failed", r1, nil, 200, nil, er, 0, false, 220, r1, false},
{"all timeout", r1, nil, 400, r2, nil, 400, false, 320, nil, true},
{"always standby p succeed", r1, nil, 50, r2, nil, 0, true, 70, r1, false},
{"always standby p failed", nil, er, 50, r2, nil, 50, true, 70, r2, false},
{"always standby p timeout", r1, nil, 200, r2, nil, 50, true, 120, r2, false},
{"always standby p timeout, s failed", r1, nil, 200, nil, er, 0, true, 220, r1, false},
{"always standby all timeout", r1, nil, 400, r2, nil, 400, true, 320, nil, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf := &FallbackConfig{
Primary: []interface{}{"p1"},
Secondary: []interface{}{"p2"},
StatLength: 99999, // never trigger the fallback mode
Threshold: 99999,
FastFallback: 100,
AlwaysStandby: tt.alwaysStandby,
}

fallbackECS, err := ParseFallbackECS(conf)
if err != nil {
t.Fatal(err)
}

p1.WantR = tt.r1
p1.WantErr = tt.e1
p1.Sleep = time.Duration(tt.l1) * time.Millisecond
p2.WantR = tt.r2
p2.WantErr = tt.e2
p2.Sleep = time.Duration(tt.l2) * time.Millisecond

ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
defer cancel()

start := time.Now()
qCtx := handler.NewContext(new(dns.Msg), nil)
err = fallbackECS.execCmd(ctx, qCtx, zap.NewNop())
if time.Since(start) > time.Millisecond*time.Duration(tt.wantLatency) {
t.Fatalf("execCmd() timeout: latency = %vms, want = %vms", time.Since(start).Milliseconds(), tt.wantLatency)
}
if tt.wantErr != (err != nil) {
t.Fatalf("execCmd() error = %v, wantErr %v", err, tt.wantErr)
}

if tt.wantR != qCtx.R() {
t.Fatalf("execCmd() qCtx.R() = %p, wantR %p", qCtx.R(), tt.wantR)
}

})
}
}

func Test_statusTracker(t *testing.T) {
tests := []struct {
name string

0 comments on commit a646823

Please sign in to comment.