diff --git a/dispatcher/coremain/run.go b/dispatcher/coremain/run.go index 8546ac99b..9cf3d2d89 100644 --- a/dispatcher/coremain/run.go +++ b/dispatcher/coremain/run.go @@ -21,8 +21,8 @@ import ( "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" "github.com/IrineSistiana/mosdns/dispatcher/mlog" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/concurrent_limiter" _ "github.com/IrineSistiana/mosdns/dispatcher/plugin" - "github.com/IrineSistiana/mosdns/dispatcher/utils" "go.uber.org/zap" "go.uber.org/zap/zapcore" "os" @@ -93,7 +93,7 @@ func loadConfig(f string, depth int) { if n < 1 { n = 1 } - pool := utils.NewConcurrentLimiter(n) + pool := concurrent_limiter.NewConcurrentLimiter(n) wg := new(sync.WaitGroup) for i, pluginConfig := range c.Plugin { if len(pluginConfig.Tag) == 0 || len(pluginConfig.Type) == 0 { diff --git a/dispatcher/pkg/concurrent_limiter/client_limiter.go b/dispatcher/pkg/concurrent_limiter/client_limiter.go new file mode 100644 index 000000000..2951e0b47 --- /dev/null +++ b/dispatcher/pkg/concurrent_limiter/client_limiter.go @@ -0,0 +1,69 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of mosdns. +// +// mosdns is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// mosdns is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package concurrent_limiter + +import ( + "github.com/IrineSistiana/mosdns/dispatcher/pkg/concurrent_map" +) + +type ClientQueryLimiter struct { + maxQueries int + m *concurrent_map.ConcurrentMap +} + +func NewClientQueryLimiter(maxQueries int) *ClientQueryLimiter { + return &ClientQueryLimiter{ + maxQueries: maxQueries, + m: concurrent_map.NewConcurrentMap(64), + } +} + +func (l *ClientQueryLimiter) Acquire(key string) bool { + return l.m.TestAndSet(key, l.acquireTestAndSet) +} + +func (l *ClientQueryLimiter) acquireTestAndSet(v interface{}, ok bool) (newV interface{}, wantUpdate, passed bool) { + n := 0 + if ok { + n = v.(int) + } + if n >= l.maxQueries { + return nil, false, false + } + n++ + return n, true, true +} + +func (l *ClientQueryLimiter) doneTestAndSet(v interface{}, ok bool) (newV interface{}, wantUpdate, passed bool) { + if !ok { + panic("ClientQueryLimiter doneTestAndSet: value is not exist") + } + n := v.(int) + n-- + if n < 0 { + panic("ClientQueryLimiter doneTestAndSet: value becomes negative") + } + if n == 0 { + return nil, true, true + } + return n, true, true +} + +func (l *ClientQueryLimiter) Done(key string) { + l.m.TestAndSet(key, l.doneTestAndSet) +} diff --git a/dispatcher/utils/server_handler_test.go b/dispatcher/pkg/concurrent_limiter/client_limiter_test.go similarity index 97% rename from dispatcher/utils/server_handler_test.go rename to dispatcher/pkg/concurrent_limiter/client_limiter_test.go index 517faba97..2c7a779e1 100644 --- a/dispatcher/utils/server_handler_test.go +++ b/dispatcher/pkg/concurrent_limiter/client_limiter_test.go @@ -1,4 +1,4 @@ -package utils +package concurrent_limiter import ( "strconv" diff --git a/dispatcher/pkg/concurrent_limiter/concurrent_limiter.go b/dispatcher/pkg/concurrent_limiter/concurrent_limiter.go new file mode 100644 index 000000000..07de47cba --- /dev/null +++ b/dispatcher/pkg/concurrent_limiter/concurrent_limiter.go @@ -0,0 +1,55 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of mosdns. +// +// mosdns is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// mosdns is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package concurrent_limiter + +import "fmt" + +// ConcurrentLimiter is a soft limiter. +type ConcurrentLimiter struct { + bucket chan struct{} +} + +// NewConcurrentLimiter returns a ConcurrentLimiter, max must > 0. +func NewConcurrentLimiter(max int) *ConcurrentLimiter { + if max <= 0 { + panic(fmt.Sprintf("ConcurrentLimiter: invalid max arg: %d", max)) + } + + bucket := make(chan struct{}, max) + for i := 0; i < max; i++ { + bucket <- struct{}{} + } + + return &ConcurrentLimiter{bucket: bucket} +} + +func (l *ConcurrentLimiter) Wait() <-chan struct{} { + return l.bucket +} + +func (l *ConcurrentLimiter) Done() { + select { + case l.bucket <- struct{}{}: + default: + panic("ConcurrentLimiter: bucket overflow") + } +} + +func (l *ConcurrentLimiter) Available() int { + return len(l.bucket) +} diff --git a/dispatcher/pkg/concurrent_limiter/concurrent_limiter_test.go b/dispatcher/pkg/concurrent_limiter/concurrent_limiter_test.go new file mode 100644 index 000000000..4e56b751f --- /dev/null +++ b/dispatcher/pkg/concurrent_limiter/concurrent_limiter_test.go @@ -0,0 +1,34 @@ +package concurrent_limiter + +import ( + "context" + "sync" + "testing" + "time" +) + +func Test_ConcurrentLimiter_acquire_release(t *testing.T) { + l := NewConcurrentLimiter(500) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + + wg := new(sync.WaitGroup) + wg.Add(1000) + for i := 0; i < 1000; i++ { + go func() { + defer wg.Done() + select { + case <-l.Wait(): + time.Sleep(time.Millisecond * 200) + l.Done() + case <-ctx.Done(): + t.Fail() + } + }() + } + + wg.Wait() + if l.Available() != 500 { + t.Fatal("token leaked") + } +} diff --git a/dispatcher/utils/concurrent_lru.go b/dispatcher/pkg/concurrent_lru/concurrent_lru.go similarity index 95% rename from dispatcher/utils/concurrent_lru.go rename to dispatcher/pkg/concurrent_lru/concurrent_lru.go index 6ba85f094..a8672f7aa 100644 --- a/dispatcher/utils/concurrent_lru.go +++ b/dispatcher/pkg/concurrent_lru/concurrent_lru.go @@ -15,9 +15,10 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -package utils +package concurrent_lru import ( + lru2 "github.com/IrineSistiana/mosdns/dispatcher/pkg/lru" "hash/maphash" "sync" ) @@ -40,7 +41,7 @@ func NewConcurrentLRU( for i := range cl.l { cl.l[i] = &shardedLRU{ onGet: onGet, - lru: NewLRU(maxSizePerShard, onEvict), + lru: lru2.NewLRU(maxSizePerShard, onEvict), } } @@ -95,7 +96,7 @@ type shardedLRU struct { onGet func(key string, v interface{}) interface{} sync.Mutex - lru *LRU + lru *lru2.LRU } func (sl *shardedLRU) Add(key string, v interface{}) { diff --git a/dispatcher/utils/concurrent_lru_test.go b/dispatcher/pkg/concurrent_lru/concurrent_lru_test.go similarity index 98% rename from dispatcher/utils/concurrent_lru_test.go rename to dispatcher/pkg/concurrent_lru/concurrent_lru_test.go index 163221efe..2f74595f8 100644 --- a/dispatcher/utils/concurrent_lru_test.go +++ b/dispatcher/pkg/concurrent_lru/concurrent_lru_test.go @@ -1,4 +1,4 @@ -package utils +package concurrent_lru import ( "reflect" diff --git a/dispatcher/utils/concurrent_map.go b/dispatcher/pkg/concurrent_map/concurrent_map.go similarity index 99% rename from dispatcher/utils/concurrent_map.go rename to dispatcher/pkg/concurrent_map/concurrent_map.go index e6bb087f8..68f84c12f 100644 --- a/dispatcher/utils/concurrent_map.go +++ b/dispatcher/pkg/concurrent_map/concurrent_map.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -package utils +package concurrent_map import ( "hash/maphash" diff --git a/dispatcher/utils/concurrent_map_test.go b/dispatcher/pkg/concurrent_map/concurrent_map_test.go similarity index 99% rename from dispatcher/utils/concurrent_map_test.go rename to dispatcher/pkg/concurrent_map/concurrent_map_test.go index 3fe1050fb..6f256fafd 100644 --- a/dispatcher/utils/concurrent_map_test.go +++ b/dispatcher/pkg/concurrent_map/concurrent_map_test.go @@ -1,4 +1,4 @@ -package utils +package concurrent_map import ( "strconv" diff --git a/dispatcher/utils/net_io.go b/dispatcher/pkg/dnsutils/net_io.go similarity index 90% rename from dispatcher/utils/net_io.go rename to dispatcher/pkg/dnsutils/net_io.go index e011894e7..3f867a8be 100644 --- a/dispatcher/utils/net_io.go +++ b/dispatcher/pkg/dnsutils/net_io.go @@ -15,11 +15,12 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -package utils +package dnsutils import ( "encoding/binary" "fmt" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool" "github.com/miekg/dns" "io" "net" @@ -60,8 +61,8 @@ func (e *IOErr) Unwrap() error { // An io err will be wrapped into an IOErr. // IsIOErr(err) can check and unwrap the inner io err. func ReadUDPMsgFrom(c net.PacketConn, bufSize int) (m *dns.Msg, from net.Addr, n int, err error) { - buf := GetMsgBuf(bufSize) - defer ReleaseMsgBuf(buf) + buf := pool.GetMsgBuf(bufSize) + defer pool.ReleaseMsgBuf(buf) n, from, err = c.ReadFrom(buf) if err != nil { @@ -84,8 +85,8 @@ func ReadUDPMsgFrom(c net.PacketConn, bufSize int) (m *dns.Msg, from net.Addr, n // ReadMsgFromUDP See ReadUDPMsgFrom. func ReadMsgFromUDP(c io.Reader, bufSize int) (m *dns.Msg, n int, err error) { - buf := GetMsgBuf(bufSize) - defer ReleaseMsgBuf(buf) + buf := pool.GetMsgBuf(bufSize) + defer pool.ReleaseMsgBuf(buf) n, err = c.Read(buf) if err != nil { @@ -108,11 +109,11 @@ func ReadMsgFromUDP(c io.Reader, bufSize int) (m *dns.Msg, n int, err error) { // An io err will be wrapped into an IOErr. // IsIOErr(err) can check and unwrap the inner io err. func WriteMsgToUDP(c io.Writer, m *dns.Msg) (n int, err error) { - mRaw, buf, err := PackBuffer(m) + mRaw, buf, err := pool.PackBuffer(m) if err != nil { return 0, err } - defer ReleaseMsgBuf(buf) + defer pool.ReleaseMsgBuf(buf) return WriteRawMsgToUDP(c, mRaw) } @@ -128,11 +129,11 @@ func WriteRawMsgToUDP(c io.Writer, b []byte) (n int, err error) { // WriteUDPMsgTo See WriteMsgToUDP. func WriteUDPMsgTo(m *dns.Msg, c net.PacketConn, to net.Addr) (n int, err error) { - mRaw, buf, err := PackBuffer(m) + mRaw, buf, err := pool.PackBuffer(m) if err != nil { return 0, err } - defer ReleaseMsgBuf(buf) + defer pool.ReleaseMsgBuf(buf) n, err = c.WriteTo(mRaw, to) if err != nil { @@ -161,8 +162,8 @@ func ReadMsgFromTCP(c io.Reader) (m *dns.Msg, n int, err error) { return nil, n, dns.ErrShortRead } - buf := GetMsgBuf(int(length)) - defer ReleaseMsgBuf(buf) + buf := pool.GetMsgBuf(int(length)) + defer pool.ReleaseMsgBuf(buf) n2, err := io.ReadFull(c, buf) n = n + n2 @@ -185,11 +186,11 @@ func ReadMsgFromTCP(c io.Reader) (m *dns.Msg, n int, err error) { // An io err will be wrapped into an IOErr. // IsIOErr(err) can check and unwrap the inner io err. func WriteMsgToTCP(c io.Writer, m *dns.Msg) (n int, err error) { - mRaw, buf, err := PackBuffer(m) + mRaw, buf, err := pool.PackBuffer(m) if err != nil { return 0, err } - defer ReleaseMsgBuf(buf) + defer pool.ReleaseMsgBuf(buf) return WriteRawMsgToTCP(c, mRaw) } @@ -222,5 +223,5 @@ func WriteRawMsgToTCP(c io.Writer, b []byte) (n int, err error) { } var ( - tcpWriteBufPool = NewBytesBufPool(512 + 2) + tcpWriteBufPool = pool.NewBytesBufPool(512 + 2) ) diff --git a/dispatcher/utils/net_io_test.go b/dispatcher/pkg/dnsutils/net_io_test.go similarity index 99% rename from dispatcher/utils/net_io_test.go rename to dispatcher/pkg/dnsutils/net_io_test.go index 7659a6a38..2d23ff20b 100644 --- a/dispatcher/utils/net_io_test.go +++ b/dispatcher/pkg/dnsutils/net_io_test.go @@ -14,7 +14,8 @@ // // You should have received a copy of the GNU General Public License // along with this program. If not, see . -package utils + +package dnsutils import ( "bytes" diff --git a/dispatcher/pkg/dnsutils/ttl.go b/dispatcher/pkg/dnsutils/ttl.go new file mode 100644 index 000000000..f5155ba2c --- /dev/null +++ b/dispatcher/pkg/dnsutils/ttl.go @@ -0,0 +1,83 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of mosdns. +// +// mosdns is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// mosdns is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package dnsutils + +import "github.com/miekg/dns" + +// GetMinimalTTL returns the minimal ttl of this msg. +// If msg m has no record, it returns 0. +func GetMinimalTTL(m *dns.Msg) uint32 { + minTTL := ^uint32(0) + hasRecord := false + for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} { + for _, rr := range section { + if rr.Header().Rrtype == dns.TypeOPT { + continue // opt record ttl is not ttl. + } + hasRecord = true + ttl := rr.Header().Ttl + if ttl < minTTL { + minTTL = ttl + } + } + } + + if !hasRecord { // no ttl applied + return 0 + } + return minTTL +} + +// SetTTL updates all records' ttl to ttl, except opt record. +func SetTTL(m *dns.Msg, ttl uint32) { + for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} { + for _, rr := range section { + if rr.Header().Rrtype == dns.TypeOPT { + continue // opt record ttl is not ttl. + } + rr.Header().Ttl = ttl + } + } +} + +func ApplyMaximumTTL(m *dns.Msg, ttl uint32) { + applyTTL(m, ttl, true) +} + +func ApplyMinimalTTL(m *dns.Msg, ttl uint32) { + applyTTL(m, ttl, false) +} + +func applyTTL(m *dns.Msg, ttl uint32, maximum bool) { + for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} { + for _, rr := range section { + if rr.Header().Rrtype == dns.TypeOPT { + continue // opt record ttl is not ttl. + } + if maximum { + if rr.Header().Ttl > ttl { + rr.Header().Ttl = ttl + } + } else { + if rr.Header().Ttl < ttl { + rr.Header().Ttl = ttl + } + } + } + } +} diff --git a/dispatcher/pkg/executable_seq/executable_cmd.go b/dispatcher/pkg/executable_seq/executable_cmd.go new file mode 100644 index 000000000..0b6456737 --- /dev/null +++ b/dispatcher/pkg/executable_seq/executable_cmd.go @@ -0,0 +1,28 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of mosdns. +// +// mosdns is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) or later version. +// +// mosdns is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package executable_seq + +import ( + "context" + "github.com/IrineSistiana/mosdns/dispatcher/handler" + "go.uber.org/zap" +) + +type ExecutableCmd interface { + ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) +} diff --git a/dispatcher/pkg/executable_seq/executable_cmd_sequence.go b/dispatcher/pkg/executable_seq/executable_cmd_sequence.go new file mode 100644 index 000000000..16c5fffc6 --- /dev/null +++ b/dispatcher/pkg/executable_seq/executable_cmd_sequence.go @@ -0,0 +1,257 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of mosdns. +// +// mosdns is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// mosdns is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package executable_seq + +import ( + "context" + "errors" + "fmt" + "github.com/IrineSistiana/mosdns/dispatcher/handler" + "go.uber.org/zap" + "reflect" + "strings" +) + +type executablePluginTag struct { + s string +} + +func (t executablePluginTag) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) { + p, err := handler.GetPlugin(t.s) + if err != nil { + return "", false, err + } + + logger.Debug("exec executable plugin", qCtx.InfoField(), zap.String("exec", t.s)) + earlyStop, err = p.ExecES(ctx, qCtx) + return "", earlyStop, err +} + +type IfBlockConfig struct { + If []string `yaml:"if"` + IfAnd []string `yaml:"if_and"` + Exec []interface{} `yaml:"exec"` + Goto string `yaml:"goto"` +} + +type matcher struct { + tag string + negate bool +} + +func paresMatcher(s []string) []matcher { + m := make([]matcher, 0, len(s)) + for _, tag := range s { + if strings.HasPrefix(tag, "!") { + m = append(m, matcher{tag: strings.TrimPrefix(tag, "!"), negate: true}) + } else { + m = append(m, matcher{tag: tag}) + } + } + return m +} + +type IfBlock struct { + ifMatcher []matcher + ifAndMatcher []matcher + executableCmd ExecutableCmd + goTwo string +} + +func (b *IfBlock) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) { + if len(b.ifMatcher) > 0 { + If, err := ifCondition(ctx, qCtx, logger, b.ifMatcher, false) + if err != nil { + return "", false, err + } + if If == false { + return "", false, nil // if case returns false, skip this block. + } + } + + if len(b.ifAndMatcher) > 0 { + If, err := ifCondition(ctx, qCtx, logger, b.ifAndMatcher, true) + if err != nil { + return "", false, err + } + if If == false { + return "", false, nil + } + } + + // exec + if b.executableCmd != nil { + goTwo, earlyStop, err = b.executableCmd.ExecCmd(ctx, qCtx, logger) + if err != nil { + return "", false, err + } + if len(goTwo) != 0 || earlyStop { + return goTwo, earlyStop, nil + } + } + + // goto + if len(b.goTwo) != 0 { // if block has a goto, return it + return b.goTwo, false, nil + } + + return "", false, nil +} + +func ifCondition(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, p []matcher, isAnd bool) (ok bool, err error) { + if len(p) == 0 { + return false, err + } + + for _, m := range p { + mp, err := handler.GetPlugin(m.tag) + if err != nil { + return false, err + } + matched, err := mp.Match(ctx, qCtx) + if err != nil { + return false, err + } + logger.Debug("exec matcher plugin", qCtx.InfoField(), zap.String("exec", m.tag), zap.Bool("result", matched)) + + res := matched != m.negate + if !isAnd && res == true { + return true, nil // or: if one of the case is true, skip others. + } + if isAnd && res == false { + return false, nil // and: if one of the case is false, skip others. + } + + ok = res + } + return ok, nil +} + +func ParseIfBlock(in map[string]interface{}) (*IfBlock, error) { + c := new(IfBlockConfig) + err := handler.WeakDecode(in, c) + if err != nil { + return nil, err + } + + b := &IfBlock{ + ifMatcher: paresMatcher(c.If), + ifAndMatcher: paresMatcher(c.IfAnd), + goTwo: c.Goto, + } + + if len(c.Exec) != 0 { + ecs, err := ParseExecutableCmdSequence(c.Exec) + if err != nil { + return nil, err + } + b.executableCmd = ecs + } + + return b, nil +} + +type ExecutableCmdSequence struct { + c []ExecutableCmd +} + +func ParseExecutableCmdSequence(in []interface{}) (*ExecutableCmdSequence, error) { + es := &ExecutableCmdSequence{c: make([]ExecutableCmd, 0, len(in))} + for i, v := range in { + ec, err := parseExecutableCmd(v) + if err != nil { + return nil, fmt.Errorf("invalid cmd #%d: %w", i, err) + } + es.c = append(es.c, ec) + } + return es, nil +} + +func parseExecutableCmd(in interface{}) (ExecutableCmd, error) { + switch v := in.(type) { + case string: + return &executablePluginTag{s: v}, nil + case map[string]interface{}: + switch { + case hasKey(v, "if") || hasKey(v, "if_and"): // if block + ec, err := ParseIfBlock(v) + if err != nil { + return nil, fmt.Errorf("invalid if section: %w", err) + } + return ec, nil + case hasKey(v, "parallel"): // parallel + ec, err := parseParallelECS(v) + if err != nil { + return nil, fmt.Errorf("invalid parallel section: %w", err) + } + return ec, nil + case hasKey(v, "primary") || hasKey(v, "secondary"): // fallback + ec, err := parseFallbackECS(v) + if err != nil { + return nil, fmt.Errorf("invalid fallback section: %w", err) + } + return ec, nil + default: + return nil, errors.New("unknown section") + } + default: + return nil, fmt.Errorf("unexpected type: %s", reflect.TypeOf(in).String()) + } +} + +func parseParallelECS(m map[string]interface{}) (ec ExecutableCmd, err error) { + conf := new(ParallelECSConfig) + err = handler.WeakDecode(m, conf) + if err != nil { + return nil, err + } + return ParseParallelECS(conf) +} + +func parseFallbackECS(m map[string]interface{}) (ec ExecutableCmd, err error) { + conf := new(FallbackConfig) + err = handler.WeakDecode(m, conf) + if err != nil { + return nil, err + } + return ParseFallbackECS(conf) +} + +func hasKey(m map[string]interface{}, key string) bool { + _, ok := m[key] + return ok +} + +// ExecCmd executes the sequence. +func (es *ExecutableCmdSequence) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) { + for _, cmd := range es.c { + goTwo, earlyStop, err = cmd.ExecCmd(ctx, qCtx, logger) + if err != nil { + return "", false, err + } + if len(goTwo) != 0 || earlyStop { + return goTwo, earlyStop, nil + } + } + + return "", false, nil +} + +func (es *ExecutableCmdSequence) Len() int { + return len(es.c) +} diff --git a/dispatcher/pkg/executable_seq/executable_cmd_sequence_test.go b/dispatcher/pkg/executable_seq/executable_cmd_sequence_test.go new file mode 100644 index 000000000..5c76095b7 --- /dev/null +++ b/dispatcher/pkg/executable_seq/executable_cmd_sequence_test.go @@ -0,0 +1,166 @@ +package executable_seq + +import ( + "context" + "errors" + "github.com/IrineSistiana/mosdns/dispatcher/handler" + "github.com/miekg/dns" + "go.uber.org/zap" + "gopkg.in/yaml.v3" + "testing" +) + +func Test_ECS(t *testing.T) { + handler.PurgePluginRegister() + defer handler.PurgePluginRegister() + + mErr := errors.New("mErr") + eErr := errors.New("eErr") + + var tests = []struct { + name string + yamlStr string + wantNext string + wantES bool + wantErr error + }{ + {name: "test empty input", yamlStr: ` +exec: +`, + wantNext: "", wantErr: nil}, + {name: "test empty end", yamlStr: ` +exec: +- if: ["!matched",not_matched] # not matched + exec: [exec_err] + goto: goto`, + wantNext: "", wantErr: nil}, + + {name: "test if_and", yamlStr: ` +exec: +- if_and: [matched, not_matched] # not matched + goto: goto1 +- if_and: [matched, not_matched, match_err] # not matched, early stop, no err + goto: goto2 +- if_and: [matched, matched, matched] # matched + goto: goto3 +`, + wantNext: "goto3", wantErr: nil}, + + {name: "test if_and err", yamlStr: ` +exec: +- if_and: [matched, match_err] # err + goto: goto1 +`, + wantNext: "", wantErr: mErr}, + + {name: "test if", yamlStr: ` +exec: +- if: ["!matched", not_matched] # test ! prefix, not matched + goto: goto1 +- if: [matched, match_err] # matched, early stop, no err + exec: + - if: ["!not_matched", not_matched] # matched + goto: goto2 # reached here + goto: goto3 +`, + wantNext: "goto2", wantErr: nil}, + + {name: "test if err", yamlStr: ` +exec: +- if: [not_matched, match_err] # err + goto: goto1 +`, + wantNext: "", wantErr: mErr}, + + {name: "test exec err", yamlStr: ` +exec: +- if: [matched] + exec: exec_err + goto: goto1 +`, + wantNext: "", wantErr: eErr}, + + {name: "test early return in main sequence", yamlStr: ` +exec: +- exec +- exec_skip +- exec_err # skipped, should not reach here. +`, + wantNext: "", wantES: true, wantErr: nil}, + + {name: "test early return in if branch", yamlStr: ` +exec: +- if: [matched] + exec: + - exec_skip + goto: goto1 # skipped, should not reach here. +`, + wantNext: "", wantES: true, wantErr: nil}, + } + + // not_matched + handler.MustRegPlugin(&handler.DummyMatcherPlugin{ + BP: handler.NewBP("not_matched", ""), + Matched: false, + WantErr: nil, + }, true) + + // do something + handler.MustRegPlugin(&handler.DummyExecutablePlugin{ + BP: handler.NewBP("exec", ""), + WantErr: nil, + }, true) + + // do something and skip the following sequence + handler.MustRegPlugin(&handler.DummyESExecutablePlugin{ + BP: handler.NewBP("exec_skip", ""), + WantSkip: true, + }, true) + + // matched + handler.MustRegPlugin(&handler.DummyMatcherPlugin{ + BP: handler.NewBP("matched", ""), + Matched: true, + WantErr: nil, + }, true) + + // plugins should return an err. + handler.MustRegPlugin(&handler.DummyMatcherPlugin{ + BP: handler.NewBP("match_err", ""), + Matched: false, + WantErr: mErr, + }, true) + + handler.MustRegPlugin(&handler.DummyExecutablePlugin{ + BP: handler.NewBP("exec_err", ""), + WantErr: eErr, + }, true) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := make(map[string]interface{}, 0) + err := yaml.Unmarshal([]byte(tt.yamlStr), args) + if err != nil { + t.Fatal(err) + } + in, _ := args["exec"].([]interface{}) + ecs, err := ParseExecutableCmdSequence(in) + if err != nil { + t.Fatal(err) + } + + gotNext, gotEarlyStop, err := ecs.ExecCmd(context.Background(), handler.NewContext(new(dns.Msg), nil), zap.NewNop()) + if (err != nil || tt.wantErr != nil) && !errors.Is(err, tt.wantErr) { + t.Errorf("ExecCmd() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotNext != tt.wantNext { + t.Errorf("ExecCmd() gotNext = %v, want %v", gotNext, tt.wantNext) + } + + if gotEarlyStop != tt.wantES { + t.Errorf("ExecCmd() gotEarlyStop = %v, want %v", gotEarlyStop, tt.wantES) + } + }) + } +} diff --git a/dispatcher/pkg/executable_seq/fallback.go b/dispatcher/pkg/executable_seq/fallback.go new file mode 100644 index 000000000..c892ac0b9 --- /dev/null +++ b/dispatcher/pkg/executable_seq/fallback.go @@ -0,0 +1,268 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of mosdns. +// +// mosdns is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) or later version. +// +// mosdns is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package executable_seq + +import ( + "context" + "errors" + "fmt" + "github.com/IrineSistiana/mosdns/dispatcher/handler" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool" + "go.uber.org/zap" + "sync" + "time" +) + +type FallbackConfig struct { + // Primary exec sequence, must have at least one element. + Primary []interface{} `yaml:"primary"` + // Secondary exec sequence, must have at least one element. + Secondary []interface{} `yaml:"secondary"` + + StatLength int `yaml:"stat_length"` // An Zero value disables the (normal) fallback. + Threshold int `yaml:"threshold"` + + // FastFallback threshold in milliseconds. Zero means fast fallback is disabled. + FastFallback int `yaml:"fast_fallback"` + + // AlwaysStandby: secondary should always standby in fast fallback. + AlwaysStandby bool `yaml:"always_standby"` +} + +type FallbackECS struct { + primary *ExecutableCmdSequence + secondary *ExecutableCmdSequence + fastFallbackDuration time.Duration + alwaysStandby bool + + primaryST *statusTracker // nil if normal fallback is disabled +} + +type statusTracker struct { + sync.Mutex + threshold int + status []uint8 // 0 means success, !0 means failed + p int +} + +func newStatusTracker(threshold, statLength int) *statusTracker { + return &statusTracker{ + threshold: threshold, + status: make([]uint8, statLength), + } +} + +func (t *statusTracker) good() bool { + t.Lock() + defer t.Unlock() + + var failedSum int + for _, s := range t.status { + if s != 0 { + failedSum++ + } + } + return failedSum < t.threshold +} + +func (t *statusTracker) update(s uint8) { + t.Lock() + defer t.Unlock() + + if t.p >= len(t.status) { + t.p = 0 + } + t.status[t.p] = s + t.p++ +} + +func ParseFallbackECS(c *FallbackConfig) (*FallbackECS, error) { + if len(c.Primary) == 0 { + return nil, errors.New("primary sequence is empty") + } + if len(c.Secondary) == 0 { + return nil, errors.New("secondary sequence is empty") + } + + primaryECS, err := ParseExecutableCmdSequence(c.Primary) + if err != nil { + return nil, fmt.Errorf("invalid primary sequence: %w", err) + } + + secondaryECS, err := ParseExecutableCmdSequence(c.Secondary) + if err != nil { + return nil, fmt.Errorf("invalid secondary sequence: %w", err) + } + + fallbackECS := &FallbackECS{ + primary: primaryECS, + secondary: secondaryECS, + fastFallbackDuration: time.Duration(c.FastFallback) * time.Millisecond, + alwaysStandby: c.AlwaysStandby, + } + + if c.StatLength > 0 { + if c.Threshold > c.StatLength { + c.Threshold = c.StatLength + } + fallbackECS.primaryST = newStatusTracker(c.Threshold, c.StatLength) + } + + return fallbackECS, nil +} + +func (f *FallbackECS) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) { + return "", false, f.execCmd(ctx, qCtx, logger) +} + +func (f *FallbackECS) execCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { + if f.primaryST == nil || f.primaryST.good() { + if f.fastFallbackDuration > 0 { + return f.doFastFallback(ctx, qCtx, logger) + } else { + return f.isolateDoPrimary(ctx, qCtx, logger) + } + } + logger.Debug("primary is not good", qCtx.InfoField()) + return f.doFallback(ctx, qCtx, logger) +} + +func (f *FallbackECS) isolateDoPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { + qCtxCopy := qCtx.Copy() + err = f.doPrimary(ctx, qCtxCopy, logger) + qCtx.SetResponse(qCtxCopy.R(), qCtxCopy.Status()) + return err +} + +func (f *FallbackECS) doPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { + err = WalkExecutableCmd(ctx, qCtx, logger, f.primary) + if err == nil { + err = qCtx.ExecDefer(ctx) + } + if f.primaryST != nil { + if err != nil || qCtx.R() == nil { + f.primaryST.update(1) + } else { + f.primaryST.update(0) + } + } + + return err +} + +func (f *FallbackECS) doFastFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { + fCtx, cancel := context.WithCancel(ctx) + defer cancel() + + timer := pool.GetTimer(f.fastFallbackDuration) + defer pool.ReleaseTimer(timer) + + c := make(chan *parallelECSResult, 2) + primFailed := make(chan struct{}) // will be closed if primary returns an err. + + qCtxCopyP := qCtx.Copy() + go func() { + err := f.doPrimary(fCtx, qCtxCopyP, logger) + if err != nil || qCtxCopyP.R() == nil { + close(primFailed) + } + c <- ¶llelECSResult{ + r: qCtxCopyP.R(), + status: qCtxCopyP.Status(), + err: err, + from: 1, + } + }() + + qCtxCopyS := qCtx.Copy() + go func() { + if !f.alwaysStandby { // not always standby, wait here. + select { + case <-fCtx.Done(): // primary is done, no needs to exec this. + return + case <-primFailed: // primary failed or timeout, exec now. + case <-timer.C: + } + } + + err := f.doSecondary(fCtx, qCtxCopyS, logger) + res := ¶llelECSResult{ + r: qCtxCopyS.R(), + status: qCtxCopyS.Status(), + err: err, + from: 2, + } + + if f.alwaysStandby { // always standby + select { + case <-fCtx.Done(): + return + case <-primFailed: // only send secondary result when primary is failed. + c <- res + case <-timer.C: // or timeout. + c <- res + } + } else { + c <- res // not always standby, send the result asap. + } + }() + + return asyncWait(ctx, qCtx, logger, c, 2) +} + +func (f *FallbackECS) doSecondary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { + err = WalkExecutableCmd(ctx, qCtx, logger, f.secondary) + if err == nil { + err = qCtx.ExecDefer(ctx) + } + return err +} + +func (f *FallbackECS) doFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) error { + fCtx, cancel := context.WithCancel(ctx) + defer cancel() + + c := make(chan *parallelECSResult, 2) // buf size is 2, avoid block. + + qCtxCopyP := qCtx.Copy() + go func() { + err := f.doPrimary(fCtx, qCtxCopyP, logger) + c <- ¶llelECSResult{ + r: qCtxCopyP.R(), + status: qCtxCopyP.Status(), + err: err, + from: 1, + } + }() + + qCtxCopyS := qCtx.Copy() + go func() { + err := WalkExecutableCmd(fCtx, qCtxCopyS, logger, f.secondary) + if err == nil { + err = qCtxCopyS.ExecDefer(fCtx) + } + c <- ¶llelECSResult{ + r: qCtxCopyS.R(), + status: qCtxCopyS.Status(), + err: err, + from: 2, + } + }() + + return asyncWait(ctx, qCtx, logger, c, 2) +} diff --git a/dispatcher/utils/executable_sequence_test.go b/dispatcher/pkg/executable_seq/fallback_test.go similarity index 51% rename from dispatcher/utils/executable_sequence_test.go rename to dispatcher/pkg/executable_seq/fallback_test.go index 07c6a50cb..8297163f9 100644 --- a/dispatcher/utils/executable_sequence_test.go +++ b/dispatcher/pkg/executable_seq/fallback_test.go @@ -1,4 +1,21 @@ -package utils +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of mosdns. +// +// mosdns is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) or later version. +// +// mosdns is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package executable_seq import ( "context" @@ -6,224 +23,10 @@ import ( "github.com/IrineSistiana/mosdns/dispatcher/handler" "github.com/miekg/dns" "go.uber.org/zap" - "gopkg.in/yaml.v3" "testing" "time" ) -func Test_ECS(t *testing.T) { - handler.PurgePluginRegister() - defer handler.PurgePluginRegister() - - mErr := errors.New("mErr") - eErr := errors.New("eErr") - - var tests = []struct { - name string - yamlStr string - wantNext string - wantES bool - wantErr error - }{ - {name: "test empty end", yamlStr: ` -exec: -- if: ["!matched",not_matched] # not matched - exec: [exec_err] - goto: goto`, - wantNext: "", wantErr: nil}, - - {name: "test if_and", yamlStr: ` -exec: -- if_and: [matched, not_matched] # not matched - goto: goto1 -- if_and: [matched, not_matched, match_err] # not matched, early stop, no err - goto: goto2 -- if_and: [matched, matched, matched] # matched - goto: goto3 -`, - wantNext: "goto3", wantErr: nil}, - - {name: "test if_and err", yamlStr: ` -exec: -- if_and: [matched, match_err] # err - goto: goto1 -`, - wantNext: "", wantErr: mErr}, - - {name: "test if", yamlStr: ` -exec: -- if: ["!matched", not_matched] # test ! prefix, not matched - goto: goto1 -- if: [matched, match_err] # matched, early stop, no err - exec: - - if: ["!not_matched", not_matched] # matched - goto: goto2 # reached here - goto: goto3 -`, - wantNext: "goto2", wantErr: nil}, - - {name: "test if err", yamlStr: ` -exec: -- if: [not_matched, match_err] # err - goto: goto1 -`, - wantNext: "", wantErr: mErr}, - - {name: "test exec err", yamlStr: ` -exec: -- if: [matched] - exec: exec_err - goto: goto1 -`, - wantNext: "", wantErr: eErr}, - - {name: "test early return in main sequence", yamlStr: ` -exec: -- exec -- exec_skip -- exec_err # skipped, should not reach here. -`, - wantNext: "", wantES: true, wantErr: nil}, - - {name: "test early return in if branch", yamlStr: ` -exec: -- if: [matched] - exec: - - exec_skip - goto: goto1 # skipped, should not reach here. -`, - wantNext: "", wantES: true, wantErr: nil}, - } - - // not_matched - handler.MustRegPlugin(&handler.DummyMatcherPlugin{ - BP: handler.NewBP("not_matched", ""), - Matched: false, - WantErr: nil, - }, true) - - // do something - handler.MustRegPlugin(&handler.DummyExecutablePlugin{ - BP: handler.NewBP("exec", ""), - WantErr: nil, - }, true) - - // do something and skip the following sequence - handler.MustRegPlugin(&handler.DummyESExecutablePlugin{ - BP: handler.NewBP("exec_skip", ""), - WantSkip: true, - }, true) - - // matched - handler.MustRegPlugin(&handler.DummyMatcherPlugin{ - BP: handler.NewBP("matched", ""), - Matched: true, - WantErr: nil, - }, true) - - // plugins should return an err. - handler.MustRegPlugin(&handler.DummyMatcherPlugin{ - BP: handler.NewBP("match_err", ""), - Matched: false, - WantErr: mErr, - }, true) - - handler.MustRegPlugin(&handler.DummyExecutablePlugin{ - BP: handler.NewBP("exec_err", ""), - WantErr: eErr, - }, true) - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - args := make(map[string]interface{}, 0) - err := yaml.Unmarshal([]byte(tt.yamlStr), args) - if err != nil { - t.Fatal(err) - } - ecs, err := ParseExecutableCmdSequence(args["exec"].([]interface{})) - if err != nil { - t.Fatal(err) - } - - gotNext, gotEarlyStop, err := ecs.ExecCmd(context.Background(), handler.NewContext(new(dns.Msg), nil), zap.NewNop()) - if (err != nil || tt.wantErr != nil) && !errors.Is(err, tt.wantErr) { - t.Errorf("ExecCmd() error = %v, wantErr %v", err, tt.wantErr) - return - } - if gotNext != tt.wantNext { - t.Errorf("ExecCmd() gotNext = %v, want %v", gotNext, tt.wantNext) - } - - if gotEarlyStop != tt.wantES { - t.Errorf("ExecCmd() gotEarlyStop = %v, want %v", gotEarlyStop, tt.wantES) - } - }) - } -} - -func Test_ParallelECS(t *testing.T) { - handler.PurgePluginRegister() - defer handler.PurgePluginRegister() - - r1 := new(dns.Msg) - r2 := new(dns.Msg) - - er := errors.New("") - tests := []struct { - name string - r1 *dns.Msg - e1 error - r2 *dns.Msg - e2 error - wantR *dns.Msg - wantErr bool - }{ - {"failed #1", nil, er, nil, er, nil, true}, - {"failed #2", nil, nil, nil, nil, nil, true}, - {"p1 response #1", r1, nil, nil, nil, r1, false}, - {"p1 response #2", r1, nil, nil, er, r1, false}, - {"p2 response #1", nil, nil, r2, nil, r2, false}, - {"p2 response #2", nil, er, r2, nil, r2, false}, - } - - parallelECS, err := ParseParallelECS(&ParallelECSConfig{ - Parallel: [][]interface{}{{"p1"}, {"p2"}}, - }) - if err != nil { - t.Fatal(err) - } - - ctx := context.Background() - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p1 := &handler.DummyExecutablePlugin{ - BP: handler.NewBP("p1", ""), - Sleep: 0, - WantR: tt.r1, - WantErr: tt.e1, - } - p2 := &handler.DummyExecutablePlugin{ - BP: handler.NewBP("p2", ""), - Sleep: 0, - WantR: tt.r2, - WantErr: tt.e2, - } - handler.MustRegPlugin(p1, false) - handler.MustRegPlugin(p2, false) - - qCtx := handler.NewContext(new(dns.Msg), nil) - err := parallelECS.execCmd(ctx, qCtx, zap.NewNop()) - if tt.wantErr != (err != nil) { - t.Fatalf("execCmd() error = %v, wantErr %v", err, tt.wantErr) - } - - if tt.wantR != qCtx.R() { - t.Fatalf("execCmd() qCtx.R() = %p, wantR %p", qCtx.R(), tt.wantR) - } - }) - } -} - func Test_FallbackECS_fallback(t *testing.T) { handler.PurgePluginRegister() defer handler.PurgePluginRegister() diff --git a/dispatcher/pkg/executable_seq/parallel.go b/dispatcher/pkg/executable_seq/parallel.go new file mode 100644 index 000000000..31c654f97 --- /dev/null +++ b/dispatcher/pkg/executable_seq/parallel.go @@ -0,0 +1,107 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of mosdns. +// +// mosdns is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) or later version. +// +// mosdns is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package executable_seq + +import ( + "context" + "fmt" + "github.com/IrineSistiana/mosdns/dispatcher/handler" + "github.com/miekg/dns" + "go.uber.org/zap" + "time" +) + +type ParallelECS struct { + s []*ExecutableCmdSequence + timeout time.Duration +} + +type ParallelECSConfig struct { + Parallel [][]interface{} `yaml:"parallel"` + Timeout uint `yaml:"timeout"` +} + +func ParseParallelECS(c *ParallelECSConfig) (*ParallelECS, error) { + if len(c.Parallel) < 2 { + return nil, fmt.Errorf("parallel needs at least 2 cmd sequences, but got %d", len(c.Parallel)) + } + + ps := make([]*ExecutableCmdSequence, 0, len(c.Parallel)) + for i, subSequence := range c.Parallel { + es, err := ParseExecutableCmdSequence(subSequence) + if err != nil { + return nil, fmt.Errorf("invalid parallel sequence at index %d: %w", i, err) + } + ps = append(ps, es) + } + return &ParallelECS{s: ps, timeout: time.Duration(c.Timeout) * time.Second}, nil +} + +type parallelECSResult struct { + r *dns.Msg + status handler.ContextStatus + err error + from int +} + +func (p *ParallelECS) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) { + return "", false, p.execCmd(ctx, qCtx, logger) +} + +func (p *ParallelECS) execCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { + + var pCtx context.Context // only valid if p.timeout == 0 + var cancel func() + if p.timeout == 0 { + pCtx, cancel = context.WithCancel(ctx) + defer cancel() + } + + t := len(p.s) + c := make(chan *parallelECSResult, len(p.s)) // use buf chan to avoid block. + + for i, sequence := range p.s { + i := i + sequence := sequence + qCtxCopy := qCtx.Copy() + + go func() { + var ecsCtx context.Context + var ecsCancel func() + if p.timeout == 0 { + ecsCtx = pCtx + } else { + ecsCtx, ecsCancel = context.WithTimeout(context.Background(), p.timeout) + defer ecsCancel() + } + + err := WalkExecutableCmd(ecsCtx, qCtxCopy, logger, sequence) + if err == nil { + err = qCtxCopy.ExecDefer(pCtx) + } + c <- ¶llelECSResult{ + r: qCtxCopy.R(), + status: qCtxCopy.Status(), + err: err, + from: i, + } + }() + } + + return asyncWait(ctx, qCtx, logger, c, t) +} diff --git a/dispatcher/pkg/executable_seq/parallel_test.go b/dispatcher/pkg/executable_seq/parallel_test.go new file mode 100644 index 000000000..2cf0caf5e --- /dev/null +++ b/dispatcher/pkg/executable_seq/parallel_test.go @@ -0,0 +1,90 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of mosdns. +// +// mosdns is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) or later version. +// +// mosdns is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package executable_seq + +import ( + "context" + "errors" + "github.com/IrineSistiana/mosdns/dispatcher/handler" + "github.com/miekg/dns" + "go.uber.org/zap" + "testing" +) + +func Test_ParallelECS(t *testing.T) { + handler.PurgePluginRegister() + defer handler.PurgePluginRegister() + + r1 := new(dns.Msg) + r2 := new(dns.Msg) + + er := errors.New("") + tests := []struct { + name string + r1 *dns.Msg + e1 error + r2 *dns.Msg + e2 error + wantR *dns.Msg + wantErr bool + }{ + {"failed #1", nil, er, nil, er, nil, true}, + {"failed #2", nil, nil, nil, nil, nil, true}, + {"p1 response #1", r1, nil, nil, nil, r1, false}, + {"p1 response #2", r1, nil, nil, er, r1, false}, + {"p2 response #1", nil, nil, r2, nil, r2, false}, + {"p2 response #2", nil, er, r2, nil, r2, false}, + } + + parallelECS, err := ParseParallelECS(&ParallelECSConfig{ + Parallel: [][]interface{}{{"p1"}, {"p2"}}, + }) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p1 := &handler.DummyExecutablePlugin{ + BP: handler.NewBP("p1", ""), + Sleep: 0, + WantR: tt.r1, + WantErr: tt.e1, + } + p2 := &handler.DummyExecutablePlugin{ + BP: handler.NewBP("p2", ""), + Sleep: 0, + WantR: tt.r2, + WantErr: tt.e2, + } + handler.MustRegPlugin(p1, false) + handler.MustRegPlugin(p2, false) + + qCtx := handler.NewContext(new(dns.Msg), nil) + err := parallelECS.execCmd(ctx, qCtx, zap.NewNop()) + if tt.wantErr != (err != nil) { + t.Fatalf("execCmd() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantR != qCtx.R() { + t.Fatalf("execCmd() qCtx.R() = %p, wantR %p", qCtx.R(), tt.wantR) + } + }) + } +} diff --git a/dispatcher/pkg/executable_seq/utils.go b/dispatcher/pkg/executable_seq/utils.go new file mode 100644 index 000000000..9f36cc0e1 --- /dev/null +++ b/dispatcher/pkg/executable_seq/utils.go @@ -0,0 +1,73 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of mosdns. +// +// mosdns is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) or later version. +// +// mosdns is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package executable_seq + +import ( + "context" + "errors" + "github.com/IrineSistiana/mosdns/dispatcher/handler" + "go.uber.org/zap" +) + +// WalkExecutableCmd executes the ExecutableCmd, include its `goto`. +// This should only be used in root cmd node. +func WalkExecutableCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, entry ExecutableCmd) (err error) { + goTwo, _, err := entry.ExecCmd(ctx, qCtx, logger) + if err != nil { + return err + } + + if len(goTwo) != 0 { + logger.Debug("goto plugin", qCtx.InfoField(), zap.String("goto", goTwo)) + p, err := handler.GetPlugin(goTwo) + if err != nil { + return err + } + _, err = p.ExecES(ctx, qCtx) + return err + } + return nil +} + +func asyncWait(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, c chan *parallelECSResult, total int) error { + for i := 0; i < total; i++ { + select { + case res := <-c: + if res.err != nil { + logger.Warn("sequence failed", qCtx.InfoField(), zap.Int("sequence", res.from), zap.Error(res.err)) + continue + } + + if res.r == nil { + logger.Debug("sequence returned with an empty response", qCtx.InfoField(), zap.Int("sequence", res.from)) + continue + } + + logger.Debug("sequence returned a response", qCtx.InfoField(), zap.Int("sequence", res.from)) + qCtx.SetResponse(res.r, res.status) + return nil + + case <-ctx.Done(): + return ctx.Err() + } + } + + // No response + qCtx.SetResponse(nil, handler.ContextStatusServerFailed) + return errors.New("no response") +} diff --git a/dispatcher/utils/load_once.go b/dispatcher/pkg/load_cache/load_cache.go similarity index 79% rename from dispatcher/utils/load_once.go rename to dispatcher/pkg/load_cache/load_cache.go index 89b9eb44b..e3beabc8d 100644 --- a/dispatcher/utils/load_once.go +++ b/dispatcher/pkg/load_cache/load_cache.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -package utils +package load_cache import ( "io/ioutil" @@ -24,18 +24,18 @@ import ( "time" ) -type LoadOnceCache struct { +type LoadCache struct { l sync.Mutex cache map[string]interface{} } -func NewCache() *LoadOnceCache { - return &LoadOnceCache{ +func NewCache() *LoadCache { + return &LoadCache{ cache: make(map[string]interface{}), } } -func (c *LoadOnceCache) Put(key string, data interface{}, ttl time.Duration) { +func (c *LoadCache) Put(key string, data interface{}, ttl time.Duration) { if ttl <= 0 { return } @@ -49,14 +49,14 @@ func (c *LoadOnceCache) Put(key string, data interface{}, ttl time.Duration) { time.AfterFunc(ttl, rm) } -func (c *LoadOnceCache) Remove(key string) { +func (c *LoadCache) Remove(key string) { c.l.Lock() defer c.l.Unlock() delete(c.cache, key) } -func (c *LoadOnceCache) Load(key string) (interface{}, bool) { +func (c *LoadCache) Load(key string) (interface{}, bool) { c.l.Lock() defer c.l.Unlock() @@ -64,7 +64,7 @@ func (c *LoadOnceCache) Load(key string) (interface{}, bool) { return data, ok } -func (c *LoadOnceCache) LoadFromCacheOrRawDisk(file string) (interface{}, []byte, error) { +func (c *LoadCache) LoadFromCacheOrRawDisk(file string) (interface{}, []byte, error) { // load from cache data, ok := c.Load(file) if ok { diff --git a/dispatcher/utils/lru.go b/dispatcher/pkg/lru/lru.go similarity index 99% rename from dispatcher/utils/lru.go rename to dispatcher/pkg/lru/lru.go index 52d3a7b2f..6af59ca7f 100644 --- a/dispatcher/utils/lru.go +++ b/dispatcher/pkg/lru/lru.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -package utils +package lru import ( "container/list" diff --git a/dispatcher/utils/lru_test.go b/dispatcher/pkg/lru/lru_test.go similarity index 99% rename from dispatcher/utils/lru_test.go rename to dispatcher/pkg/lru/lru_test.go index 30a0b0878..dc3a7c2f6 100644 --- a/dispatcher/utils/lru_test.go +++ b/dispatcher/pkg/lru/lru_test.go @@ -1,4 +1,4 @@ -package utils +package lru import ( "reflect" diff --git a/dispatcher/matcher/domain/interface.go b/dispatcher/pkg/matcher/domain/interface.go similarity index 100% rename from dispatcher/matcher/domain/interface.go rename to dispatcher/pkg/matcher/domain/interface.go diff --git a/dispatcher/matcher/domain/load_helper.go b/dispatcher/pkg/matcher/domain/load_helper.go similarity index 77% rename from dispatcher/matcher/domain/load_helper.go rename to dispatcher/pkg/matcher/domain/load_helper.go index 1bbc32d96..4f69c82f2 100644 --- a/dispatcher/matcher/domain/load_helper.go +++ b/dispatcher/pkg/matcher/domain/load_helper.go @@ -22,8 +22,9 @@ import ( "bytes" "errors" "fmt" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/v2data" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/load_cache" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/v2data" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "io" "io/ioutil" "strings" @@ -32,7 +33,7 @@ import ( "github.com/golang/protobuf/proto" ) -var matcherCache = utils.NewCache() +var matcherCache = load_cache.NewCache() const ( cacheTTL = time.Second * 30 @@ -41,35 +42,32 @@ const ( // ProcessAttrFunc processes the additional attributions. The given []string could have a 0 length or is nil. type ProcessAttrFunc func([]string) (v interface{}, accept bool, err error) -// LoadFromFile loads data from the file. -// File can be a text file or a v2ray data file. -// Only MixMatcher can load v2ray data file. -// v2ray data file needs to specify the data category by using ':', e.g. 'geosite.dat:cn' -func LoadFromFile(m Matcher, file string, processAttr ProcessAttrFunc) error { - var err error - if tmp := strings.SplitN(file, ":", 2); len(tmp) == 2 { // is a v2ray data file - mixMatcher, ok := m.(*MixMatcher) - if !ok { - return errors.New("only MixMatcher can load v2ray data file") - } - filePath := tmp[0] - countryCode := tmp[1] - err = mixMatcher.LoadFromDAT(filePath, countryCode, processAttr) - } else { // is a text file - err = LoadFromTextFile(m, file, processAttr) - } - if err != nil { - return err +// Load loads data from a entry. +// If entry begin with "ext:", Load loads the file by using LoadFromFile. +// Else it loads the entry as a text pattern by using LoadFromText. +func Load(m Matcher, entry string, processAttr ProcessAttrFunc) error { + s1, s2, ok := utils.SplitString2(entry, ":") + if ok && s1 == "ext" { + return LoadFromFile(m, s2, processAttr) } + return LoadFromText(m, entry, processAttr) +} +// BatchLoadMatcher loads multiple files using Load. +func BatchLoadMatcher(m Matcher, f []string, processAttr ProcessAttrFunc) error { + for _, file := range f { + err := Load(m, file, processAttr) + if err != nil { + return fmt.Errorf("failed to load file %s: %w", file, err) + } + } return nil } -// LoadFromFileAsV2Matcher loads data from a file. +// LoadFromFile loads data from a file. // v2ray data file can also have multiple @attr. e.g. 'geosite.dat:cn@attr1@attr2'. // Only the record with all of the @attr will be loaded. -// Also see LoadFromFile. -func LoadFromFileAsV2Matcher(m Matcher, file string) error { +func LoadFromFile(m Matcher, file string, processAttr ProcessAttrFunc) error { var err error if tmp := strings.SplitN(file, ":", 2); len(tmp) == 2 { // is a v2ray data file mixMatcher, ok := m.(*MixMatcher) @@ -80,12 +78,19 @@ func LoadFromFileAsV2Matcher(m Matcher, file string) error { tmp := strings.Split(tmp[1], "@") countryCode := tmp[0] wantedAttr := tmp[1:] - processAttr := func(attr []string) (v interface{}, accept bool, err error) { - return nil, mustHaveAttr(attr, wantedAttr), nil + v2ProcessAttr := func(attr []string) (v interface{}, accept bool, err error) { + v2Accept := mustHaveAttr(attr, wantedAttr) + if v2Accept { + if processAttr != nil { + return processAttr(attr) + } + return nil, true, nil + } + return nil, false, nil } - err = mixMatcher.LoadFromDAT(filePath, countryCode, processAttr) + err = mixMatcher.LoadFromDAT(filePath, countryCode, v2ProcessAttr) } else { // is a text file - err = LoadFromTextFile(m, file, nil) + err = LoadFromTextFile(m, file, processAttr) } if err != nil { return err @@ -94,28 +99,6 @@ func LoadFromFileAsV2Matcher(m Matcher, file string) error { return nil } -// BatchLoadMatcher loads multiple files using LoadFromFile -func BatchLoadMatcher(m Matcher, f []string, processAttr ProcessAttrFunc) error { - for _, file := range f { - err := LoadFromFile(m, file, processAttr) - if err != nil { - return fmt.Errorf("failed to load file %s: %w", file, err) - } - } - return nil -} - -// BatchLoadMixMatcherV2Matcher loads multiple files using LoadFromFileAsV2Matcher -func BatchLoadMixMatcherV2Matcher(m Matcher, f []string) error { - for _, file := range f { - err := LoadFromFileAsV2Matcher(m, file) - if err != nil { - return fmt.Errorf("failed to load file %s: %w", file, err) - } - } - return nil -} - func LoadFromTextFile(m Matcher, file string, processAttr ProcessAttrFunc) error { data, err := ioutil.ReadFile(file) if err != nil { diff --git a/dispatcher/matcher/domain/load_helper_test.go b/dispatcher/pkg/matcher/domain/load_helper_test.go similarity index 100% rename from dispatcher/matcher/domain/load_helper_test.go rename to dispatcher/pkg/matcher/domain/load_helper_test.go diff --git a/dispatcher/matcher/domain/matcher.go b/dispatcher/pkg/matcher/domain/matcher.go similarity index 99% rename from dispatcher/matcher/domain/matcher.go rename to dispatcher/pkg/matcher/domain/matcher.go index a0b8bbb67..d1227977a 100644 --- a/dispatcher/matcher/domain/matcher.go +++ b/dispatcher/pkg/matcher/domain/matcher.go @@ -19,7 +19,7 @@ package domain import ( "fmt" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" "regexp" "strings" diff --git a/dispatcher/matcher/domain/matcher_test.go b/dispatcher/pkg/matcher/domain/matcher_test.go similarity index 100% rename from dispatcher/matcher/domain/matcher_test.go rename to dispatcher/pkg/matcher/domain/matcher_test.go diff --git a/dispatcher/matcher/elem/rr_type.go b/dispatcher/pkg/matcher/elem/rr_type.go similarity index 100% rename from dispatcher/matcher/elem/rr_type.go rename to dispatcher/pkg/matcher/elem/rr_type.go diff --git a/dispatcher/matcher/elem/rr_type_test.go b/dispatcher/pkg/matcher/elem/rr_type_test.go similarity index 100% rename from dispatcher/matcher/elem/rr_type_test.go rename to dispatcher/pkg/matcher/elem/rr_type_test.go diff --git a/dispatcher/matcher/netlist/interface.go b/dispatcher/pkg/matcher/netlist/interface.go similarity index 100% rename from dispatcher/matcher/netlist/interface.go rename to dispatcher/pkg/matcher/netlist/interface.go diff --git a/dispatcher/matcher/netlist/list.go b/dispatcher/pkg/matcher/netlist/list.go similarity index 100% rename from dispatcher/matcher/netlist/list.go rename to dispatcher/pkg/matcher/netlist/list.go diff --git a/dispatcher/matcher/netlist/load_helper.go b/dispatcher/pkg/matcher/netlist/load_helper.go similarity index 50% rename from dispatcher/matcher/netlist/load_helper.go rename to dispatcher/pkg/matcher/netlist/load_helper.go index b75fdb193..c77f64fa8 100644 --- a/dispatcher/matcher/netlist/load_helper.go +++ b/dispatcher/pkg/matcher/netlist/load_helper.go @@ -20,67 +20,49 @@ package netlist import ( "bufio" "bytes" - "errors" "fmt" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/v2data" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/load_cache" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/v2data" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/golang/protobuf/proto" "io" "io/ioutil" - "net" "strings" "time" ) -var matcherCache = utils.NewCache() +var matcherCache = load_cache.NewCache() const ( cacheTTL = time.Second * 30 ) -type MatcherGroup struct { - m []Matcher -} - -func (mg *MatcherGroup) Match(ip net.IP) bool { - for _, m := range mg.m { - if m.Match(ip) { - return true +// BatchLoad is a helper func to load multiple files using Load. +// It might modify the List and causes List unsorted. +func BatchLoad(l *List, entries []string) error { + for _, file := range entries { + err := Load(l, file) + if err != nil { + return fmt.Errorf("failed to load ip file %s: %w", file, err) } } - return false + return nil } -func NewMatcherGroup(m []Matcher) *MatcherGroup { - return &MatcherGroup{m: m} -} - -// BatchLoad is helper func to load multiple files using NewListFromFile. -func BatchLoad(f []string) (m *List, err error) { - if len(f) == 0 { - return nil, errors.New("no file to load") - } - - if len(f) == 1 { - return NewListFromFile(f[0]) +// Load loads data from entry. +// If entry begin with "ext:", Load loads the file by using LoadFromFile. +// Else it loads the entry as a text pattern by using LoadFromText. +func Load(l *List, entry string) error { + s1, s2, ok := utils.SplitString2(entry, ":") + if ok && s1 == "ext" { + return LoadFromFile(l, s2) } - - list := NewList() - for _, file := range f { - l, err := NewListFromFile(file) - if err != nil { - return nil, fmt.Errorf("failed to load ip file %s: %w", file, err) - } - list.Merge(l) - } - - list.Sort() - return list, nil + return LoadFromText(l, entry) } -//NewListFromReader read IP list from a reader. The returned *List is sorted. -func NewListFromReader(reader io.Reader) (*List, error) { - ipNetList := NewList() +// LoadFromReader loads IP list from a reader. +// It might modify the List and causes List unsorted. +func LoadFromReader(l *List, reader io.Reader) error { scanner := bufio.NewScanner(reader) //count how many lines we have read. @@ -88,69 +70,76 @@ func NewListFromReader(reader io.Reader) (*List, error) { for scanner.Scan() { lineCounter++ - s := strings.TrimSpace(utils.BytesToStringUnsafe(scanner.Bytes())) - s = utils.RemoveComment(s, "#") - s = utils.RemoveComment(s, " ") // remove other strings, e.g. 192.168.1.1 str1 str2 - - if len(s) == 0 { - continue - } - - ipNet, err := ParseCIDR(s) + s := utils.BytesToStringUnsafe(scanner.Bytes()) + err := LoadFromText(l, s) if err != nil { - return nil, fmt.Errorf("invalid CIDR format %s in line %d", s, lineCounter) + return fmt.Errorf("invalid data at line #%d: %w", lineCounter, err) } + } + + return nil +} + +// LoadFromText loads an IP from s. +// It might modify the List and causes List unsorted. +func LoadFromText(l *List, s string) error { + s = strings.TrimSpace(s) + s = utils.RemoveComment(s, "#") + s = utils.RemoveComment(s, " ") // remove other strings, e.g. 192.168.1.1 str1 str2 - ipNetList.Append(ipNet) + if len(s) == 0 { + return nil } - ipNetList.Sort() - return ipNetList, nil + ipNet, err := ParseCIDR(s) + if err != nil { + return err + } + l.Append(ipNet) + return nil } -// NewListFromFile loads ip from a text file or a geoip file. +// LoadFromFile loads ip from a text file or a geoip file. // If file contains a ':' and has format like 'geoip:cn', it will be read as a geoip file. -// The returned *List is already been sorted. -func NewListFromFile(file string) (*List, error) { +// It might modify the List and causes List unsorted. +func LoadFromFile(l *List, file string) error { if strings.Contains(file, ":") { tmp := strings.SplitN(file, ":", 2) - return NewListFromDAT(tmp[0], tmp[1]) // file and tag + return LoadFromDAT(l, tmp[0], tmp[1]) // file and tag } else { - return NewListFromTextFile(file) + return LoadFromTextFile(l, file) } } -// NewListFromTextFile reads IP list from a text file. -// The returned *List is already been sorted. -func NewListFromTextFile(file string) (*List, error) { +// LoadFromTextFile reads IP list from a text file. +// It might modify the List and causes List unsorted. +func LoadFromTextFile(l *List, file string) error { b, err := ioutil.ReadFile(file) if err != nil { - return nil, err + return err } - - return NewListFromReader(bytes.NewBuffer(b)) + return LoadFromReader(l, bytes.NewBuffer(b)) } -// NewListFromDAT loads ip from v2ray proto file. -// The returned *List is already been sorted. -func NewListFromDAT(file, tag string) (*List, error) { +// LoadFromDAT loads ip from v2ray proto file. +// It might modify the List and causes List unsorted. +func LoadFromDAT(l *List, file, tag string) error { geoIP, err := LoadGeoIPFromDAT(file, tag) if err != nil { - return nil, err + return err } - return NewListFromV2CIDR(geoIP.GetCidr()) + return LoadFromV2CIDR(l, geoIP.GetCidr()) } -// NewListFromV2CIDR loads ip from v2ray CIDR. -// The returned *List is already been sorted. -func NewListFromV2CIDR(cidr []*v2data.CIDR) (*List, error) { - l := NewList() - l.Grow(len(cidr)) +// LoadFromV2CIDR loads ip from v2ray CIDR. +// It might modify the List and causes List unsorted. +func LoadFromV2CIDR(l *List, cidr []*v2data.CIDR) error { + l.Grow(l.Len() + len(cidr)) for i, e := range cidr { ipv6, err := Conv(e.Ip) if err != nil { - return nil, fmt.Errorf("invalid data ip at index #%d, %w", i, err) + return fmt.Errorf("invalid data ip at index #%d, %w", i, err) } switch len(e.Ip) { case 4: @@ -158,12 +147,10 @@ func NewListFromV2CIDR(cidr []*v2data.CIDR) (*List, error) { case 16: l.Append(NewNet(ipv6, uint(e.Prefix))) default: - return nil, fmt.Errorf("invalid cidr ip length at #%d", i) + return fmt.Errorf("invalid cidr ip length at #%d", i) } } - - l.Sort() - return l, nil + return nil } func LoadGeoIPFromDAT(file, tag string) (*v2data.GeoIP, error) { diff --git a/dispatcher/matcher/netlist/net.go b/dispatcher/pkg/matcher/netlist/net.go similarity index 98% rename from dispatcher/matcher/netlist/net.go rename to dispatcher/pkg/matcher/netlist/net.go index bfcbd6706..19a1ae522 100644 --- a/dispatcher/matcher/netlist/net.go +++ b/dispatcher/pkg/matcher/netlist/net.go @@ -21,7 +21,7 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "net" "strconv" ) diff --git a/dispatcher/matcher/netlist/netlist_test.go b/dispatcher/pkg/matcher/netlist/netlist_test.go similarity index 96% rename from dispatcher/matcher/netlist/netlist_test.go rename to dispatcher/pkg/matcher/netlist/netlist_test.go index 3e0f26e36..ecf830813 100644 --- a/dispatcher/matcher/netlist/netlist_test.go +++ b/dispatcher/pkg/matcher/netlist/netlist_test.go @@ -55,10 +55,12 @@ var ( ) func TestIPNetList_New_And_Contains(t *testing.T) { - ipNetList, err := NewListFromReader(bytes.NewBufferString(rawList)) + ipNetList := NewList() + err := LoadFromReader(ipNetList, bytes.NewBufferString(rawList)) if err != nil { t.Fatal(err) } + ipNetList.Sort() if ipNetList.Len() != 18 { t.Fatalf("unexpected length %d", ipNetList.Len()) diff --git a/dispatcher/matcher/v2data/data.pb.go b/dispatcher/pkg/matcher/v2data/data.pb.go similarity index 100% rename from dispatcher/matcher/v2data/data.pb.go rename to dispatcher/pkg/matcher/v2data/data.pb.go diff --git a/dispatcher/matcher/v2data/data.proto b/dispatcher/pkg/matcher/v2data/data.proto similarity index 100% rename from dispatcher/matcher/v2data/data.proto rename to dispatcher/pkg/matcher/v2data/data.proto diff --git a/dispatcher/pkg/pool/bytes_buf.go b/dispatcher/pkg/pool/bytes_buf.go new file mode 100644 index 000000000..4d96fd7be --- /dev/null +++ b/dispatcher/pkg/pool/bytes_buf.go @@ -0,0 +1,51 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of mosdns. +// +// mosdns is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// mosdns is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package pool + +import ( + "bytes" + "fmt" + "sync" +) + +type BytesBufPool struct { + p sync.Pool +} + +func NewBytesBufPool(initSize int) *BytesBufPool { + if initSize < 0 { + panic(fmt.Sprintf("utils.NewBytesBufPool: negative init size %d", initSize)) + } + + return &BytesBufPool{ + p: sync.Pool{New: func() interface{} { + b := new(bytes.Buffer) + b.Grow(initSize) + return b + }}, + } +} + +func (p *BytesBufPool) Get() *bytes.Buffer { + return p.p.Get().(*bytes.Buffer) +} + +func (p *BytesBufPool) Release(b *bytes.Buffer) { + b.Reset() + p.p.Put(b) +} diff --git a/dispatcher/utils/msg_buf.go b/dispatcher/pkg/pool/msg_buf.go similarity index 99% rename from dispatcher/utils/msg_buf.go rename to dispatcher/pkg/pool/msg_buf.go index 9992ae801..d67b53c9d 100644 --- a/dispatcher/utils/msg_buf.go +++ b/dispatcher/pkg/pool/msg_buf.go @@ -18,7 +18,7 @@ // This file is a modified version from github.com/xtaci/smux/blob/master/alloc.go f386d90 // license of smux: MIT https://github.com/xtaci/smux/blob/master/LICENSE -package utils +package pool import ( "fmt" diff --git a/dispatcher/utils/timer.go b/dispatcher/pkg/pool/timer.go similarity index 98% rename from dispatcher/utils/timer.go rename to dispatcher/pkg/pool/timer.go index dfc4b13cc..83d60533e 100644 --- a/dispatcher/utils/timer.go +++ b/dispatcher/pkg/pool/timer.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -package utils +package pool import ( "sync" diff --git a/dispatcher/utils/server_handler.go b/dispatcher/pkg/server_handler/server_handler.go similarity index 75% rename from dispatcher/utils/server_handler.go rename to dispatcher/pkg/server_handler/server_handler.go index ea16cb69c..e0de06372 100644 --- a/dispatcher/utils/server_handler.go +++ b/dispatcher/pkg/server_handler/server_handler.go @@ -15,11 +15,13 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -package utils +package server_handler import ( "context" "github.com/IrineSistiana/mosdns/dispatcher/handler" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/concurrent_limiter" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/executable_seq" "github.com/miekg/dns" "go.uber.org/zap" "testing" @@ -38,15 +40,15 @@ type ResponseWriter interface { type DefaultServerHandler struct { config *DefaultServerHandlerConfig - limiter *ConcurrentLimiter // if it's nil, means no limit. - clientLimiter *ClientQueryLimiter // if it's nil, means no limit. + limiter *concurrent_limiter.ConcurrentLimiter // if it's nil, means no limit. + clientLimiter *concurrent_limiter.ClientQueryLimiter // if it's nil, means no limit. } type DefaultServerHandlerConfig struct { // Logger is used for logging, it cannot be nil. Logger *zap.Logger // Entry is the entry ExecutablePlugin's tag. This shouldn't be empty. - Entry *ExecutableCmdSequence + Entry *executable_seq.ExecutableCmdSequence // ConcurrentLimit controls the max concurrent queries for the DefaultServerHandler. // If ConcurrentLimit <= 0, means no limit. // When calling DefaultServerHandler.ServeDNS(), if a query exceeds the limit, it will wait on a FIFO queue until @@ -62,22 +64,22 @@ type DefaultServerHandlerConfig struct { ConcurrentLimitPreClient int } -// NewDefaultServerHandler: +// NewDefaultServerHandler // Also see DefaultServerHandler.ServeDNS. func NewDefaultServerHandler(config *DefaultServerHandlerConfig) *DefaultServerHandler { h := &DefaultServerHandler{config: config} if config.ConcurrentLimit > 0 { - h.limiter = NewConcurrentLimiter(config.ConcurrentLimit) + h.limiter = concurrent_limiter.NewConcurrentLimiter(config.ConcurrentLimit) } if config.ConcurrentLimitPreClient > 0 { - h.clientLimiter = NewClientQueryLimiter(config.ConcurrentLimitPreClient) + h.clientLimiter = concurrent_limiter.NewClientQueryLimiter(config.ConcurrentLimitPreClient) } return h } -// ServeDNS: +// ServeDNS // If entry returns an err, a SERVFAIL response will be sent back to client. // If concurrentLimit is reached, the query will block and wait available token until ctx is done. func (h *DefaultServerHandler) ServeDNS(ctx context.Context, qCtx *handler.Context, w ResponseWriter) { @@ -135,7 +137,7 @@ func (h *DefaultServerHandler) ServeDNS(ctx context.Context, qCtx *handler.Conte } func (h *DefaultServerHandler) execEntry(ctx context.Context, qCtx *handler.Context) error { - err := WalkExecutableCmd(ctx, qCtx, h.config.Logger, h.config.Entry) + err := executable_seq.WalkExecutableCmd(ctx, qCtx, h.config.Logger, h.config.Entry) if err != nil { return err } @@ -164,50 +166,3 @@ func (d *DummyServerHandler) ServeDNS(_ context.Context, qCtx *handler.Context, d.T.Errorf("DummyServerHandler: %v", err) } } - -type ClientQueryLimiter struct { - maxQueries int - m *ConcurrentMap -} - -func NewClientQueryLimiter(maxQueries int) *ClientQueryLimiter { - return &ClientQueryLimiter{ - maxQueries: maxQueries, - m: NewConcurrentMap(64), - } -} - -func (l *ClientQueryLimiter) Acquire(key string) bool { - return l.m.TestAndSet(key, l.acquireTestAndSet) -} - -func (l *ClientQueryLimiter) acquireTestAndSet(v interface{}, ok bool) (newV interface{}, wantUpdate, passed bool) { - n := 0 - if ok { - n = v.(int) - } - if n >= l.maxQueries { - return nil, false, false - } - n++ - return n, true, true -} - -func (l *ClientQueryLimiter) doneTestAndSet(v interface{}, ok bool) (newV interface{}, wantUpdate, passed bool) { - if !ok { - panic("ClientQueryLimiter doneTestAndSet: value is not exist") - } - n := v.(int) - n-- - if n < 0 { - panic("ClientQueryLimiter doneTestAndSet: value becomes negative") - } - if n == 0 { - return nil, true, true - } - return n, true, true -} - -func (l *ClientQueryLimiter) Done(key string) { - l.m.TestAndSet(key, l.doneTestAndSet) -} diff --git a/dispatcher/utils/utils.go b/dispatcher/pkg/utils/utils.go similarity index 76% rename from dispatcher/utils/utils.go rename to dispatcher/pkg/utils/utils.go index 0564d27b9..eb6c65f74 100644 --- a/dispatcher/utils/utils.go +++ b/dispatcher/pkg/utils/utils.go @@ -18,7 +18,6 @@ package utils import ( - "bytes" "context" "crypto/ecdsa" "crypto/elliptic" @@ -342,128 +341,3 @@ func ExchangeParallel(ctx context.Context, qCtx *handler.Context, upstreams []Up // all upstreams are failed return nil, errors.New("no response") } - -// GetMinimalTTL returns the minimal ttl of this msg. -// If msg m has no record, it returns 0. -func GetMinimalTTL(m *dns.Msg) uint32 { - minTTL := ^uint32(0) - hasRecord := false - for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} { - for _, rr := range section { - if rr.Header().Rrtype == dns.TypeOPT { - continue // opt record ttl is not ttl. - } - hasRecord = true - ttl := rr.Header().Ttl - if ttl < minTTL { - minTTL = ttl - } - } - } - - if !hasRecord { // no ttl applied - return 0 - } - return minTTL -} - -// SetTTL updates all records' ttl to ttl, except opt record. -func SetTTL(m *dns.Msg, ttl uint32) { - for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} { - for _, rr := range section { - if rr.Header().Rrtype == dns.TypeOPT { - continue // opt record ttl is not ttl. - } - rr.Header().Ttl = ttl - } - } -} - -func ApplyMaximumTTL(m *dns.Msg, ttl uint32) { - applyTTL(m, ttl, true) -} - -func ApplyMinimalTTL(m *dns.Msg, ttl uint32) { - applyTTL(m, ttl, false) -} - -func applyTTL(m *dns.Msg, ttl uint32, maximum bool) { - for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} { - for _, rr := range section { - if rr.Header().Rrtype == dns.TypeOPT { - continue // opt record ttl is not ttl. - } - if maximum { - if rr.Header().Ttl > ttl { - rr.Header().Ttl = ttl - } - } else { - if rr.Header().Ttl < ttl { - rr.Header().Ttl = ttl - } - } - } - } -} - -type BytesBufPool struct { - p sync.Pool -} - -func NewBytesBufPool(initSize int) *BytesBufPool { - if initSize < 0 { - panic(fmt.Sprintf("utils.NewBytesBufPool: negative init size %d", initSize)) - } - - return &BytesBufPool{ - p: sync.Pool{New: func() interface{} { - b := new(bytes.Buffer) - b.Grow(initSize) - return b - }}, - } -} - -func (p *BytesBufPool) Get() *bytes.Buffer { - return p.p.Get().(*bytes.Buffer) -} - -func (p *BytesBufPool) Release(b *bytes.Buffer) { - b.Reset() - p.p.Put(b) -} - -// ConcurrentLimiter -type ConcurrentLimiter struct { - bucket chan struct{} -} - -// NewConcurrentLimiter returns a ConcurrentLimiter, max must > 0. -func NewConcurrentLimiter(max int) *ConcurrentLimiter { - if max <= 0 { - panic(fmt.Sprintf("ConcurrentLimiter: invalid max arg: %d", max)) - } - - bucket := make(chan struct{}, max) - for i := 0; i < max; i++ { - bucket <- struct{}{} - } - - return &ConcurrentLimiter{bucket: bucket} -} - -func (l *ConcurrentLimiter) Wait() <-chan struct{} { - return l.bucket -} - -func (l *ConcurrentLimiter) Done() { - select { - case l.bucket <- struct{}{}: - default: - panic("ConcurrentLimiter: bucket overflow") - } -} - -func (l *ConcurrentLimiter) Available() int { - return len(l.bucket) -} diff --git a/dispatcher/utils/utils_test.go b/dispatcher/pkg/utils/utils_test.go similarity index 90% rename from dispatcher/utils/utils_test.go rename to dispatcher/pkg/utils/utils_test.go index 599763b83..20b944168 100644 --- a/dispatcher/utils/utils_test.go +++ b/dispatcher/pkg/utils/utils_test.go @@ -21,11 +21,10 @@ import ( "context" "errors" "github.com/IrineSistiana/mosdns/dispatcher/handler" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/concurrent_limiter" "github.com/miekg/dns" "reflect" - "sync" "testing" - "time" ) func TestBoolLogic(t *testing.T) { @@ -189,39 +188,13 @@ func Test_NewConcurrentLimiter(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewConcurrentLimiter(tt.args.max); got.Available() != tt.wantLen { + if got := concurrent_limiter.NewConcurrentLimiter(tt.args.max); got.Available() != tt.wantLen { t.Errorf("NewConcurrentLimiter() = %v, want %v", got.Available(), tt.wantLen) } }) } } -func Test_ConcurrentLimiter_acquire_release(t *testing.T) { - l := NewConcurrentLimiter(500) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) - defer cancel() - - wg := new(sync.WaitGroup) - wg.Add(1000) - for i := 0; i < 1000; i++ { - go func() { - defer wg.Done() - select { - case <-l.Wait(): - time.Sleep(time.Millisecond * 200) - l.Done() - case <-ctx.Done(): - t.Fail() - } - }() - } - - wg.Wait() - if l.Available() != 500 { - t.Fatal("token leaked") - } -} - func TestRemoveComment(t *testing.T) { type args struct { s string diff --git a/dispatcher/plugin/cache/mem_cache.go b/dispatcher/plugin/cache/mem_cache.go index 58928f314..58dc071a3 100644 --- a/dispatcher/plugin/cache/mem_cache.go +++ b/dispatcher/plugin/cache/mem_cache.go @@ -19,7 +19,7 @@ package cache import ( "context" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/concurrent_lru" "github.com/miekg/dns" "sync" "time" @@ -31,7 +31,7 @@ type memCache struct { closeOnce sync.Once closeChan chan struct{} - lru *utils.ConcurrentLRU + lru *concurrent_lru.ConcurrentLRU } type elem struct { @@ -45,7 +45,7 @@ type elem struct { func newMemCache(shardNum, maxSizePerShard int, cleanerInterval time.Duration) *memCache { c := &memCache{ cleanerInterval: cleanerInterval, - lru: utils.NewConcurrentLRU(shardNum, maxSizePerShard, nil, nil), + lru: concurrent_lru.NewConcurrentLRU(shardNum, maxSizePerShard, nil, nil), } if c.cleanerInterval > 0 { diff --git a/dispatcher/plugin/cache/plugin.go b/dispatcher/plugin/cache/plugin.go index b77575757..20d274620 100644 --- a/dispatcher/plugin/cache/plugin.go +++ b/dispatcher/plugin/cache/plugin.go @@ -20,7 +20,8 @@ import ( "context" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" "go.uber.org/zap" "time" @@ -130,7 +131,7 @@ func (c *cachePlugin) searchAndReply(ctx context.Context, qCtx *handler.Context) if r != nil { // if cache hit c.L().Debug("cache hit", qCtx.InfoField()) r.Id = q.Id - utils.SetTTL(r, uint32(ttl/time.Second)) + dnsutils.SetTTL(r, uint32(ttl/time.Second)) qCtx.SetResponse(r, handler.ContextStatusResponded) return key, true } @@ -158,7 +159,7 @@ func (d *deferCacheStore) Exec(ctx context.Context, qCtx *handler.Context) (err func (d *deferCacheStore) exec(ctx context.Context, qCtx *handler.Context) (err error) { r := qCtx.R() if r != nil && r.Rcode == dns.RcodeSuccess && r.Truncated == false && len(r.Answer) != 0 { - ttl := utils.GetMinimalTTL(r) + ttl := dnsutils.GetMinimalTTL(r) if ttl > maxTTL { ttl = maxTTL } diff --git a/dispatcher/plugin/cache/redis_cache.go b/dispatcher/plugin/cache/redis_cache.go index aee5e7ea5..311f5773f 100644 --- a/dispatcher/plugin/cache/redis_cache.go +++ b/dispatcher/plugin/cache/redis_cache.go @@ -19,7 +19,7 @@ package cache import ( "context" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool" "github.com/go-redis/redis/v8" "github.com/miekg/dns" "time" @@ -63,10 +63,10 @@ func (r *redisCache) get(ctx context.Context, key string) (v *dns.Msg, ttl time. } func (r *redisCache) store(ctx context.Context, key string, v *dns.Msg, ttl time.Duration) (err error) { - wireMsg, buf, err := utils.PackBuffer(v) + wireMsg, buf, err := pool.PackBuffer(v) if err != nil { return err } - defer utils.ReleaseMsgBuf(buf) + defer pool.ReleaseMsgBuf(buf) return r.client.Set(ctx, key, wireMsg, ttl).Err() } diff --git a/dispatcher/plugin/executable/ecs/ecs.go b/dispatcher/plugin/executable/ecs/ecs.go index e371ba236..3e6e49ec6 100644 --- a/dispatcher/plugin/executable/ecs/ecs.go +++ b/dispatcher/plugin/executable/ecs/ecs.go @@ -21,7 +21,7 @@ import ( "context" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" "go.uber.org/zap" "net" diff --git a/dispatcher/plugin/executable/ecs/ecs_test.go b/dispatcher/plugin/executable/ecs/ecs_test.go index c1a93b59d..6cf518643 100644 --- a/dispatcher/plugin/executable/ecs/ecs_test.go +++ b/dispatcher/plugin/executable/ecs/ecs_test.go @@ -20,7 +20,7 @@ package ecs import ( "context" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" "gopkg.in/yaml.v3" "net" diff --git a/dispatcher/plugin/executable/fallback/fallback.go b/dispatcher/plugin/executable/fallback/fallback.go index afc5e7f82..7b022f893 100644 --- a/dispatcher/plugin/executable/fallback/fallback.go +++ b/dispatcher/plugin/executable/fallback/fallback.go @@ -20,7 +20,7 @@ package fallback import ( "context" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/executable_seq" ) const PluginType = "fallback" @@ -34,17 +34,17 @@ var _ handler.ExecutablePlugin = (*fallback)(nil) type fallback struct { *handler.BP - fallbackECS *utils.FallbackECS + fallbackECS *executable_seq.FallbackECS } -type Args = utils.FallbackConfig +type Args = executable_seq.FallbackConfig func Init(bp *handler.BP, args interface{}) (p handler.Plugin, err error) { return newFallback(bp, args.(*Args)) } func newFallback(bp *handler.BP, args *Args) (*fallback, error) { - fallbackECS, err := utils.ParseFallbackECS(args) + fallbackECS, err := executable_seq.ParseFallbackECS(args) if err != nil { return nil, err } @@ -55,5 +55,5 @@ func newFallback(bp *handler.BP, args *Args) (*fallback, error) { } func (f *fallback) Exec(ctx context.Context, qCtx *handler.Context) (err error) { - return utils.WalkExecutableCmd(ctx, qCtx, f.L(), f.fallbackECS) + return executable_seq.WalkExecutableCmd(ctx, qCtx, f.L(), f.fallbackECS) } diff --git a/dispatcher/plugin/executable/fast_forward/fast_forward.go b/dispatcher/plugin/executable/fast_forward/fast_forward.go index 004fd34be..f7dc8e5fe 100644 --- a/dispatcher/plugin/executable/fast_forward/fast_forward.go +++ b/dispatcher/plugin/executable/fast_forward/fast_forward.go @@ -23,7 +23,7 @@ import ( "errors" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" "time" ) diff --git a/dispatcher/plugin/executable/fast_forward/transport.go b/dispatcher/plugin/executable/fast_forward/transport.go index a30d6c2be..6d0168051 100644 --- a/dispatcher/plugin/executable/fast_forward/transport.go +++ b/dispatcher/plugin/executable/fast_forward/transport.go @@ -20,7 +20,7 @@ package fastforward import ( "errors" "fmt" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool" "github.com/miekg/dns" "go.uber.org/zap" "io" @@ -143,8 +143,8 @@ func (t *transport) getConn() (conn *dnsConn, reusedConn bool, err error) { panic("Transport getConn: dCall is nil") } - timer := utils.GetTimer(t.readTimeout) - defer utils.ReleaseTimer(timer) + timer := pool.GetTimer(t.readTimeout) + defer pool.ReleaseTimer(timer) select { case <-timer.C: return nil, false, errDialTimeout @@ -234,8 +234,8 @@ func (c *dnsConn) exchange(m *dns.Msg) (r *dns.Msg, err error) { return nil, err } - timer := utils.GetTimer(c.t.readTimeout) - defer utils.ReleaseTimer(timer) + timer := pool.GetTimer(c.t.readTimeout) + defer pool.ReleaseTimer(timer) select { case <-timer.C: diff --git a/dispatcher/plugin/executable/fast_forward/upstream.go b/dispatcher/plugin/executable/fast_forward/upstream.go index 9da0fdefb..7e96d9958 100644 --- a/dispatcher/plugin/executable/fast_forward/upstream.go +++ b/dispatcher/plugin/executable/fast_forward/upstream.go @@ -24,7 +24,8 @@ import ( "errors" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" "go.uber.org/zap" "golang.org/x/net/http2" @@ -139,9 +140,9 @@ func newFastUpstream(config *UpstreamConfig, logger *zap.Logger, certPool *x509. func() (net.Conn, error) { return u.dialTimeout("udp", dialTimeout) }, - utils.WriteMsgToUDP, + dnsutils.WriteMsgToUDP, func(c io.Reader) (m *dns.Msg, n int, err error) { - return utils.ReadMsgFromUDP(c, utils.IPv4UdpMaxPayload) + return dnsutils.ReadMsgFromUDP(c, dnsutils.IPv4UdpMaxPayload) }, maxConn, time.Second*30, @@ -182,8 +183,8 @@ func newFastUpstream(config *UpstreamConfig, logger *zap.Logger, certPool *x509. u.tcpTransport = newTransport( logger, dialFunc, - utils.WriteMsgToTCP, - utils.ReadMsgFromTCP, + dnsutils.WriteMsgToTCP, + dnsutils.ReadMsgFromTCP, maxConn, idleTimeout, timeout, diff --git a/dispatcher/plugin/executable/fast_forward/upstream_doh.go b/dispatcher/plugin/executable/fast_forward/upstream_doh.go index e7714ae8e..76c0a4a45 100644 --- a/dispatcher/plugin/executable/fast_forward/upstream_doh.go +++ b/dispatcher/plugin/executable/fast_forward/upstream_doh.go @@ -21,23 +21,23 @@ import ( "context" "encoding/base64" "fmt" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool" "github.com/miekg/dns" "io" "net/http" ) var ( - bufPool512 = utils.NewBytesBufPool(512) + bufPool512 = pool.NewBytesBufPool(512) ) func (u *fastUpstream) exchangeDoH(q *dns.Msg) (r *dns.Msg, err error) { - rRaw, buf, err := utils.PackBuffer(q) + rRaw, buf, err := pool.PackBuffer(q) if err != nil { return nil, err } - defer utils.ReleaseMsgBuf(buf) + defer pool.ReleaseMsgBuf(buf) // In order to maximize HTTP cache friendliness, DoH clients using media // formats that include the ID field from the DNS message header, such diff --git a/dispatcher/plugin/executable/fast_forward/upstream_test.go b/dispatcher/plugin/executable/fast_forward/upstream_test.go index 17d42d5c0..a03b3da6f 100644 --- a/dispatcher/plugin/executable/fast_forward/upstream_test.go +++ b/dispatcher/plugin/executable/fast_forward/upstream_test.go @@ -23,7 +23,7 @@ import ( "fmt" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" "go.uber.org/zap" "net" diff --git a/dispatcher/plugin/executable/forward/forward.go b/dispatcher/plugin/executable/forward/forward.go index f78b0d0dc..44fe9efb4 100644 --- a/dispatcher/plugin/executable/forward/forward.go +++ b/dispatcher/plugin/executable/forward/forward.go @@ -24,7 +24,7 @@ import ( "github.com/AdguardTeam/dnsproxy/fastip" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" "net" "time" diff --git a/dispatcher/plugin/executable/parallel/parallel.go b/dispatcher/plugin/executable/parallel/parallel.go index 08c4cbbad..e654f6486 100644 --- a/dispatcher/plugin/executable/parallel/parallel.go +++ b/dispatcher/plugin/executable/parallel/parallel.go @@ -20,7 +20,7 @@ package parallel import ( "context" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/executable_seq" ) const PluginType = "parallel" @@ -34,17 +34,17 @@ var _ handler.ExecutablePlugin = (*parallel)(nil) type parallel struct { *handler.BP - ps *utils.ParallelECS + ps *executable_seq.ParallelECS } -type Args = utils.ParallelECSConfig +type Args = executable_seq.ParallelECSConfig func Init(bp *handler.BP, args interface{}) (p handler.Plugin, err error) { return newParallel(bp, args.(*Args)) } func newParallel(bp *handler.BP, args *Args) (*parallel, error) { - ps, err := utils.ParseParallelECS(args) + ps, err := executable_seq.ParseParallelECS(args) if err != nil { return nil, err } @@ -56,5 +56,5 @@ func newParallel(bp *handler.BP, args *Args) (*parallel, error) { } func (p *parallel) Exec(ctx context.Context, qCtx *handler.Context) (err error) { - return utils.WalkExecutableCmd(ctx, qCtx, p.L(), p.ps) + return executable_seq.WalkExecutableCmd(ctx, qCtx, p.L(), p.ps) } diff --git a/dispatcher/plugin/executable/sequence/sequence.go b/dispatcher/plugin/executable/sequence/sequence.go index eb33b873f..b2cfffe97 100644 --- a/dispatcher/plugin/executable/sequence/sequence.go +++ b/dispatcher/plugin/executable/sequence/sequence.go @@ -21,7 +21,7 @@ import ( "context" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/executable_seq" ) const PluginType = "sequence" @@ -37,7 +37,7 @@ var _ handler.ExecutablePlugin = (*sequenceRouter)(nil) type sequenceRouter struct { *handler.BP - ecs *utils.ExecutableCmdSequence + ecs *executable_seq.ExecutableCmdSequence } type Args struct { @@ -49,7 +49,7 @@ func Init(bp *handler.BP, args interface{}) (p handler.Plugin, err error) { } func newSequencePlugin(bp *handler.BP, args *Args) (*sequenceRouter, error) { - ecs, err := utils.ParseExecutableCmdSequence(args.Exec) + ecs, err := executable_seq.ParseExecutableCmdSequence(args.Exec) if err != nil { return nil, fmt.Errorf("invalid exec squence: %w", err) } @@ -61,7 +61,7 @@ func newSequencePlugin(bp *handler.BP, args *Args) (*sequenceRouter, error) { } func (s *sequenceRouter) Exec(ctx context.Context, qCtx *handler.Context) (err error) { - return utils.WalkExecutableCmd(ctx, qCtx, s.L(), s.ecs) + return executable_seq.WalkExecutableCmd(ctx, qCtx, s.L(), s.ecs) } var _ handler.ExecutablePlugin = (*noop)(nil) diff --git a/dispatcher/plugin/executable/sleep/sleep.go b/dispatcher/plugin/executable/sleep/sleep.go index 803b1f603..c327cfa05 100644 --- a/dispatcher/plugin/executable/sleep/sleep.go +++ b/dispatcher/plugin/executable/sleep/sleep.go @@ -20,7 +20,7 @@ package sleep import ( "context" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool" "time" ) @@ -51,8 +51,8 @@ func (s *sleep) sleep(ctx context.Context) (err error) { return } - timer := utils.GetTimer(s.d) - defer utils.ReleaseTimer(timer) + timer := pool.GetTimer(s.d) + defer pool.ReleaseTimer(timer) select { case <-timer.C: case <-ctx.Done(): diff --git a/dispatcher/plugin/executable/ttl/ttl.go b/dispatcher/plugin/executable/ttl/ttl.go index 55cad8ff8..09973da88 100644 --- a/dispatcher/plugin/executable/ttl/ttl.go +++ b/dispatcher/plugin/executable/ttl/ttl.go @@ -20,7 +20,7 @@ package ttl import ( "context" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils" "github.com/miekg/dns" ) @@ -64,9 +64,9 @@ func (t ttl) Exec(ctx context.Context, qCtx *handler.Context) (err error) { func (t ttl) exec(r *dns.Msg) { if t.args.MaximumTTL > 0 { - utils.ApplyMaximumTTL(r, t.args.MaximumTTL) + dnsutils.ApplyMaximumTTL(r, t.args.MaximumTTL) } if t.args.MinimalTTL > 0 { - utils.ApplyMinimalTTL(r, t.args.MinimalTTL) + dnsutils.ApplyMinimalTTL(r, t.args.MinimalTTL) } } diff --git a/dispatcher/plugin/hosts/hosts.go b/dispatcher/plugin/hosts/hosts.go index f6f718c74..2153d6c04 100644 --- a/dispatcher/plugin/hosts/hosts.go +++ b/dispatcher/plugin/hosts/hosts.go @@ -19,10 +19,9 @@ package hosts import ( "context" - "errors" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain" "github.com/miekg/dns" "net" ) @@ -59,10 +58,6 @@ var patternTypeMap = map[string]domain.MixMatcherPatternType{ } func newHostsContainer(bp *handler.BP, args *Args) (*hostsContainer, error) { - if len(args.Hosts) == 0 { - return nil, errors.New("no hosts file is configured") - } - mixMatcher := domain.NewMixMatcher() mixMatcher.SetPattenTypeMap(patternTypeMap) err := domain.BatchLoadMatcher(mixMatcher, args.Hosts, parseIP) @@ -90,8 +85,14 @@ func (h *hostsContainer) matchAndSet(qCtx *handler.Context) (matched bool) { if len(qCtx.Q().Question) != 1 { return false } - + if qCtx.Q().Question[0].Qclass != dns.ClassINET { + return false + } typ := qCtx.Q().Question[0].Qtype + if typ != dns.TypeA && typ != dns.TypeAAAA { + return false + } + fqdn := qCtx.Q().Question[0].Name v, ok := h.matcher.Match(fqdn) if !ok { @@ -99,50 +100,46 @@ func (h *hostsContainer) matchAndSet(qCtx *handler.Context) (matched bool) { } record := v.(*ipRecord) - switch typ { - case dns.TypeA: - if len(record.ipv4) != 0 { - r := new(dns.Msg) - r.SetReply(qCtx.Q()) - for _, ip := range record.ipv4 { - ipCopy := make(net.IP, len(ip)) - copy(ipCopy, ip) - rr := &dns.A{ - Hdr: dns.RR_Header{ - Name: fqdn, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 3600, - }, - A: ipCopy, - } - r.Answer = append(r.Answer, rr) + switch { + case typ == dns.TypeA && len(record.ipv4) > 0: + r := new(dns.Msg) + r.SetReply(qCtx.Q()) + for _, ip := range record.ipv4 { + ipCopy := make(net.IP, len(ip)) + copy(ipCopy, ip) + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: fqdn, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 3600, + }, + A: ipCopy, } - qCtx.SetResponse(r, handler.ContextStatusResponded) - return true + r.Answer = append(r.Answer, rr) } - - case dns.TypeAAAA: - if len(record.ipv6) != 0 { - r := new(dns.Msg) - r.SetReply(qCtx.Q()) - for _, ip := range record.ipv6 { - ipCopy := make(net.IP, len(ip)) - copy(ipCopy, ip) - rr := &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: fqdn, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 3600, - }, - AAAA: ipCopy, - } - r.Answer = append(r.Answer, rr) + qCtx.SetResponse(r, handler.ContextStatusResponded) + return true + + case typ == dns.TypeAAAA && len(record.ipv6) > 0: + r := new(dns.Msg) + r.SetReply(qCtx.Q()) + for _, ip := range record.ipv6 { + ipCopy := make(net.IP, len(ip)) + copy(ipCopy, ip) + rr := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: fqdn, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 3600, + }, + AAAA: ipCopy, } - qCtx.SetResponse(r, handler.ContextStatusResponded) - return true + r.Answer = append(r.Answer, rr) } + qCtx.SetResponse(r, handler.ContextStatusResponded) + return true } return false } diff --git a/dispatcher/plugin/hosts/hosts_test.go b/dispatcher/plugin/hosts/hosts_test.go index 5a30ac527..3678fdab4 100644 --- a/dispatcher/plugin/hosts/hosts_test.go +++ b/dispatcher/plugin/hosts/hosts_test.go @@ -20,7 +20,7 @@ package hosts import ( "bytes" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain" "github.com/miekg/dns" "net" "testing" diff --git a/dispatcher/plugin/matcher/query_matcher/matcher_group.go b/dispatcher/plugin/matcher/query_matcher/matcher_group.go index 152757680..513d8ac4b 100644 --- a/dispatcher/plugin/matcher/query_matcher/matcher_group.go +++ b/dispatcher/plugin/matcher/query_matcher/matcher_group.go @@ -21,10 +21,10 @@ import ( "context" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/elem" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/netlist" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/elem" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/netlist" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" ) type clientIPMatcher struct { diff --git a/dispatcher/plugin/matcher/query_matcher/query_matcher.go b/dispatcher/plugin/matcher/query_matcher/query_matcher.go index 61cee17de..3190307d4 100644 --- a/dispatcher/plugin/matcher/query_matcher/query_matcher.go +++ b/dispatcher/plugin/matcher/query_matcher/query_matcher.go @@ -21,10 +21,10 @@ import ( "context" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/elem" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/netlist" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/elem" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/netlist" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" ) @@ -73,15 +73,17 @@ func newQueryMatcher(bp *handler.BP, args *Args) (m *queryMatcher, err error) { m.args = args if len(args.ClientIP) > 0 { - ipMatcher, err := netlist.BatchLoad(args.ClientIP) + ipMatcher := netlist.NewList() + err := netlist.BatchLoad(ipMatcher, args.ClientIP) if err != nil { return nil, err } + ipMatcher.Sort() m.matcherGroup = append(m.matcherGroup, newClientIPMatcher(ipMatcher)) } if len(args.Domain) > 0 { mixMatcher := domain.NewMixMatcher() - err := domain.BatchLoadMixMatcherV2Matcher(mixMatcher, args.Domain) + err := domain.BatchLoadMatcher(mixMatcher, args.Domain, nil) if err != nil { return nil, err } diff --git a/dispatcher/plugin/matcher/response_matcher/matcher_group.go b/dispatcher/plugin/matcher/response_matcher/matcher_group.go index 4609f6b53..a1e2036ce 100644 --- a/dispatcher/plugin/matcher/response_matcher/matcher_group.go +++ b/dispatcher/plugin/matcher/response_matcher/matcher_group.go @@ -20,9 +20,9 @@ package responsematcher import ( "context" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/elem" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/netlist" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/elem" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/netlist" "github.com/miekg/dns" "net" ) diff --git a/dispatcher/plugin/matcher/response_matcher/response_matcher.go b/dispatcher/plugin/matcher/response_matcher/response_matcher.go index 1b4a33737..cd2c0a388 100644 --- a/dispatcher/plugin/matcher/response_matcher/response_matcher.go +++ b/dispatcher/plugin/matcher/response_matcher/response_matcher.go @@ -21,10 +21,10 @@ import ( "context" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/elem" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/netlist" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/elem" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/netlist" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" ) @@ -71,7 +71,7 @@ func newResponseMatcher(bp *handler.BP, args *Args) (m *responseMatcher, err err if len(args.CNAME) > 0 { mixMatcher := domain.NewMixMatcher() - err := domain.BatchLoadMixMatcherV2Matcher(mixMatcher, args.CNAME) + err := domain.BatchLoadMatcher(mixMatcher, args.CNAME, nil) if err != nil { return nil, err } @@ -79,10 +79,12 @@ func newResponseMatcher(bp *handler.BP, args *Args) (m *responseMatcher, err err } if len(args.IP) > 0 { - ipMatcher, err := netlist.BatchLoad(args.IP) + ipMatcher := netlist.NewList() + err := netlist.BatchLoad(ipMatcher, args.IP) if err != nil { return nil, err } + ipMatcher.Sort() m.matcherGroup = append(m.matcherGroup, newResponseIPMatcher(ipMatcher)) } diff --git a/dispatcher/plugin/server/doh.go b/dispatcher/plugin/server/doh.go index 2fef01c82..3e759d1f4 100644 --- a/dispatcher/plugin/server/doh.go +++ b/dispatcher/plugin/server/doh.go @@ -23,7 +23,9 @@ import ( "errors" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" "go.uber.org/zap" "io" @@ -123,8 +125,8 @@ func getMsgFromReq(req *http.Request) (*dns.Msg, error) { if msgSize > dns.MaxMsgSize { return nil, fmt.Errorf("query length %d is too big", msgSize) } - msgBuf := utils.GetMsgBuf(msgSize) - defer utils.ReleaseMsgBuf(msgBuf) + msgBuf := pool.GetMsgBuf(msgSize) + defer pool.ReleaseMsgBuf(msgBuf) strBuf := readBufPool.Get() defer readBufPool.Release(strBuf) @@ -155,12 +157,12 @@ func getMsgFromReq(req *http.Request) (*dns.Msg, error) { return q, nil } -var readBufPool = utils.NewBytesBufPool(512) +var readBufPool = pool.NewBytesBufPool(512) type httpDnsRespWriter struct { httpRespWriter http.ResponseWriter } func (h *httpDnsRespWriter) Write(m *dns.Msg) (n int, err error) { - return utils.WriteMsgToUDP(h.httpRespWriter, m) + return dnsutils.WriteMsgToUDP(h.httpRespWriter, m) } diff --git a/dispatcher/plugin/server/server.go b/dispatcher/plugin/server/server.go index c1e935929..816dbd051 100644 --- a/dispatcher/plugin/server/server.go +++ b/dispatcher/plugin/server/server.go @@ -21,7 +21,9 @@ import ( "errors" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/executable_seq" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/server_handler" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "io" "sync" "time" @@ -37,7 +39,7 @@ type ServerGroup struct { *handler.BP configs []*Server - handler utils.ServerHandler + handler server_handler.ServerHandler m sync.Mutex activated bool @@ -98,12 +100,12 @@ func newServerPlugin(bp *handler.BP, args *Args) (*ServerGroup, error) { return nil, errors.New("empty entry") } - ecs, err := utils.ParseExecutableCmdSequence(args.Entry) + ecs, err := executable_seq.ParseExecutableCmdSequence(args.Entry) if err != nil { return nil, err } - sh := utils.NewDefaultServerHandler(&utils.DefaultServerHandlerConfig{ + sh := server_handler.NewDefaultServerHandler(&server_handler.DefaultServerHandlerConfig{ Logger: bp.L(), Entry: ecs, ConcurrentLimit: args.MaxConcurrentQueries, @@ -123,7 +125,7 @@ func newServerPlugin(bp *handler.BP, args *Args) (*ServerGroup, error) { return sg, nil } -func NewServerGroup(bp *handler.BP, handler utils.ServerHandler, configs []*Server) *ServerGroup { +func NewServerGroup(bp *handler.BP, handler server_handler.ServerHandler, configs []*Server) *ServerGroup { s := &ServerGroup{ BP: bp, configs: configs, diff --git a/dispatcher/plugin/server/server_test.go b/dispatcher/plugin/server/server_test.go index e98478200..499dd8b12 100644 --- a/dispatcher/plugin/server/server_test.go +++ b/dispatcher/plugin/server/server_test.go @@ -20,7 +20,7 @@ package server import ( "github.com/AdguardTeam/dnsproxy/upstream" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/server_handler" "github.com/miekg/dns" "testing" "time" @@ -46,7 +46,7 @@ func TestUdpServer_ListenAndServe(t *testing.T) { return } func() { - sg := NewServerGroup(handler.NewBP("test", PluginType), &utils.DummyServerHandler{T: t}, []*Server{tt.config}) + sg := NewServerGroup(handler.NewBP("test", PluginType), &server_handler.DummyServerHandler{T: t}, []*Server{tt.config}) if err := sg.Activate(); err != nil { t.Fatal(err) } diff --git a/dispatcher/plugin/server/tcp.go b/dispatcher/plugin/server/tcp.go index ba35d1c0b..1221f2c33 100644 --- a/dispatcher/plugin/server/tcp.go +++ b/dispatcher/plugin/server/tcp.go @@ -23,7 +23,7 @@ import ( "errors" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils" "github.com/miekg/dns" "go.uber.org/zap" "net" @@ -40,7 +40,7 @@ type tcpResponseWriter struct { func (t *tcpResponseWriter) Write(m *dns.Msg) (n int, err error) { t.c.SetWriteDeadline(time.Now().Add(serverTCPWriteTimeout)) - return utils.WriteMsgToTCP(t.c, m) + return dnsutils.WriteMsgToTCP(t.c, m) } // remainder: startTCP should be called only after ServerGroup is locked. @@ -93,7 +93,7 @@ func (sg *ServerGroup) startTCP(conf *Server, isDoT bool) error { for { c.SetReadDeadline(time.Now().Add(conf.idleTimeout)) - q, _, err := utils.ReadMsgFromTCP(c) + q, _, err := dnsutils.ReadMsgFromTCP(c) if err != nil { return // read err, close the conn } diff --git a/dispatcher/plugin/server/udp.go b/dispatcher/plugin/server/udp.go index cab4ee0fa..f2753f051 100644 --- a/dispatcher/plugin/server/udp.go +++ b/dispatcher/plugin/server/udp.go @@ -21,7 +21,7 @@ import ( "context" "fmt" "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils" "github.com/miekg/dns" "go.uber.org/zap" "net" @@ -44,7 +44,7 @@ func getMaxSizeFromQuery(m *dns.Msg) int { func (u *udpResponseWriter) Write(m *dns.Msg) (n int, err error) { m.Truncate(u.maxSize) - return utils.WriteUDPMsgTo(m, u.c, u.to) + return dnsutils.WriteUDPMsgTo(m, u.c, u.to) } // remainder: startUDP should be called only after ServerGroup is locked. @@ -63,9 +63,9 @@ func (sg *ServerGroup) startUDP(conf *Server) error { defer cancel() for { - q, from, _, err := utils.ReadUDPMsgFrom(c, utils.IPv4UdpMaxPayload) + q, from, _, err := dnsutils.ReadUDPMsgFrom(c, dnsutils.IPv4UdpMaxPayload) if err != nil { - if ioErr := utils.IsIOErr(err); ioErr != nil { + if ioErr := dnsutils.IsIOErr(err); ioErr != nil { if netErr, ok := ioErr.(net.Error); ok && netErr.Temporary() { // is a temporary net err sg.L().Warn("listener temporary err", zap.Stringer("addr", c.LocalAddr()), zap.Error(err)) time.Sleep(time.Second * 5) diff --git a/dispatcher/utils/executable_sequence.go b/dispatcher/utils/executable_sequence.go deleted file mode 100644 index ca6a02b77..000000000 --- a/dispatcher/utils/executable_sequence.go +++ /dev/null @@ -1,631 +0,0 @@ -// Copyright (C) 2020-2021, IrineSistiana -// -// This file is part of mosdns. -// -// mosdns is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// mosdns is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . - -package utils - -import ( - "context" - "errors" - "fmt" - "github.com/IrineSistiana/mosdns/dispatcher/handler" - "github.com/miekg/dns" - "go.uber.org/zap" - "reflect" - "strings" - "sync" - "time" -) - -type ExecutableCmd interface { - ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) -} - -type executablePluginTag struct { - s string -} - -func (t executablePluginTag) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) { - p, err := handler.GetPlugin(t.s) - if err != nil { - return "", false, err - } - - logger.Debug("exec executable plugin", qCtx.InfoField(), zap.String("exec", t.s)) - earlyStop, err = p.ExecES(ctx, qCtx) - return "", earlyStop, err -} - -type IfBlockConfig struct { - If []string `yaml:"if"` - IfAnd []string `yaml:"if_and"` - Exec []interface{} `yaml:"exec"` - Goto string `yaml:"goto"` -} - -type matcher struct { - tag string - negate bool -} - -func paresMatcher(s []string) []matcher { - m := make([]matcher, 0, len(s)) - for _, tag := range s { - if strings.HasPrefix(tag, "!") { - m = append(m, matcher{tag: strings.TrimPrefix(tag, "!"), negate: true}) - } else { - m = append(m, matcher{tag: tag}) - } - } - return m -} - -type IfBlock struct { - ifMatcher []matcher - ifAndMatcher []matcher - executableCmd ExecutableCmd - goTwo string -} - -func (b *IfBlock) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) { - if len(b.ifMatcher) > 0 { - If, err := ifCondition(ctx, qCtx, logger, b.ifMatcher, false) - if err != nil { - return "", false, err - } - if If == false { - return "", false, nil // if case returns false, skip this block. - } - } - - if len(b.ifAndMatcher) > 0 { - If, err := ifCondition(ctx, qCtx, logger, b.ifAndMatcher, true) - if err != nil { - return "", false, err - } - if If == false { - return "", false, nil - } - } - - // exec - if b.executableCmd != nil { - goTwo, earlyStop, err = b.executableCmd.ExecCmd(ctx, qCtx, logger) - if err != nil { - return "", false, err - } - if len(goTwo) != 0 || earlyStop { - return goTwo, earlyStop, nil - } - } - - // goto - if len(b.goTwo) != 0 { // if block has a goto, return it - return b.goTwo, false, nil - } - - return "", false, nil -} - -func ifCondition(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, p []matcher, isAnd bool) (ok bool, err error) { - if len(p) == 0 { - return false, err - } - - for _, m := range p { - mp, err := handler.GetPlugin(m.tag) - if err != nil { - return false, err - } - matched, err := mp.Match(ctx, qCtx) - if err != nil { - return false, err - } - logger.Debug("exec matcher plugin", qCtx.InfoField(), zap.String("exec", m.tag), zap.Bool("result", matched)) - - res := matched != m.negate - if !isAnd && res == true { - return true, nil // or: if one of the case is true, skip others. - } - if isAnd && res == false { - return false, nil // and: if one of the case is false, skip others. - } - - ok = res - } - return ok, nil -} - -func ParseIfBlock(in map[string]interface{}) (*IfBlock, error) { - c := new(IfBlockConfig) - err := handler.WeakDecode(in, c) - if err != nil { - return nil, err - } - - b := &IfBlock{ - ifMatcher: paresMatcher(c.If), - ifAndMatcher: paresMatcher(c.IfAnd), - goTwo: c.Goto, - } - - if len(c.Exec) != 0 { - ecs, err := ParseExecutableCmdSequence(c.Exec) - if err != nil { - return nil, err - } - b.executableCmd = ecs - } - - return b, nil -} - -type ParallelECS struct { - s []*ExecutableCmdSequence - timeout time.Duration -} - -type ParallelECSConfig struct { - Parallel [][]interface{} `yaml:"parallel"` - Timeout uint `yaml:"timeout"` -} - -func ParseParallelECS(c *ParallelECSConfig) (*ParallelECS, error) { - if len(c.Parallel) < 2 { - return nil, fmt.Errorf("parallel needs at least 2 cmd sequences, but got %d", len(c.Parallel)) - } - - ps := make([]*ExecutableCmdSequence, 0, len(c.Parallel)) - for i, subSequence := range c.Parallel { - es, err := ParseExecutableCmdSequence(subSequence) - if err != nil { - return nil, fmt.Errorf("invalid parallel sequence at index %d: %w", i, err) - } - ps = append(ps, es) - } - return &ParallelECS{s: ps, timeout: time.Duration(c.Timeout) * time.Second}, nil -} - -type parallelECSResult struct { - r *dns.Msg - status handler.ContextStatus - err error - from int -} - -func (p *ParallelECS) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) { - return "", false, p.execCmd(ctx, qCtx, logger) -} - -func (p *ParallelECS) execCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { - - var pCtx context.Context // only valid if p.timeout == 0 - var cancel func() - if p.timeout == 0 { - pCtx, cancel = context.WithCancel(ctx) - defer cancel() - } - - t := len(p.s) - c := make(chan *parallelECSResult, len(p.s)) // use buf chan to avoid block. - - for i, sequence := range p.s { - i := i - sequence := sequence - qCtxCopy := qCtx.Copy() - - go func() { - var ecsCtx context.Context - var ecsCancel func() - if p.timeout == 0 { - ecsCtx = pCtx - } else { - ecsCtx, ecsCancel = context.WithTimeout(context.Background(), p.timeout) - defer ecsCancel() - } - - err := WalkExecutableCmd(ecsCtx, qCtxCopy, logger, sequence) - if err == nil { - err = qCtxCopy.ExecDefer(pCtx) - } - c <- ¶llelECSResult{ - r: qCtxCopy.R(), - status: qCtxCopy.Status(), - err: err, - from: i, - } - }() - } - - return asyncWait(ctx, qCtx, logger, c, t) -} - -type FallbackConfig struct { - // Primary exec sequence, must have at least one element. - Primary []interface{} `yaml:"primary"` - // Secondary exec sequence, must have at least one element. - Secondary []interface{} `yaml:"secondary"` - - StatLength int `yaml:"stat_length"` // An Zero value disables the (normal) fallback. - Threshold int `yaml:"threshold"` - - // FastFallback threshold in milliseconds. Zero means fast fallback is disabled. - FastFallback int `yaml:"fast_fallback"` - - // AlwaysStandby: secondary should always standby in fast fallback. - AlwaysStandby bool `yaml:"always_standby"` -} - -type FallbackECS struct { - primary *ExecutableCmdSequence - secondary *ExecutableCmdSequence - fastFallbackDuration time.Duration - alwaysStandby bool - - primaryST *statusTracker // nil if normal fallback is disabled -} - -type statusTracker struct { - sync.Mutex - threshold int - status []uint8 // 0 means success, !0 means failed - p int -} - -func newStatusTracker(threshold, statLength int) *statusTracker { - return &statusTracker{ - threshold: threshold, - status: make([]uint8, statLength), - } -} - -func (t *statusTracker) good() bool { - t.Lock() - defer t.Unlock() - - var failedSum int - for _, s := range t.status { - if s != 0 { - failedSum++ - } - } - return failedSum < t.threshold -} - -func (t *statusTracker) update(s uint8) { - t.Lock() - defer t.Unlock() - - if t.p >= len(t.status) { - t.p = 0 - } - t.status[t.p] = s - t.p++ -} - -func ParseFallbackECS(c *FallbackConfig) (*FallbackECS, error) { - if len(c.Primary) == 0 { - return nil, errors.New("primary sequence is empty") - } - if len(c.Secondary) == 0 { - return nil, errors.New("secondary sequence is empty") - } - - primaryECS, err := ParseExecutableCmdSequence(c.Primary) - if err != nil { - return nil, fmt.Errorf("invalid primary sequence: %w", err) - } - - secondaryECS, err := ParseExecutableCmdSequence(c.Secondary) - if err != nil { - return nil, fmt.Errorf("invalid secondary sequence: %w", err) - } - - fallbackECS := &FallbackECS{ - primary: primaryECS, - secondary: secondaryECS, - fastFallbackDuration: time.Duration(c.FastFallback) * time.Millisecond, - alwaysStandby: c.AlwaysStandby, - } - - if c.StatLength > 0 { - if c.Threshold > c.StatLength { - c.Threshold = c.StatLength - } - fallbackECS.primaryST = newStatusTracker(c.Threshold, c.StatLength) - } - - return fallbackECS, nil -} - -func (f *FallbackECS) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) { - return "", false, f.execCmd(ctx, qCtx, logger) -} - -func (f *FallbackECS) execCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { - if f.primaryST == nil || f.primaryST.good() { - if f.fastFallbackDuration > 0 { - return f.doFastFallback(ctx, qCtx, logger) - } else { - return f.isolateDoPrimary(ctx, qCtx, logger) - } - } - logger.Debug("primary is not good", qCtx.InfoField()) - return f.doFallback(ctx, qCtx, logger) -} - -func (f *FallbackECS) isolateDoPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { - qCtxCopy := qCtx.Copy() - err = f.doPrimary(ctx, qCtxCopy, logger) - qCtx.SetResponse(qCtxCopy.R(), qCtxCopy.Status()) - return err -} - -func (f *FallbackECS) doPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { - err = WalkExecutableCmd(ctx, qCtx, logger, f.primary) - if err == nil { - err = qCtx.ExecDefer(ctx) - } - if f.primaryST != nil { - if err != nil || qCtx.R() == nil { - f.primaryST.update(1) - } else { - f.primaryST.update(0) - } - } - - return err -} - -func (f *FallbackECS) doFastFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { - fCtx, cancel := context.WithCancel(ctx) - defer cancel() - - timer := GetTimer(f.fastFallbackDuration) - defer ReleaseTimer(timer) - - c := make(chan *parallelECSResult, 2) - primFailed := make(chan struct{}) // will be closed if primary returns an err. - - qCtxCopyP := qCtx.Copy() - go func() { - err := f.doPrimary(fCtx, qCtxCopyP, logger) - if err != nil || qCtxCopyP.R() == nil { - close(primFailed) - } - c <- ¶llelECSResult{ - r: qCtxCopyP.R(), - status: qCtxCopyP.Status(), - err: err, - from: 1, - } - }() - - qCtxCopyS := qCtx.Copy() - go func() { - if !f.alwaysStandby { // not always standby, wait here. - select { - case <-fCtx.Done(): // primary is done, no needs to exec this. - return - case <-primFailed: // primary failed or timeout, exec now. - case <-timer.C: - } - } - - err := f.doSecondary(fCtx, qCtxCopyS, logger) - res := ¶llelECSResult{ - r: qCtxCopyS.R(), - status: qCtxCopyS.Status(), - err: err, - from: 2, - } - - if f.alwaysStandby { // always standby - select { - case <-fCtx.Done(): - return - case <-primFailed: // only send secondary result when primary is failed. - c <- res - case <-timer.C: // or timeout. - c <- res - } - } else { - c <- res // not always standby, send the result asap. - } - }() - - return asyncWait(ctx, qCtx, logger, c, 2) -} - -func (f *FallbackECS) doSecondary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) { - err = WalkExecutableCmd(ctx, qCtx, logger, f.secondary) - if err == nil { - err = qCtx.ExecDefer(ctx) - } - return err -} - -func (f *FallbackECS) doFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) error { - fCtx, cancel := context.WithCancel(ctx) - defer cancel() - - c := make(chan *parallelECSResult, 2) // buf size is 2, avoid block. - - qCtxCopyP := qCtx.Copy() - go func() { - err := f.doPrimary(fCtx, qCtxCopyP, logger) - c <- ¶llelECSResult{ - r: qCtxCopyP.R(), - status: qCtxCopyP.Status(), - err: err, - from: 1, - } - }() - - qCtxCopyS := qCtx.Copy() - go func() { - err := WalkExecutableCmd(fCtx, qCtxCopyS, logger, f.secondary) - if err == nil { - err = qCtxCopyS.ExecDefer(fCtx) - } - c <- ¶llelECSResult{ - r: qCtxCopyS.R(), - status: qCtxCopyS.Status(), - err: err, - from: 2, - } - }() - - return asyncWait(ctx, qCtx, logger, c, 2) -} - -func asyncWait(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, c chan *parallelECSResult, total int) error { - for i := 0; i < total; i++ { - select { - case res := <-c: - if res.err != nil { - logger.Warn("sequence failed", qCtx.InfoField(), zap.Int("sequence", res.from), zap.Error(res.err)) - continue - } - - if res.r == nil { - logger.Debug("sequence returned with an empty response", qCtx.InfoField(), zap.Int("sequence", res.from)) - continue - } - - logger.Debug("sequence returned a response", qCtx.InfoField(), zap.Int("sequence", res.from)) - qCtx.SetResponse(res.r, res.status) - return nil - - case <-ctx.Done(): - return ctx.Err() - } - } - - // No response - qCtx.SetResponse(nil, handler.ContextStatusServerFailed) - return errors.New("no response") -} - -type ExecutableCmdSequence struct { - c []ExecutableCmd -} - -func ParseExecutableCmdSequence(in []interface{}) (*ExecutableCmdSequence, error) { - es := &ExecutableCmdSequence{c: make([]ExecutableCmd, 0, len(in))} - for i, v := range in { - ec, err := parseExecutableCmd(v) - if err != nil { - return nil, fmt.Errorf("invalid cmd #%d: %w", i, err) - } - es.c = append(es.c, ec) - } - return es, nil -} - -func parseExecutableCmd(in interface{}) (ExecutableCmd, error) { - switch v := in.(type) { - case string: - return &executablePluginTag{s: v}, nil - case map[string]interface{}: - switch { - case hasKey(v, "if") || hasKey(v, "if_and"): // if block - ec, err := ParseIfBlock(v) - if err != nil { - return nil, fmt.Errorf("invalid if section: %w", err) - } - return ec, nil - case hasKey(v, "parallel"): // parallel - ec, err := parseParallelECS(v) - if err != nil { - return nil, fmt.Errorf("invalid parallel section: %w", err) - } - return ec, nil - case hasKey(v, "primary"): - ec, err := parseFallbackECS(v) - if err != nil { - return nil, fmt.Errorf("invalid fallback section: %w", err) - } - return ec, nil - default: - return nil, errors.New("unknown section") - } - default: - return nil, fmt.Errorf("unexpected type: %s", reflect.TypeOf(in).String()) - } -} - -func parseParallelECS(m map[string]interface{}) (ec ExecutableCmd, err error) { - conf := new(ParallelECSConfig) - err = handler.WeakDecode(m, conf) - if err != nil { - return nil, err - } - return ParseParallelECS(conf) -} - -func parseFallbackECS(m map[string]interface{}) (ec ExecutableCmd, err error) { - conf := new(FallbackConfig) - err = handler.WeakDecode(m, conf) - if err != nil { - return nil, err - } - return ParseFallbackECS(conf) -} - -func hasKey(m map[string]interface{}, key string) bool { - _, ok := m[key] - return ok -} - -// ExecCmd executes the sequence. -func (es *ExecutableCmdSequence) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) { - for _, cmd := range es.c { - goTwo, earlyStop, err = cmd.ExecCmd(ctx, qCtx, logger) - if err != nil { - return "", false, err - } - if len(goTwo) != 0 || earlyStop { - return goTwo, earlyStop, nil - } - } - - return "", false, nil -} - -func (es *ExecutableCmdSequence) Len() int { - return len(es.c) -} - -// WalkExecutableCmd executes the ExecutableCmd, include its `goto`. -// This should only be used in root cmd node. -func WalkExecutableCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, entry ExecutableCmd) (err error) { - goTwo, _, err := entry.ExecCmd(ctx, qCtx, logger) - if err != nil { - return err - } - - if len(goTwo) != 0 { - logger.Debug("goto plugin", qCtx.InfoField(), zap.String("goto", goTwo)) - p, err := handler.GetPlugin(goTwo) - if err != nil { - return err - } - _, err = p.ExecES(ctx, qCtx) - return err - } - return nil -} diff --git a/tools/tools.go b/tools/tools.go index 77c0b95bf..ae77b926f 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -20,11 +20,11 @@ package tools import ( "crypto/tls" "fmt" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/netlist" - "github.com/IrineSistiana/mosdns/dispatcher/matcher/v2data" "github.com/IrineSistiana/mosdns/dispatcher/mlog" - "github.com/IrineSistiana/mosdns/dispatcher/utils" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/netlist" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/v2data" + "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils" "github.com/miekg/dns" "io" "net" @@ -113,10 +113,12 @@ func ProbServerTimeout(addr string) error { } func BenchIPMatcher(f string) error { - list, err := netlist.NewListFromFile(f) + list := netlist.NewList() + err := netlist.LoadFromFile(list, f) if err != nil { return err } + list.Sort() ip := net.IPv4(8, 8, 8, 8).To4() @@ -135,7 +137,7 @@ func BenchIPMatcher(f string) error { func BenchDomainMatcher(f string) error { matcher := domain.NewMixMatcher() - err := domain.LoadFromFileAsV2Matcher(matcher, f) + err := domain.LoadFromFile(matcher, f, nil) if err != nil { return err }