From a646823804d98ff6b481434941b6805a8f16b01a Mon Sep 17 00:00:00 2001 From: IrineSistiana <49315432+IrineSistiana@users.noreply.github.com> Date: Sat, 16 Jan 2021 12:39:42 +0800 Subject: [PATCH] utils: support fast fallback --- dispatcher/handler/test_utils.go | 11 + .../plugin/executable/fallback/fallback.go | 2 +- dispatcher/utils/executable_sequence.go | 199 ++++++++++++------ dispatcher/utils/executable_sequence_test.go | 95 ++++++++- 4 files changed, 234 insertions(+), 73 deletions(-) diff --git a/dispatcher/handler/test_utils.go b/dispatcher/handler/test_utils.go index 74a1cf6c9..22c5c35af 100644 --- a/dispatcher/handler/test_utils.go +++ b/dispatcher/handler/test_utils.go @@ -20,6 +20,7 @@ package handler import ( "context" "github.com/miekg/dns" + "time" ) // Types and funcs in this file are for testing only @@ -36,11 +37,16 @@ func (d *DummyMatcherPlugin) Match(_ context.Context, _ *Context) (matched bool, type DummyExecutablePlugin struct { *BP + Sleep time.Duration WantR *dns.Msg WantErr error } func (d *DummyExecutablePlugin) Exec(_ context.Context, qCtx *Context) (err error) { + if d.Sleep != 0 { + time.Sleep(d.Sleep) + } + if d.WantErr != nil { return d.WantErr } @@ -52,12 +58,17 @@ func (d *DummyExecutablePlugin) Exec(_ context.Context, qCtx *Context) (err erro type DummyESExecutablePlugin struct { *BP + Sleep time.Duration WantR *dns.Msg WantSkip bool WantErr error } func (d *DummyESExecutablePlugin) ExecES(_ context.Context, qCtx *Context) (earlyStop bool, err error) { + if d.Sleep != 0 { + time.Sleep(d.Sleep) + } + if d.WantErr != nil { return false, d.WantErr } diff --git a/dispatcher/plugin/executable/fallback/fallback.go b/dispatcher/plugin/executable/fallback/fallback.go index e19e4bbea..afc5e7f82 100644 --- a/dispatcher/plugin/executable/fallback/fallback.go +++ b/dispatcher/plugin/executable/fallback/fallback.go @@ -44,7 +44,7 @@ func Init(bp *handler.BP, args interface{}) (p handler.Plugin, err error) { } func newFallback(bp *handler.BP, args *Args) (*fallback, error) { - fallbackECS, err := utils.ParseFallbackECS(args.Primary, args.Secondary, args.Threshold, args.StatLength) + fallbackECS, err := utils.ParseFallbackECS(args) if err != nil { return nil, err } diff --git a/dispatcher/utils/executable_sequence.go b/dispatcher/utils/executable_sequence.go index 5cc011f0c..aa76d5520 100644 --- a/dispatcher/utils/executable_sequence.go +++ b/dispatcher/utils/executable_sequence.go @@ -250,30 +250,7 @@ func (p *ParallelECS) execCmd(ctx context.Context, qCtx *handler.Context, logger }() } - for i := 0; i < t; i++ { - select { - case res := <-c: - if res.err != nil { - logger.Warn("sequence failed", qCtx.InfoField(), zap.Int("sequence_index", res.from), zap.Error(res.err)) - continue - } - if res.r == nil { - logger.Debug("sequence returned with an empty response", qCtx.InfoField(), zap.Int("sequence_index", res.from)) - continue - } - - logger.Debug("sequence returned a response", qCtx.InfoField(), zap.Int("sequence_index", res.from)) - qCtx.SetResponse(res.r, res.status) - return nil - - case <-ctx.Done(): - return ctx.Err() - } - } - - // No valid response, all parallel sequences are failed. - qCtx.SetResponse(nil, handler.ContextStatusServerFailed) - return errors.New("no response") + return asyncWait(ctx, qCtx, logger, c, t) } type FallbackConfig struct { @@ -284,12 +261,19 @@ type FallbackConfig struct { StatLength int `yaml:"stat_length"` // default is 10 Threshold int `yaml:"threshold"` // default is 5 + + // FastFallback threshold in milliseconds. Zero means disable fast fallback. + FastFallback int `yaml:"fast_fallback"` + + // AlwaysStandby: secondary should always standby in fast fallback. + AlwaysStandby bool `yaml:"always_standby"` } type FallbackECS struct { - primary *ExecutableCmdSequence - secondary *ExecutableCmdSequence - threshold int + primary *ExecutableCmdSequence + secondary *ExecutableCmdSequence + fastFallbackDuration time.Duration + alwaysStandby bool primaryST *statusTracker } @@ -332,39 +316,40 @@ func (t *statusTracker) update(s uint8) { t.p++ } -func ParseFallbackECS(primary, secondary []interface{}, threshold, statLength int) (*FallbackECS, error) { - if len(primary) == 0 { +func ParseFallbackECS(c *FallbackConfig) (*FallbackECS, error) { + if len(c.Primary) == 0 { return nil, errors.New("primary sequence is empty") } - if len(secondary) == 0 { + if len(c.Secondary) == 0 { return nil, errors.New("secondary sequence is empty") } - primaryECS, err := ParseExecutableCmdSequence(primary) + primaryECS, err := ParseExecutableCmdSequence(c.Primary) if err != nil { return nil, fmt.Errorf("invalid primary sequence: %w", err) } - secondaryECS, err := ParseExecutableCmdSequence(secondary) + secondaryECS, err := ParseExecutableCmdSequence(c.Secondary) if err != nil { return nil, fmt.Errorf("invalid secondary sequence: %w", err) } - if threshold > statLength { - threshold = statLength + if c.Threshold > c.StatLength { + c.Threshold = c.StatLength } - if statLength <= 0 { - statLength = 10 + if c.StatLength <= 0 { + c.StatLength = 10 } - if threshold <= 0 { - threshold = 5 + if c.Threshold <= 0 { + c.Threshold = 5 } return &FallbackECS{ - primary: primaryECS, - secondary: secondaryECS, - threshold: threshold, - primaryST: newStatusTracker(threshold, statLength), + primary: primaryECS, + secondary: secondaryECS, + fastFallbackDuration: time.Duration(c.FastFallback) * time.Millisecond, + alwaysStandby: c.AlwaysStandby, + primaryST: newStatusTracker(c.Threshold, c.Threshold), }, nil } @@ -374,20 +359,28 @@ func (f *FallbackECS) ExecCmd(ctx context.Context, qCtx *handler.Context, logger func (f *FallbackECS) execCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { if f.primaryST.good() { - qCtxCopy := qCtx.Copy() - err = f.execPrimary(ctx, qCtxCopy, logger) - if err == nil { - err = qCtxCopy.ExecDefer(ctx) + if f.fastFallbackDuration > 0 { + return f.doFastFallback(ctx, qCtx, logger) + } else { + return f.isolateDoPrimary(ctx, qCtx, logger) } - qCtx.SetResponse(qCtxCopy.R(), qCtxCopy.Status()) - return err } logger.Debug("primary is not good", qCtx.InfoField()) return f.doFallback(ctx, qCtx, logger) } -func (f *FallbackECS) execPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { +func (f *FallbackECS) isolateDoPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { + qCtxCopy := qCtx.Copy() + err = f.doPrimary(ctx, qCtxCopy, logger) + qCtx.SetResponse(qCtxCopy.R(), qCtxCopy.Status()) + return err +} + +func (f *FallbackECS) doPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { err = WalkExecutableCmd(ctx, qCtx, logger, f.primary) + if err == nil { + err = qCtx.ExecDefer(ctx) + } if err != nil || qCtx.R() == nil { f.primaryST.update(1) } else { @@ -396,30 +389,94 @@ func (f *FallbackECS) execPrimary(ctx context.Context, qCtx *handler.Context, lo return err } -type fallbackResult struct { - r *dns.Msg - status handler.ContextStatus - err error - from string +func (f *FallbackECS) doFastFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { + fCtx, cancel := context.WithCancel(ctx) + defer cancel() + + timer := GetTimer(f.fastFallbackDuration) + defer ReleaseTimer(timer) + + c := make(chan *parallelECSResult, 2) // this chan only has nil-err result. + primFailed := make(chan struct{}) // will be closed if primary returns an err. + + qCtxCopyP := qCtx.Copy() + go func() { + err := f.doPrimary(fCtx, qCtxCopyP, logger) + if err != nil { + logger.Warn("primary sequence failed", qCtx.InfoField(), zap.Error(err)) + close(primFailed) + return // do not send this err + } + c <- ¶llelECSResult{ + r: qCtxCopyP.R(), + status: qCtxCopyP.Status(), + err: nil, + from: 1, + } + }() + + qCtxCopyS := qCtx.Copy() // TODO: this copy sometime is unnecessary, try to avoid it? + go func() { + if !f.alwaysStandby { // not always standby, wait here. + select { + case <-fCtx.Done(): // primary is done, no needs to exec this. + return + case <-primFailed: // primary failed or timeout, exec now. + case <-timer.C: + } + } + + err := f.doSecondary(fCtx, qCtxCopyS, logger) + if err != nil { + logger.Warn("secondary sequence failed", qCtx.InfoField(), zap.Error(err)) + return + } + res := ¶llelECSResult{ + r: qCtxCopyS.R(), + status: qCtxCopyS.Status(), + err: nil, + from: 2, + } + + if f.alwaysStandby { // always standby + select { + case <-fCtx.Done(): + return + case <-primFailed: // only send secondary result when primary is failed. + c <- res + case <-timer.C: // or timeout. + c <- res + } + } else { + c <- res // not always standby, send the result asap. + } + }() + + return asyncWait(ctx, qCtx, logger, c, 2) } -func (f *FallbackECS) doFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { +func (f *FallbackECS) doSecondary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { + err = WalkExecutableCmd(ctx, qCtx, logger, f.secondary) + if err == nil { + err = qCtx.ExecDefer(ctx) + } + return err +} + +func (f *FallbackECS) doFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) error { fCtx, cancel := context.WithCancel(ctx) defer cancel() - c := make(chan *fallbackResult, 2) // buf size is 2, avoid block. + c := make(chan *parallelECSResult, 2) // buf size is 2, avoid block. qCtxCopyP := qCtx.Copy() go func() { - err := f.execPrimary(fCtx, qCtxCopyP, logger) - if err == nil { - err = qCtxCopyP.ExecDefer(fCtx) - } - c <- &fallbackResult{ + err := f.doPrimary(fCtx, qCtxCopyP, logger) + c <- ¶llelECSResult{ r: qCtxCopyP.R(), status: qCtxCopyP.Status(), err: err, - from: "primary", + from: 1, } }() @@ -429,28 +486,32 @@ func (f *FallbackECS) doFallback(ctx context.Context, qCtx *handler.Context, log if err == nil { err = qCtxCopyS.ExecDefer(fCtx) } - c <- &fallbackResult{ + c <- ¶llelECSResult{ r: qCtxCopyS.R(), status: qCtxCopyS.Status(), err: err, - from: "secondary", + from: 2, } }() - for i := 0; i < 2; i++ { + return asyncWait(ctx, qCtx, logger, c, 2) +} + +func asyncWait(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, c chan *parallelECSResult, total int) error { + for i := 0; i < total; i++ { select { case res := <-c: if res.err != nil { - logger.Warn("sequence failed", qCtx.InfoField(), zap.String("sequence", res.from), zap.Error(err)) + logger.Warn("sequence failed", qCtx.InfoField(), zap.Int("sequence", res.from), zap.Error(res.err)) continue } if res.r == nil { - logger.Debug("sequence returned with an empty response ", qCtx.InfoField(), zap.String("sequence", res.from)) + logger.Debug("sequence returned with an empty response", qCtx.InfoField(), zap.Int("sequence", res.from)) continue } - logger.Debug("sequence returned a response", qCtx.InfoField(), zap.String("sequence", res.from)) + logger.Debug("sequence returned a response", qCtx.InfoField(), zap.Int("sequence", res.from)) qCtx.SetResponse(res.r, res.status) return nil @@ -527,7 +588,7 @@ func parseFallbackECS(m map[string]interface{}) (ec ExecutableCmd, err error) { if err != nil { return nil, err } - return ParseFallbackECS(conf.Primary, conf.Secondary, conf.Threshold, conf.StatLength) + return ParseFallbackECS(conf) } func hasKey(m map[string]interface{}, key string) bool { diff --git a/dispatcher/utils/executable_sequence_test.go b/dispatcher/utils/executable_sequence_test.go index 7c63174c9..6e2000799 100644 --- a/dispatcher/utils/executable_sequence_test.go +++ b/dispatcher/utils/executable_sequence_test.go @@ -218,14 +218,14 @@ func Test_ParallelECS(t *testing.T) { } } -func Test_FallbackECS(t *testing.T) { +func Test_FallbackECS_fallback(t *testing.T) { handler.PurgePluginRegister() defer handler.PurgePluginRegister() r1 := new(dns.Msg) r2 := new(dns.Msg) p1 := &handler.DummyExecutablePlugin{BP: handler.NewBP("p1", "")} - p2 := &handler.DummyExecutablePlugin{BP: handler.NewBP("p2", ""), WantR: r2} + p2 := &handler.DummyExecutablePlugin{BP: handler.NewBP("p2", "")} handler.MustRegPlugin(p1, true) handler.MustRegPlugin(p2, true) er := errors.New("") @@ -248,8 +248,16 @@ func Test_FallbackECS(t *testing.T) { {"success 2 failed 1", nil, er, nil, nil, nil, true}, // end of fallback, but primary returns an err again {"success 1 failed 2", nil, er, nil, er, nil, true}, // no response } + conf := &FallbackConfig{ + Primary: []interface{}{"p1"}, + Secondary: []interface{}{"p2"}, + StatLength: 2, + Threshold: 3, + FastFallback: 0, + AlwaysStandby: false, + } - fallbackECS, err := ParseFallbackECS([]interface{}{"p1"}, []interface{}{"p2"}, 2, 3) + fallbackECS, err := ParseFallbackECS(conf) if err != nil { t.Fatal(err) } @@ -277,6 +285,87 @@ func Test_FallbackECS(t *testing.T) { } } +func Test_FallbackECS_fast_fallback(t *testing.T) { + handler.PurgePluginRegister() + defer handler.PurgePluginRegister() + + r1 := new(dns.Msg) + r2 := new(dns.Msg) + p1 := &handler.DummyExecutablePlugin{BP: handler.NewBP("p1", "")} + p2 := &handler.DummyExecutablePlugin{BP: handler.NewBP("p2", "")} + handler.MustRegPlugin(p1, true) + handler.MustRegPlugin(p2, true) + er := errors.New("") + + tests := []struct { + name string + r1 *dns.Msg + e1 error + l1 int + r2 *dns.Msg + e2 error + l2 int + alwaysStandby bool + wantLatency int + wantR *dns.Msg + wantErr bool + }{ + {"p succeed", r1, nil, 50, r2, nil, 0, false, 70, r1, false}, + {"p failed", nil, er, 0, r2, nil, 0, false, 20, r2, false}, + {"p timeout", r1, nil, 200, r2, nil, 0, false, 120, r2, false}, + {"p timeout, s failed", r1, nil, 200, nil, er, 0, false, 220, r1, false}, + {"all timeout", r1, nil, 400, r2, nil, 400, false, 320, nil, true}, + {"always standby p succeed", r1, nil, 50, r2, nil, 0, true, 70, r1, false}, + {"always standby p failed", nil, er, 50, r2, nil, 50, true, 70, r2, false}, + {"always standby p timeout", r1, nil, 200, r2, nil, 50, true, 120, r2, false}, + {"always standby p timeout, s failed", r1, nil, 200, nil, er, 0, true, 220, r1, false}, + {"always standby all timeout", r1, nil, 400, r2, nil, 400, true, 320, nil, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conf := &FallbackConfig{ + Primary: []interface{}{"p1"}, + Secondary: []interface{}{"p2"}, + StatLength: 99999, // never trigger the fallback mode + Threshold: 99999, + FastFallback: 100, + AlwaysStandby: tt.alwaysStandby, + } + + fallbackECS, err := ParseFallbackECS(conf) + if err != nil { + t.Fatal(err) + } + + p1.WantR = tt.r1 + p1.WantErr = tt.e1 + p1.Sleep = time.Duration(tt.l1) * time.Millisecond + p2.WantR = tt.r2 + p2.WantErr = tt.e2 + p2.Sleep = time.Duration(tt.l2) * time.Millisecond + + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + + start := time.Now() + qCtx := handler.NewContext(new(dns.Msg), nil) + err = fallbackECS.execCmd(ctx, qCtx, zap.NewNop()) + if time.Since(start) > time.Millisecond*time.Duration(tt.wantLatency) { + t.Fatalf("execCmd() timeout: latency = %vms, want = %vms", time.Since(start).Milliseconds(), tt.wantLatency) + } + if tt.wantErr != (err != nil) { + t.Fatalf("execCmd() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantR != qCtx.R() { + t.Fatalf("execCmd() qCtx.R() = %p, wantR %p", qCtx.R(), tt.wantR) + } + + }) + } +} + func Test_statusTracker(t *testing.T) { tests := []struct { name string