diff --git a/dispatcher/coremain/run.go b/dispatcher/coremain/run.go
index 8546ac99b..9cf3d2d89 100644
--- a/dispatcher/coremain/run.go
+++ b/dispatcher/coremain/run.go
@@ -21,8 +21,8 @@ import (
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
"github.com/IrineSistiana/mosdns/dispatcher/mlog"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/concurrent_limiter"
_ "github.com/IrineSistiana/mosdns/dispatcher/plugin"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"os"
@@ -93,7 +93,7 @@ func loadConfig(f string, depth int) {
if n < 1 {
n = 1
}
- pool := utils.NewConcurrentLimiter(n)
+ pool := concurrent_limiter.NewConcurrentLimiter(n)
wg := new(sync.WaitGroup)
for i, pluginConfig := range c.Plugin {
if len(pluginConfig.Tag) == 0 || len(pluginConfig.Type) == 0 {
diff --git a/dispatcher/pkg/concurrent_limiter/client_limiter.go b/dispatcher/pkg/concurrent_limiter/client_limiter.go
new file mode 100644
index 000000000..2951e0b47
--- /dev/null
+++ b/dispatcher/pkg/concurrent_limiter/client_limiter.go
@@ -0,0 +1,69 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of mosdns.
+//
+// mosdns is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// mosdns is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package concurrent_limiter
+
+import (
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/concurrent_map"
+)
+
+type ClientQueryLimiter struct {
+ maxQueries int
+ m *concurrent_map.ConcurrentMap
+}
+
+func NewClientQueryLimiter(maxQueries int) *ClientQueryLimiter {
+ return &ClientQueryLimiter{
+ maxQueries: maxQueries,
+ m: concurrent_map.NewConcurrentMap(64),
+ }
+}
+
+func (l *ClientQueryLimiter) Acquire(key string) bool {
+ return l.m.TestAndSet(key, l.acquireTestAndSet)
+}
+
+func (l *ClientQueryLimiter) acquireTestAndSet(v interface{}, ok bool) (newV interface{}, wantUpdate, passed bool) {
+ n := 0
+ if ok {
+ n = v.(int)
+ }
+ if n >= l.maxQueries {
+ return nil, false, false
+ }
+ n++
+ return n, true, true
+}
+
+func (l *ClientQueryLimiter) doneTestAndSet(v interface{}, ok bool) (newV interface{}, wantUpdate, passed bool) {
+ if !ok {
+ panic("ClientQueryLimiter doneTestAndSet: value is not exist")
+ }
+ n := v.(int)
+ n--
+ if n < 0 {
+ panic("ClientQueryLimiter doneTestAndSet: value becomes negative")
+ }
+ if n == 0 {
+ return nil, true, true
+ }
+ return n, true, true
+}
+
+func (l *ClientQueryLimiter) Done(key string) {
+ l.m.TestAndSet(key, l.doneTestAndSet)
+}
diff --git a/dispatcher/utils/server_handler_test.go b/dispatcher/pkg/concurrent_limiter/client_limiter_test.go
similarity index 97%
rename from dispatcher/utils/server_handler_test.go
rename to dispatcher/pkg/concurrent_limiter/client_limiter_test.go
index 517faba97..2c7a779e1 100644
--- a/dispatcher/utils/server_handler_test.go
+++ b/dispatcher/pkg/concurrent_limiter/client_limiter_test.go
@@ -1,4 +1,4 @@
-package utils
+package concurrent_limiter
import (
"strconv"
diff --git a/dispatcher/pkg/concurrent_limiter/concurrent_limiter.go b/dispatcher/pkg/concurrent_limiter/concurrent_limiter.go
new file mode 100644
index 000000000..07de47cba
--- /dev/null
+++ b/dispatcher/pkg/concurrent_limiter/concurrent_limiter.go
@@ -0,0 +1,55 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of mosdns.
+//
+// mosdns is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// mosdns is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package concurrent_limiter
+
+import "fmt"
+
+// ConcurrentLimiter is a soft limiter.
+type ConcurrentLimiter struct {
+ bucket chan struct{}
+}
+
+// NewConcurrentLimiter returns a ConcurrentLimiter, max must > 0.
+func NewConcurrentLimiter(max int) *ConcurrentLimiter {
+ if max <= 0 {
+ panic(fmt.Sprintf("ConcurrentLimiter: invalid max arg: %d", max))
+ }
+
+ bucket := make(chan struct{}, max)
+ for i := 0; i < max; i++ {
+ bucket <- struct{}{}
+ }
+
+ return &ConcurrentLimiter{bucket: bucket}
+}
+
+func (l *ConcurrentLimiter) Wait() <-chan struct{} {
+ return l.bucket
+}
+
+func (l *ConcurrentLimiter) Done() {
+ select {
+ case l.bucket <- struct{}{}:
+ default:
+ panic("ConcurrentLimiter: bucket overflow")
+ }
+}
+
+func (l *ConcurrentLimiter) Available() int {
+ return len(l.bucket)
+}
diff --git a/dispatcher/pkg/concurrent_limiter/concurrent_limiter_test.go b/dispatcher/pkg/concurrent_limiter/concurrent_limiter_test.go
new file mode 100644
index 000000000..4e56b751f
--- /dev/null
+++ b/dispatcher/pkg/concurrent_limiter/concurrent_limiter_test.go
@@ -0,0 +1,34 @@
+package concurrent_limiter
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+)
+
+func Test_ConcurrentLimiter_acquire_release(t *testing.T) {
+ l := NewConcurrentLimiter(500)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
+ defer cancel()
+
+ wg := new(sync.WaitGroup)
+ wg.Add(1000)
+ for i := 0; i < 1000; i++ {
+ go func() {
+ defer wg.Done()
+ select {
+ case <-l.Wait():
+ time.Sleep(time.Millisecond * 200)
+ l.Done()
+ case <-ctx.Done():
+ t.Fail()
+ }
+ }()
+ }
+
+ wg.Wait()
+ if l.Available() != 500 {
+ t.Fatal("token leaked")
+ }
+}
diff --git a/dispatcher/utils/concurrent_lru.go b/dispatcher/pkg/concurrent_lru/concurrent_lru.go
similarity index 95%
rename from dispatcher/utils/concurrent_lru.go
rename to dispatcher/pkg/concurrent_lru/concurrent_lru.go
index 6ba85f094..a8672f7aa 100644
--- a/dispatcher/utils/concurrent_lru.go
+++ b/dispatcher/pkg/concurrent_lru/concurrent_lru.go
@@ -15,9 +15,10 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
-package utils
+package concurrent_lru
import (
+ lru2 "github.com/IrineSistiana/mosdns/dispatcher/pkg/lru"
"hash/maphash"
"sync"
)
@@ -40,7 +41,7 @@ func NewConcurrentLRU(
for i := range cl.l {
cl.l[i] = &shardedLRU{
onGet: onGet,
- lru: NewLRU(maxSizePerShard, onEvict),
+ lru: lru2.NewLRU(maxSizePerShard, onEvict),
}
}
@@ -95,7 +96,7 @@ type shardedLRU struct {
onGet func(key string, v interface{}) interface{}
sync.Mutex
- lru *LRU
+ lru *lru2.LRU
}
func (sl *shardedLRU) Add(key string, v interface{}) {
diff --git a/dispatcher/utils/concurrent_lru_test.go b/dispatcher/pkg/concurrent_lru/concurrent_lru_test.go
similarity index 98%
rename from dispatcher/utils/concurrent_lru_test.go
rename to dispatcher/pkg/concurrent_lru/concurrent_lru_test.go
index 163221efe..2f74595f8 100644
--- a/dispatcher/utils/concurrent_lru_test.go
+++ b/dispatcher/pkg/concurrent_lru/concurrent_lru_test.go
@@ -1,4 +1,4 @@
-package utils
+package concurrent_lru
import (
"reflect"
diff --git a/dispatcher/utils/concurrent_map.go b/dispatcher/pkg/concurrent_map/concurrent_map.go
similarity index 99%
rename from dispatcher/utils/concurrent_map.go
rename to dispatcher/pkg/concurrent_map/concurrent_map.go
index e6bb087f8..68f84c12f 100644
--- a/dispatcher/utils/concurrent_map.go
+++ b/dispatcher/pkg/concurrent_map/concurrent_map.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
-package utils
+package concurrent_map
import (
"hash/maphash"
diff --git a/dispatcher/utils/concurrent_map_test.go b/dispatcher/pkg/concurrent_map/concurrent_map_test.go
similarity index 99%
rename from dispatcher/utils/concurrent_map_test.go
rename to dispatcher/pkg/concurrent_map/concurrent_map_test.go
index 3fe1050fb..6f256fafd 100644
--- a/dispatcher/utils/concurrent_map_test.go
+++ b/dispatcher/pkg/concurrent_map/concurrent_map_test.go
@@ -1,4 +1,4 @@
-package utils
+package concurrent_map
import (
"strconv"
diff --git a/dispatcher/utils/net_io.go b/dispatcher/pkg/dnsutils/net_io.go
similarity index 90%
rename from dispatcher/utils/net_io.go
rename to dispatcher/pkg/dnsutils/net_io.go
index e011894e7..3f867a8be 100644
--- a/dispatcher/utils/net_io.go
+++ b/dispatcher/pkg/dnsutils/net_io.go
@@ -15,11 +15,12 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
-package utils
+package dnsutils
import (
"encoding/binary"
"fmt"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool"
"github.com/miekg/dns"
"io"
"net"
@@ -60,8 +61,8 @@ func (e *IOErr) Unwrap() error {
// An io err will be wrapped into an IOErr.
// IsIOErr(err) can check and unwrap the inner io err.
func ReadUDPMsgFrom(c net.PacketConn, bufSize int) (m *dns.Msg, from net.Addr, n int, err error) {
- buf := GetMsgBuf(bufSize)
- defer ReleaseMsgBuf(buf)
+ buf := pool.GetMsgBuf(bufSize)
+ defer pool.ReleaseMsgBuf(buf)
n, from, err = c.ReadFrom(buf)
if err != nil {
@@ -84,8 +85,8 @@ func ReadUDPMsgFrom(c net.PacketConn, bufSize int) (m *dns.Msg, from net.Addr, n
// ReadMsgFromUDP See ReadUDPMsgFrom.
func ReadMsgFromUDP(c io.Reader, bufSize int) (m *dns.Msg, n int, err error) {
- buf := GetMsgBuf(bufSize)
- defer ReleaseMsgBuf(buf)
+ buf := pool.GetMsgBuf(bufSize)
+ defer pool.ReleaseMsgBuf(buf)
n, err = c.Read(buf)
if err != nil {
@@ -108,11 +109,11 @@ func ReadMsgFromUDP(c io.Reader, bufSize int) (m *dns.Msg, n int, err error) {
// An io err will be wrapped into an IOErr.
// IsIOErr(err) can check and unwrap the inner io err.
func WriteMsgToUDP(c io.Writer, m *dns.Msg) (n int, err error) {
- mRaw, buf, err := PackBuffer(m)
+ mRaw, buf, err := pool.PackBuffer(m)
if err != nil {
return 0, err
}
- defer ReleaseMsgBuf(buf)
+ defer pool.ReleaseMsgBuf(buf)
return WriteRawMsgToUDP(c, mRaw)
}
@@ -128,11 +129,11 @@ func WriteRawMsgToUDP(c io.Writer, b []byte) (n int, err error) {
// WriteUDPMsgTo See WriteMsgToUDP.
func WriteUDPMsgTo(m *dns.Msg, c net.PacketConn, to net.Addr) (n int, err error) {
- mRaw, buf, err := PackBuffer(m)
+ mRaw, buf, err := pool.PackBuffer(m)
if err != nil {
return 0, err
}
- defer ReleaseMsgBuf(buf)
+ defer pool.ReleaseMsgBuf(buf)
n, err = c.WriteTo(mRaw, to)
if err != nil {
@@ -161,8 +162,8 @@ func ReadMsgFromTCP(c io.Reader) (m *dns.Msg, n int, err error) {
return nil, n, dns.ErrShortRead
}
- buf := GetMsgBuf(int(length))
- defer ReleaseMsgBuf(buf)
+ buf := pool.GetMsgBuf(int(length))
+ defer pool.ReleaseMsgBuf(buf)
n2, err := io.ReadFull(c, buf)
n = n + n2
@@ -185,11 +186,11 @@ func ReadMsgFromTCP(c io.Reader) (m *dns.Msg, n int, err error) {
// An io err will be wrapped into an IOErr.
// IsIOErr(err) can check and unwrap the inner io err.
func WriteMsgToTCP(c io.Writer, m *dns.Msg) (n int, err error) {
- mRaw, buf, err := PackBuffer(m)
+ mRaw, buf, err := pool.PackBuffer(m)
if err != nil {
return 0, err
}
- defer ReleaseMsgBuf(buf)
+ defer pool.ReleaseMsgBuf(buf)
return WriteRawMsgToTCP(c, mRaw)
}
@@ -222,5 +223,5 @@ func WriteRawMsgToTCP(c io.Writer, b []byte) (n int, err error) {
}
var (
- tcpWriteBufPool = NewBytesBufPool(512 + 2)
+ tcpWriteBufPool = pool.NewBytesBufPool(512 + 2)
)
diff --git a/dispatcher/utils/net_io_test.go b/dispatcher/pkg/dnsutils/net_io_test.go
similarity index 99%
rename from dispatcher/utils/net_io_test.go
rename to dispatcher/pkg/dnsutils/net_io_test.go
index 7659a6a38..2d23ff20b 100644
--- a/dispatcher/utils/net_io_test.go
+++ b/dispatcher/pkg/dnsutils/net_io_test.go
@@ -14,7 +14,8 @@
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
-package utils
+
+package dnsutils
import (
"bytes"
diff --git a/dispatcher/pkg/dnsutils/ttl.go b/dispatcher/pkg/dnsutils/ttl.go
new file mode 100644
index 000000000..f5155ba2c
--- /dev/null
+++ b/dispatcher/pkg/dnsutils/ttl.go
@@ -0,0 +1,83 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of mosdns.
+//
+// mosdns is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// mosdns is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package dnsutils
+
+import "github.com/miekg/dns"
+
+// GetMinimalTTL returns the minimal ttl of this msg.
+// If msg m has no record, it returns 0.
+func GetMinimalTTL(m *dns.Msg) uint32 {
+ minTTL := ^uint32(0)
+ hasRecord := false
+ for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
+ for _, rr := range section {
+ if rr.Header().Rrtype == dns.TypeOPT {
+ continue // opt record ttl is not ttl.
+ }
+ hasRecord = true
+ ttl := rr.Header().Ttl
+ if ttl < minTTL {
+ minTTL = ttl
+ }
+ }
+ }
+
+ if !hasRecord { // no ttl applied
+ return 0
+ }
+ return minTTL
+}
+
+// SetTTL updates all records' ttl to ttl, except opt record.
+func SetTTL(m *dns.Msg, ttl uint32) {
+ for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
+ for _, rr := range section {
+ if rr.Header().Rrtype == dns.TypeOPT {
+ continue // opt record ttl is not ttl.
+ }
+ rr.Header().Ttl = ttl
+ }
+ }
+}
+
+func ApplyMaximumTTL(m *dns.Msg, ttl uint32) {
+ applyTTL(m, ttl, true)
+}
+
+func ApplyMinimalTTL(m *dns.Msg, ttl uint32) {
+ applyTTL(m, ttl, false)
+}
+
+func applyTTL(m *dns.Msg, ttl uint32, maximum bool) {
+ for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
+ for _, rr := range section {
+ if rr.Header().Rrtype == dns.TypeOPT {
+ continue // opt record ttl is not ttl.
+ }
+ if maximum {
+ if rr.Header().Ttl > ttl {
+ rr.Header().Ttl = ttl
+ }
+ } else {
+ if rr.Header().Ttl < ttl {
+ rr.Header().Ttl = ttl
+ }
+ }
+ }
+ }
+}
diff --git a/dispatcher/pkg/executable_seq/executable_cmd.go b/dispatcher/pkg/executable_seq/executable_cmd.go
new file mode 100644
index 000000000..0b6456737
--- /dev/null
+++ b/dispatcher/pkg/executable_seq/executable_cmd.go
@@ -0,0 +1,28 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of mosdns.
+//
+// mosdns is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) or later version.
+//
+// mosdns is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package executable_seq
+
+import (
+ "context"
+ "github.com/IrineSistiana/mosdns/dispatcher/handler"
+ "go.uber.org/zap"
+)
+
+type ExecutableCmd interface {
+ ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error)
+}
diff --git a/dispatcher/pkg/executable_seq/executable_cmd_sequence.go b/dispatcher/pkg/executable_seq/executable_cmd_sequence.go
new file mode 100644
index 000000000..16c5fffc6
--- /dev/null
+++ b/dispatcher/pkg/executable_seq/executable_cmd_sequence.go
@@ -0,0 +1,257 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of mosdns.
+//
+// mosdns is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// mosdns is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package executable_seq
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "github.com/IrineSistiana/mosdns/dispatcher/handler"
+ "go.uber.org/zap"
+ "reflect"
+ "strings"
+)
+
+type executablePluginTag struct {
+ s string
+}
+
+func (t executablePluginTag) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) {
+ p, err := handler.GetPlugin(t.s)
+ if err != nil {
+ return "", false, err
+ }
+
+ logger.Debug("exec executable plugin", qCtx.InfoField(), zap.String("exec", t.s))
+ earlyStop, err = p.ExecES(ctx, qCtx)
+ return "", earlyStop, err
+}
+
+type IfBlockConfig struct {
+ If []string `yaml:"if"`
+ IfAnd []string `yaml:"if_and"`
+ Exec []interface{} `yaml:"exec"`
+ Goto string `yaml:"goto"`
+}
+
+type matcher struct {
+ tag string
+ negate bool
+}
+
+func paresMatcher(s []string) []matcher {
+ m := make([]matcher, 0, len(s))
+ for _, tag := range s {
+ if strings.HasPrefix(tag, "!") {
+ m = append(m, matcher{tag: strings.TrimPrefix(tag, "!"), negate: true})
+ } else {
+ m = append(m, matcher{tag: tag})
+ }
+ }
+ return m
+}
+
+type IfBlock struct {
+ ifMatcher []matcher
+ ifAndMatcher []matcher
+ executableCmd ExecutableCmd
+ goTwo string
+}
+
+func (b *IfBlock) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) {
+ if len(b.ifMatcher) > 0 {
+ If, err := ifCondition(ctx, qCtx, logger, b.ifMatcher, false)
+ if err != nil {
+ return "", false, err
+ }
+ if If == false {
+ return "", false, nil // if case returns false, skip this block.
+ }
+ }
+
+ if len(b.ifAndMatcher) > 0 {
+ If, err := ifCondition(ctx, qCtx, logger, b.ifAndMatcher, true)
+ if err != nil {
+ return "", false, err
+ }
+ if If == false {
+ return "", false, nil
+ }
+ }
+
+ // exec
+ if b.executableCmd != nil {
+ goTwo, earlyStop, err = b.executableCmd.ExecCmd(ctx, qCtx, logger)
+ if err != nil {
+ return "", false, err
+ }
+ if len(goTwo) != 0 || earlyStop {
+ return goTwo, earlyStop, nil
+ }
+ }
+
+ // goto
+ if len(b.goTwo) != 0 { // if block has a goto, return it
+ return b.goTwo, false, nil
+ }
+
+ return "", false, nil
+}
+
+func ifCondition(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, p []matcher, isAnd bool) (ok bool, err error) {
+ if len(p) == 0 {
+ return false, err
+ }
+
+ for _, m := range p {
+ mp, err := handler.GetPlugin(m.tag)
+ if err != nil {
+ return false, err
+ }
+ matched, err := mp.Match(ctx, qCtx)
+ if err != nil {
+ return false, err
+ }
+ logger.Debug("exec matcher plugin", qCtx.InfoField(), zap.String("exec", m.tag), zap.Bool("result", matched))
+
+ res := matched != m.negate
+ if !isAnd && res == true {
+ return true, nil // or: if one of the case is true, skip others.
+ }
+ if isAnd && res == false {
+ return false, nil // and: if one of the case is false, skip others.
+ }
+
+ ok = res
+ }
+ return ok, nil
+}
+
+func ParseIfBlock(in map[string]interface{}) (*IfBlock, error) {
+ c := new(IfBlockConfig)
+ err := handler.WeakDecode(in, c)
+ if err != nil {
+ return nil, err
+ }
+
+ b := &IfBlock{
+ ifMatcher: paresMatcher(c.If),
+ ifAndMatcher: paresMatcher(c.IfAnd),
+ goTwo: c.Goto,
+ }
+
+ if len(c.Exec) != 0 {
+ ecs, err := ParseExecutableCmdSequence(c.Exec)
+ if err != nil {
+ return nil, err
+ }
+ b.executableCmd = ecs
+ }
+
+ return b, nil
+}
+
+type ExecutableCmdSequence struct {
+ c []ExecutableCmd
+}
+
+func ParseExecutableCmdSequence(in []interface{}) (*ExecutableCmdSequence, error) {
+ es := &ExecutableCmdSequence{c: make([]ExecutableCmd, 0, len(in))}
+ for i, v := range in {
+ ec, err := parseExecutableCmd(v)
+ if err != nil {
+ return nil, fmt.Errorf("invalid cmd #%d: %w", i, err)
+ }
+ es.c = append(es.c, ec)
+ }
+ return es, nil
+}
+
+func parseExecutableCmd(in interface{}) (ExecutableCmd, error) {
+ switch v := in.(type) {
+ case string:
+ return &executablePluginTag{s: v}, nil
+ case map[string]interface{}:
+ switch {
+ case hasKey(v, "if") || hasKey(v, "if_and"): // if block
+ ec, err := ParseIfBlock(v)
+ if err != nil {
+ return nil, fmt.Errorf("invalid if section: %w", err)
+ }
+ return ec, nil
+ case hasKey(v, "parallel"): // parallel
+ ec, err := parseParallelECS(v)
+ if err != nil {
+ return nil, fmt.Errorf("invalid parallel section: %w", err)
+ }
+ return ec, nil
+ case hasKey(v, "primary") || hasKey(v, "secondary"): // fallback
+ ec, err := parseFallbackECS(v)
+ if err != nil {
+ return nil, fmt.Errorf("invalid fallback section: %w", err)
+ }
+ return ec, nil
+ default:
+ return nil, errors.New("unknown section")
+ }
+ default:
+ return nil, fmt.Errorf("unexpected type: %s", reflect.TypeOf(in).String())
+ }
+}
+
+func parseParallelECS(m map[string]interface{}) (ec ExecutableCmd, err error) {
+ conf := new(ParallelECSConfig)
+ err = handler.WeakDecode(m, conf)
+ if err != nil {
+ return nil, err
+ }
+ return ParseParallelECS(conf)
+}
+
+func parseFallbackECS(m map[string]interface{}) (ec ExecutableCmd, err error) {
+ conf := new(FallbackConfig)
+ err = handler.WeakDecode(m, conf)
+ if err != nil {
+ return nil, err
+ }
+ return ParseFallbackECS(conf)
+}
+
+func hasKey(m map[string]interface{}, key string) bool {
+ _, ok := m[key]
+ return ok
+}
+
+// ExecCmd executes the sequence.
+func (es *ExecutableCmdSequence) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) {
+ for _, cmd := range es.c {
+ goTwo, earlyStop, err = cmd.ExecCmd(ctx, qCtx, logger)
+ if err != nil {
+ return "", false, err
+ }
+ if len(goTwo) != 0 || earlyStop {
+ return goTwo, earlyStop, nil
+ }
+ }
+
+ return "", false, nil
+}
+
+func (es *ExecutableCmdSequence) Len() int {
+ return len(es.c)
+}
diff --git a/dispatcher/pkg/executable_seq/executable_cmd_sequence_test.go b/dispatcher/pkg/executable_seq/executable_cmd_sequence_test.go
new file mode 100644
index 000000000..5c76095b7
--- /dev/null
+++ b/dispatcher/pkg/executable_seq/executable_cmd_sequence_test.go
@@ -0,0 +1,166 @@
+package executable_seq
+
+import (
+ "context"
+ "errors"
+ "github.com/IrineSistiana/mosdns/dispatcher/handler"
+ "github.com/miekg/dns"
+ "go.uber.org/zap"
+ "gopkg.in/yaml.v3"
+ "testing"
+)
+
+func Test_ECS(t *testing.T) {
+ handler.PurgePluginRegister()
+ defer handler.PurgePluginRegister()
+
+ mErr := errors.New("mErr")
+ eErr := errors.New("eErr")
+
+ var tests = []struct {
+ name string
+ yamlStr string
+ wantNext string
+ wantES bool
+ wantErr error
+ }{
+ {name: "test empty input", yamlStr: `
+exec:
+`,
+ wantNext: "", wantErr: nil},
+ {name: "test empty end", yamlStr: `
+exec:
+- if: ["!matched",not_matched] # not matched
+ exec: [exec_err]
+ goto: goto`,
+ wantNext: "", wantErr: nil},
+
+ {name: "test if_and", yamlStr: `
+exec:
+- if_and: [matched, not_matched] # not matched
+ goto: goto1
+- if_and: [matched, not_matched, match_err] # not matched, early stop, no err
+ goto: goto2
+- if_and: [matched, matched, matched] # matched
+ goto: goto3
+`,
+ wantNext: "goto3", wantErr: nil},
+
+ {name: "test if_and err", yamlStr: `
+exec:
+- if_and: [matched, match_err] # err
+ goto: goto1
+`,
+ wantNext: "", wantErr: mErr},
+
+ {name: "test if", yamlStr: `
+exec:
+- if: ["!matched", not_matched] # test ! prefix, not matched
+ goto: goto1
+- if: [matched, match_err] # matched, early stop, no err
+ exec:
+ - if: ["!not_matched", not_matched] # matched
+ goto: goto2 # reached here
+ goto: goto3
+`,
+ wantNext: "goto2", wantErr: nil},
+
+ {name: "test if err", yamlStr: `
+exec:
+- if: [not_matched, match_err] # err
+ goto: goto1
+`,
+ wantNext: "", wantErr: mErr},
+
+ {name: "test exec err", yamlStr: `
+exec:
+- if: [matched]
+ exec: exec_err
+ goto: goto1
+`,
+ wantNext: "", wantErr: eErr},
+
+ {name: "test early return in main sequence", yamlStr: `
+exec:
+- exec
+- exec_skip
+- exec_err # skipped, should not reach here.
+`,
+ wantNext: "", wantES: true, wantErr: nil},
+
+ {name: "test early return in if branch", yamlStr: `
+exec:
+- if: [matched]
+ exec:
+ - exec_skip
+ goto: goto1 # skipped, should not reach here.
+`,
+ wantNext: "", wantES: true, wantErr: nil},
+ }
+
+ // not_matched
+ handler.MustRegPlugin(&handler.DummyMatcherPlugin{
+ BP: handler.NewBP("not_matched", ""),
+ Matched: false,
+ WantErr: nil,
+ }, true)
+
+ // do something
+ handler.MustRegPlugin(&handler.DummyExecutablePlugin{
+ BP: handler.NewBP("exec", ""),
+ WantErr: nil,
+ }, true)
+
+ // do something and skip the following sequence
+ handler.MustRegPlugin(&handler.DummyESExecutablePlugin{
+ BP: handler.NewBP("exec_skip", ""),
+ WantSkip: true,
+ }, true)
+
+ // matched
+ handler.MustRegPlugin(&handler.DummyMatcherPlugin{
+ BP: handler.NewBP("matched", ""),
+ Matched: true,
+ WantErr: nil,
+ }, true)
+
+ // plugins should return an err.
+ handler.MustRegPlugin(&handler.DummyMatcherPlugin{
+ BP: handler.NewBP("match_err", ""),
+ Matched: false,
+ WantErr: mErr,
+ }, true)
+
+ handler.MustRegPlugin(&handler.DummyExecutablePlugin{
+ BP: handler.NewBP("exec_err", ""),
+ WantErr: eErr,
+ }, true)
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ args := make(map[string]interface{}, 0)
+ err := yaml.Unmarshal([]byte(tt.yamlStr), args)
+ if err != nil {
+ t.Fatal(err)
+ }
+ in, _ := args["exec"].([]interface{})
+ ecs, err := ParseExecutableCmdSequence(in)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ gotNext, gotEarlyStop, err := ecs.ExecCmd(context.Background(), handler.NewContext(new(dns.Msg), nil), zap.NewNop())
+ if (err != nil || tt.wantErr != nil) && !errors.Is(err, tt.wantErr) {
+ t.Errorf("ExecCmd() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if gotNext != tt.wantNext {
+ t.Errorf("ExecCmd() gotNext = %v, want %v", gotNext, tt.wantNext)
+ }
+
+ if gotEarlyStop != tt.wantES {
+ t.Errorf("ExecCmd() gotEarlyStop = %v, want %v", gotEarlyStop, tt.wantES)
+ }
+ })
+ }
+}
diff --git a/dispatcher/pkg/executable_seq/fallback.go b/dispatcher/pkg/executable_seq/fallback.go
new file mode 100644
index 000000000..c892ac0b9
--- /dev/null
+++ b/dispatcher/pkg/executable_seq/fallback.go
@@ -0,0 +1,268 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of mosdns.
+//
+// mosdns is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) or later version.
+//
+// mosdns is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package executable_seq
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "github.com/IrineSistiana/mosdns/dispatcher/handler"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool"
+ "go.uber.org/zap"
+ "sync"
+ "time"
+)
+
+type FallbackConfig struct {
+ // Primary exec sequence, must have at least one element.
+ Primary []interface{} `yaml:"primary"`
+ // Secondary exec sequence, must have at least one element.
+ Secondary []interface{} `yaml:"secondary"`
+
+ StatLength int `yaml:"stat_length"` // An Zero value disables the (normal) fallback.
+ Threshold int `yaml:"threshold"`
+
+ // FastFallback threshold in milliseconds. Zero means fast fallback is disabled.
+ FastFallback int `yaml:"fast_fallback"`
+
+ // AlwaysStandby: secondary should always standby in fast fallback.
+ AlwaysStandby bool `yaml:"always_standby"`
+}
+
+type FallbackECS struct {
+ primary *ExecutableCmdSequence
+ secondary *ExecutableCmdSequence
+ fastFallbackDuration time.Duration
+ alwaysStandby bool
+
+ primaryST *statusTracker // nil if normal fallback is disabled
+}
+
+type statusTracker struct {
+ sync.Mutex
+ threshold int
+ status []uint8 // 0 means success, !0 means failed
+ p int
+}
+
+func newStatusTracker(threshold, statLength int) *statusTracker {
+ return &statusTracker{
+ threshold: threshold,
+ status: make([]uint8, statLength),
+ }
+}
+
+func (t *statusTracker) good() bool {
+ t.Lock()
+ defer t.Unlock()
+
+ var failedSum int
+ for _, s := range t.status {
+ if s != 0 {
+ failedSum++
+ }
+ }
+ return failedSum < t.threshold
+}
+
+func (t *statusTracker) update(s uint8) {
+ t.Lock()
+ defer t.Unlock()
+
+ if t.p >= len(t.status) {
+ t.p = 0
+ }
+ t.status[t.p] = s
+ t.p++
+}
+
+func ParseFallbackECS(c *FallbackConfig) (*FallbackECS, error) {
+ if len(c.Primary) == 0 {
+ return nil, errors.New("primary sequence is empty")
+ }
+ if len(c.Secondary) == 0 {
+ return nil, errors.New("secondary sequence is empty")
+ }
+
+ primaryECS, err := ParseExecutableCmdSequence(c.Primary)
+ if err != nil {
+ return nil, fmt.Errorf("invalid primary sequence: %w", err)
+ }
+
+ secondaryECS, err := ParseExecutableCmdSequence(c.Secondary)
+ if err != nil {
+ return nil, fmt.Errorf("invalid secondary sequence: %w", err)
+ }
+
+ fallbackECS := &FallbackECS{
+ primary: primaryECS,
+ secondary: secondaryECS,
+ fastFallbackDuration: time.Duration(c.FastFallback) * time.Millisecond,
+ alwaysStandby: c.AlwaysStandby,
+ }
+
+ if c.StatLength > 0 {
+ if c.Threshold > c.StatLength {
+ c.Threshold = c.StatLength
+ }
+ fallbackECS.primaryST = newStatusTracker(c.Threshold, c.StatLength)
+ }
+
+ return fallbackECS, nil
+}
+
+func (f *FallbackECS) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) {
+ return "", false, f.execCmd(ctx, qCtx, logger)
+}
+
+func (f *FallbackECS) execCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
+ if f.primaryST == nil || f.primaryST.good() {
+ if f.fastFallbackDuration > 0 {
+ return f.doFastFallback(ctx, qCtx, logger)
+ } else {
+ return f.isolateDoPrimary(ctx, qCtx, logger)
+ }
+ }
+ logger.Debug("primary is not good", qCtx.InfoField())
+ return f.doFallback(ctx, qCtx, logger)
+}
+
+func (f *FallbackECS) isolateDoPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
+ qCtxCopy := qCtx.Copy()
+ err = f.doPrimary(ctx, qCtxCopy, logger)
+ qCtx.SetResponse(qCtxCopy.R(), qCtxCopy.Status())
+ return err
+}
+
+func (f *FallbackECS) doPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
+ err = WalkExecutableCmd(ctx, qCtx, logger, f.primary)
+ if err == nil {
+ err = qCtx.ExecDefer(ctx)
+ }
+ if f.primaryST != nil {
+ if err != nil || qCtx.R() == nil {
+ f.primaryST.update(1)
+ } else {
+ f.primaryST.update(0)
+ }
+ }
+
+ return err
+}
+
+func (f *FallbackECS) doFastFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
+ fCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ timer := pool.GetTimer(f.fastFallbackDuration)
+ defer pool.ReleaseTimer(timer)
+
+ c := make(chan *parallelECSResult, 2)
+ primFailed := make(chan struct{}) // will be closed if primary returns an err.
+
+ qCtxCopyP := qCtx.Copy()
+ go func() {
+ err := f.doPrimary(fCtx, qCtxCopyP, logger)
+ if err != nil || qCtxCopyP.R() == nil {
+ close(primFailed)
+ }
+ c <- ¶llelECSResult{
+ r: qCtxCopyP.R(),
+ status: qCtxCopyP.Status(),
+ err: err,
+ from: 1,
+ }
+ }()
+
+ qCtxCopyS := qCtx.Copy()
+ go func() {
+ if !f.alwaysStandby { // not always standby, wait here.
+ select {
+ case <-fCtx.Done(): // primary is done, no needs to exec this.
+ return
+ case <-primFailed: // primary failed or timeout, exec now.
+ case <-timer.C:
+ }
+ }
+
+ err := f.doSecondary(fCtx, qCtxCopyS, logger)
+ res := ¶llelECSResult{
+ r: qCtxCopyS.R(),
+ status: qCtxCopyS.Status(),
+ err: err,
+ from: 2,
+ }
+
+ if f.alwaysStandby { // always standby
+ select {
+ case <-fCtx.Done():
+ return
+ case <-primFailed: // only send secondary result when primary is failed.
+ c <- res
+ case <-timer.C: // or timeout.
+ c <- res
+ }
+ } else {
+ c <- res // not always standby, send the result asap.
+ }
+ }()
+
+ return asyncWait(ctx, qCtx, logger, c, 2)
+}
+
+func (f *FallbackECS) doSecondary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
+ err = WalkExecutableCmd(ctx, qCtx, logger, f.secondary)
+ if err == nil {
+ err = qCtx.ExecDefer(ctx)
+ }
+ return err
+}
+
+func (f *FallbackECS) doFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) error {
+ fCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ c := make(chan *parallelECSResult, 2) // buf size is 2, avoid block.
+
+ qCtxCopyP := qCtx.Copy()
+ go func() {
+ err := f.doPrimary(fCtx, qCtxCopyP, logger)
+ c <- ¶llelECSResult{
+ r: qCtxCopyP.R(),
+ status: qCtxCopyP.Status(),
+ err: err,
+ from: 1,
+ }
+ }()
+
+ qCtxCopyS := qCtx.Copy()
+ go func() {
+ err := WalkExecutableCmd(fCtx, qCtxCopyS, logger, f.secondary)
+ if err == nil {
+ err = qCtxCopyS.ExecDefer(fCtx)
+ }
+ c <- ¶llelECSResult{
+ r: qCtxCopyS.R(),
+ status: qCtxCopyS.Status(),
+ err: err,
+ from: 2,
+ }
+ }()
+
+ return asyncWait(ctx, qCtx, logger, c, 2)
+}
diff --git a/dispatcher/utils/executable_sequence_test.go b/dispatcher/pkg/executable_seq/fallback_test.go
similarity index 51%
rename from dispatcher/utils/executable_sequence_test.go
rename to dispatcher/pkg/executable_seq/fallback_test.go
index 07c6a50cb..8297163f9 100644
--- a/dispatcher/utils/executable_sequence_test.go
+++ b/dispatcher/pkg/executable_seq/fallback_test.go
@@ -1,4 +1,21 @@
-package utils
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of mosdns.
+//
+// mosdns is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) or later version.
+//
+// mosdns is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package executable_seq
import (
"context"
@@ -6,224 +23,10 @@ import (
"github.com/IrineSistiana/mosdns/dispatcher/handler"
"github.com/miekg/dns"
"go.uber.org/zap"
- "gopkg.in/yaml.v3"
"testing"
"time"
)
-func Test_ECS(t *testing.T) {
- handler.PurgePluginRegister()
- defer handler.PurgePluginRegister()
-
- mErr := errors.New("mErr")
- eErr := errors.New("eErr")
-
- var tests = []struct {
- name string
- yamlStr string
- wantNext string
- wantES bool
- wantErr error
- }{
- {name: "test empty end", yamlStr: `
-exec:
-- if: ["!matched",not_matched] # not matched
- exec: [exec_err]
- goto: goto`,
- wantNext: "", wantErr: nil},
-
- {name: "test if_and", yamlStr: `
-exec:
-- if_and: [matched, not_matched] # not matched
- goto: goto1
-- if_and: [matched, not_matched, match_err] # not matched, early stop, no err
- goto: goto2
-- if_and: [matched, matched, matched] # matched
- goto: goto3
-`,
- wantNext: "goto3", wantErr: nil},
-
- {name: "test if_and err", yamlStr: `
-exec:
-- if_and: [matched, match_err] # err
- goto: goto1
-`,
- wantNext: "", wantErr: mErr},
-
- {name: "test if", yamlStr: `
-exec:
-- if: ["!matched", not_matched] # test ! prefix, not matched
- goto: goto1
-- if: [matched, match_err] # matched, early stop, no err
- exec:
- - if: ["!not_matched", not_matched] # matched
- goto: goto2 # reached here
- goto: goto3
-`,
- wantNext: "goto2", wantErr: nil},
-
- {name: "test if err", yamlStr: `
-exec:
-- if: [not_matched, match_err] # err
- goto: goto1
-`,
- wantNext: "", wantErr: mErr},
-
- {name: "test exec err", yamlStr: `
-exec:
-- if: [matched]
- exec: exec_err
- goto: goto1
-`,
- wantNext: "", wantErr: eErr},
-
- {name: "test early return in main sequence", yamlStr: `
-exec:
-- exec
-- exec_skip
-- exec_err # skipped, should not reach here.
-`,
- wantNext: "", wantES: true, wantErr: nil},
-
- {name: "test early return in if branch", yamlStr: `
-exec:
-- if: [matched]
- exec:
- - exec_skip
- goto: goto1 # skipped, should not reach here.
-`,
- wantNext: "", wantES: true, wantErr: nil},
- }
-
- // not_matched
- handler.MustRegPlugin(&handler.DummyMatcherPlugin{
- BP: handler.NewBP("not_matched", ""),
- Matched: false,
- WantErr: nil,
- }, true)
-
- // do something
- handler.MustRegPlugin(&handler.DummyExecutablePlugin{
- BP: handler.NewBP("exec", ""),
- WantErr: nil,
- }, true)
-
- // do something and skip the following sequence
- handler.MustRegPlugin(&handler.DummyESExecutablePlugin{
- BP: handler.NewBP("exec_skip", ""),
- WantSkip: true,
- }, true)
-
- // matched
- handler.MustRegPlugin(&handler.DummyMatcherPlugin{
- BP: handler.NewBP("matched", ""),
- Matched: true,
- WantErr: nil,
- }, true)
-
- // plugins should return an err.
- handler.MustRegPlugin(&handler.DummyMatcherPlugin{
- BP: handler.NewBP("match_err", ""),
- Matched: false,
- WantErr: mErr,
- }, true)
-
- handler.MustRegPlugin(&handler.DummyExecutablePlugin{
- BP: handler.NewBP("exec_err", ""),
- WantErr: eErr,
- }, true)
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- args := make(map[string]interface{}, 0)
- err := yaml.Unmarshal([]byte(tt.yamlStr), args)
- if err != nil {
- t.Fatal(err)
- }
- ecs, err := ParseExecutableCmdSequence(args["exec"].([]interface{}))
- if err != nil {
- t.Fatal(err)
- }
-
- gotNext, gotEarlyStop, err := ecs.ExecCmd(context.Background(), handler.NewContext(new(dns.Msg), nil), zap.NewNop())
- if (err != nil || tt.wantErr != nil) && !errors.Is(err, tt.wantErr) {
- t.Errorf("ExecCmd() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if gotNext != tt.wantNext {
- t.Errorf("ExecCmd() gotNext = %v, want %v", gotNext, tt.wantNext)
- }
-
- if gotEarlyStop != tt.wantES {
- t.Errorf("ExecCmd() gotEarlyStop = %v, want %v", gotEarlyStop, tt.wantES)
- }
- })
- }
-}
-
-func Test_ParallelECS(t *testing.T) {
- handler.PurgePluginRegister()
- defer handler.PurgePluginRegister()
-
- r1 := new(dns.Msg)
- r2 := new(dns.Msg)
-
- er := errors.New("")
- tests := []struct {
- name string
- r1 *dns.Msg
- e1 error
- r2 *dns.Msg
- e2 error
- wantR *dns.Msg
- wantErr bool
- }{
- {"failed #1", nil, er, nil, er, nil, true},
- {"failed #2", nil, nil, nil, nil, nil, true},
- {"p1 response #1", r1, nil, nil, nil, r1, false},
- {"p1 response #2", r1, nil, nil, er, r1, false},
- {"p2 response #1", nil, nil, r2, nil, r2, false},
- {"p2 response #2", nil, er, r2, nil, r2, false},
- }
-
- parallelECS, err := ParseParallelECS(&ParallelECSConfig{
- Parallel: [][]interface{}{{"p1"}, {"p2"}},
- })
- if err != nil {
- t.Fatal(err)
- }
-
- ctx := context.Background()
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- p1 := &handler.DummyExecutablePlugin{
- BP: handler.NewBP("p1", ""),
- Sleep: 0,
- WantR: tt.r1,
- WantErr: tt.e1,
- }
- p2 := &handler.DummyExecutablePlugin{
- BP: handler.NewBP("p2", ""),
- Sleep: 0,
- WantR: tt.r2,
- WantErr: tt.e2,
- }
- handler.MustRegPlugin(p1, false)
- handler.MustRegPlugin(p2, false)
-
- qCtx := handler.NewContext(new(dns.Msg), nil)
- err := parallelECS.execCmd(ctx, qCtx, zap.NewNop())
- if tt.wantErr != (err != nil) {
- t.Fatalf("execCmd() error = %v, wantErr %v", err, tt.wantErr)
- }
-
- if tt.wantR != qCtx.R() {
- t.Fatalf("execCmd() qCtx.R() = %p, wantR %p", qCtx.R(), tt.wantR)
- }
- })
- }
-}
-
func Test_FallbackECS_fallback(t *testing.T) {
handler.PurgePluginRegister()
defer handler.PurgePluginRegister()
diff --git a/dispatcher/pkg/executable_seq/parallel.go b/dispatcher/pkg/executable_seq/parallel.go
new file mode 100644
index 000000000..31c654f97
--- /dev/null
+++ b/dispatcher/pkg/executable_seq/parallel.go
@@ -0,0 +1,107 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of mosdns.
+//
+// mosdns is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) or later version.
+//
+// mosdns is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package executable_seq
+
+import (
+ "context"
+ "fmt"
+ "github.com/IrineSistiana/mosdns/dispatcher/handler"
+ "github.com/miekg/dns"
+ "go.uber.org/zap"
+ "time"
+)
+
+type ParallelECS struct {
+ s []*ExecutableCmdSequence
+ timeout time.Duration
+}
+
+type ParallelECSConfig struct {
+ Parallel [][]interface{} `yaml:"parallel"`
+ Timeout uint `yaml:"timeout"`
+}
+
+func ParseParallelECS(c *ParallelECSConfig) (*ParallelECS, error) {
+ if len(c.Parallel) < 2 {
+ return nil, fmt.Errorf("parallel needs at least 2 cmd sequences, but got %d", len(c.Parallel))
+ }
+
+ ps := make([]*ExecutableCmdSequence, 0, len(c.Parallel))
+ for i, subSequence := range c.Parallel {
+ es, err := ParseExecutableCmdSequence(subSequence)
+ if err != nil {
+ return nil, fmt.Errorf("invalid parallel sequence at index %d: %w", i, err)
+ }
+ ps = append(ps, es)
+ }
+ return &ParallelECS{s: ps, timeout: time.Duration(c.Timeout) * time.Second}, nil
+}
+
+type parallelECSResult struct {
+ r *dns.Msg
+ status handler.ContextStatus
+ err error
+ from int
+}
+
+func (p *ParallelECS) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) {
+ return "", false, p.execCmd(ctx, qCtx, logger)
+}
+
+func (p *ParallelECS) execCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
+
+ var pCtx context.Context // only valid if p.timeout == 0
+ var cancel func()
+ if p.timeout == 0 {
+ pCtx, cancel = context.WithCancel(ctx)
+ defer cancel()
+ }
+
+ t := len(p.s)
+ c := make(chan *parallelECSResult, len(p.s)) // use buf chan to avoid block.
+
+ for i, sequence := range p.s {
+ i := i
+ sequence := sequence
+ qCtxCopy := qCtx.Copy()
+
+ go func() {
+ var ecsCtx context.Context
+ var ecsCancel func()
+ if p.timeout == 0 {
+ ecsCtx = pCtx
+ } else {
+ ecsCtx, ecsCancel = context.WithTimeout(context.Background(), p.timeout)
+ defer ecsCancel()
+ }
+
+ err := WalkExecutableCmd(ecsCtx, qCtxCopy, logger, sequence)
+ if err == nil {
+ err = qCtxCopy.ExecDefer(pCtx)
+ }
+ c <- ¶llelECSResult{
+ r: qCtxCopy.R(),
+ status: qCtxCopy.Status(),
+ err: err,
+ from: i,
+ }
+ }()
+ }
+
+ return asyncWait(ctx, qCtx, logger, c, t)
+}
diff --git a/dispatcher/pkg/executable_seq/parallel_test.go b/dispatcher/pkg/executable_seq/parallel_test.go
new file mode 100644
index 000000000..2cf0caf5e
--- /dev/null
+++ b/dispatcher/pkg/executable_seq/parallel_test.go
@@ -0,0 +1,90 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of mosdns.
+//
+// mosdns is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) or later version.
+//
+// mosdns is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package executable_seq
+
+import (
+ "context"
+ "errors"
+ "github.com/IrineSistiana/mosdns/dispatcher/handler"
+ "github.com/miekg/dns"
+ "go.uber.org/zap"
+ "testing"
+)
+
+func Test_ParallelECS(t *testing.T) {
+ handler.PurgePluginRegister()
+ defer handler.PurgePluginRegister()
+
+ r1 := new(dns.Msg)
+ r2 := new(dns.Msg)
+
+ er := errors.New("")
+ tests := []struct {
+ name string
+ r1 *dns.Msg
+ e1 error
+ r2 *dns.Msg
+ e2 error
+ wantR *dns.Msg
+ wantErr bool
+ }{
+ {"failed #1", nil, er, nil, er, nil, true},
+ {"failed #2", nil, nil, nil, nil, nil, true},
+ {"p1 response #1", r1, nil, nil, nil, r1, false},
+ {"p1 response #2", r1, nil, nil, er, r1, false},
+ {"p2 response #1", nil, nil, r2, nil, r2, false},
+ {"p2 response #2", nil, er, r2, nil, r2, false},
+ }
+
+ parallelECS, err := ParseParallelECS(&ParallelECSConfig{
+ Parallel: [][]interface{}{{"p1"}, {"p2"}},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ctx := context.Background()
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ p1 := &handler.DummyExecutablePlugin{
+ BP: handler.NewBP("p1", ""),
+ Sleep: 0,
+ WantR: tt.r1,
+ WantErr: tt.e1,
+ }
+ p2 := &handler.DummyExecutablePlugin{
+ BP: handler.NewBP("p2", ""),
+ Sleep: 0,
+ WantR: tt.r2,
+ WantErr: tt.e2,
+ }
+ handler.MustRegPlugin(p1, false)
+ handler.MustRegPlugin(p2, false)
+
+ qCtx := handler.NewContext(new(dns.Msg), nil)
+ err := parallelECS.execCmd(ctx, qCtx, zap.NewNop())
+ if tt.wantErr != (err != nil) {
+ t.Fatalf("execCmd() error = %v, wantErr %v", err, tt.wantErr)
+ }
+
+ if tt.wantR != qCtx.R() {
+ t.Fatalf("execCmd() qCtx.R() = %p, wantR %p", qCtx.R(), tt.wantR)
+ }
+ })
+ }
+}
diff --git a/dispatcher/pkg/executable_seq/utils.go b/dispatcher/pkg/executable_seq/utils.go
new file mode 100644
index 000000000..9f36cc0e1
--- /dev/null
+++ b/dispatcher/pkg/executable_seq/utils.go
@@ -0,0 +1,73 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of mosdns.
+//
+// mosdns is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) or later version.
+//
+// mosdns is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package executable_seq
+
+import (
+ "context"
+ "errors"
+ "github.com/IrineSistiana/mosdns/dispatcher/handler"
+ "go.uber.org/zap"
+)
+
+// WalkExecutableCmd executes the ExecutableCmd, include its `goto`.
+// This should only be used in root cmd node.
+func WalkExecutableCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, entry ExecutableCmd) (err error) {
+ goTwo, _, err := entry.ExecCmd(ctx, qCtx, logger)
+ if err != nil {
+ return err
+ }
+
+ if len(goTwo) != 0 {
+ logger.Debug("goto plugin", qCtx.InfoField(), zap.String("goto", goTwo))
+ p, err := handler.GetPlugin(goTwo)
+ if err != nil {
+ return err
+ }
+ _, err = p.ExecES(ctx, qCtx)
+ return err
+ }
+ return nil
+}
+
+func asyncWait(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, c chan *parallelECSResult, total int) error {
+ for i := 0; i < total; i++ {
+ select {
+ case res := <-c:
+ if res.err != nil {
+ logger.Warn("sequence failed", qCtx.InfoField(), zap.Int("sequence", res.from), zap.Error(res.err))
+ continue
+ }
+
+ if res.r == nil {
+ logger.Debug("sequence returned with an empty response", qCtx.InfoField(), zap.Int("sequence", res.from))
+ continue
+ }
+
+ logger.Debug("sequence returned a response", qCtx.InfoField(), zap.Int("sequence", res.from))
+ qCtx.SetResponse(res.r, res.status)
+ return nil
+
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ }
+
+ // No response
+ qCtx.SetResponse(nil, handler.ContextStatusServerFailed)
+ return errors.New("no response")
+}
diff --git a/dispatcher/utils/load_once.go b/dispatcher/pkg/load_cache/load_cache.go
similarity index 79%
rename from dispatcher/utils/load_once.go
rename to dispatcher/pkg/load_cache/load_cache.go
index 89b9eb44b..e3beabc8d 100644
--- a/dispatcher/utils/load_once.go
+++ b/dispatcher/pkg/load_cache/load_cache.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
-package utils
+package load_cache
import (
"io/ioutil"
@@ -24,18 +24,18 @@ import (
"time"
)
-type LoadOnceCache struct {
+type LoadCache struct {
l sync.Mutex
cache map[string]interface{}
}
-func NewCache() *LoadOnceCache {
- return &LoadOnceCache{
+func NewCache() *LoadCache {
+ return &LoadCache{
cache: make(map[string]interface{}),
}
}
-func (c *LoadOnceCache) Put(key string, data interface{}, ttl time.Duration) {
+func (c *LoadCache) Put(key string, data interface{}, ttl time.Duration) {
if ttl <= 0 {
return
}
@@ -49,14 +49,14 @@ func (c *LoadOnceCache) Put(key string, data interface{}, ttl time.Duration) {
time.AfterFunc(ttl, rm)
}
-func (c *LoadOnceCache) Remove(key string) {
+func (c *LoadCache) Remove(key string) {
c.l.Lock()
defer c.l.Unlock()
delete(c.cache, key)
}
-func (c *LoadOnceCache) Load(key string) (interface{}, bool) {
+func (c *LoadCache) Load(key string) (interface{}, bool) {
c.l.Lock()
defer c.l.Unlock()
@@ -64,7 +64,7 @@ func (c *LoadOnceCache) Load(key string) (interface{}, bool) {
return data, ok
}
-func (c *LoadOnceCache) LoadFromCacheOrRawDisk(file string) (interface{}, []byte, error) {
+func (c *LoadCache) LoadFromCacheOrRawDisk(file string) (interface{}, []byte, error) {
// load from cache
data, ok := c.Load(file)
if ok {
diff --git a/dispatcher/utils/lru.go b/dispatcher/pkg/lru/lru.go
similarity index 99%
rename from dispatcher/utils/lru.go
rename to dispatcher/pkg/lru/lru.go
index 52d3a7b2f..6af59ca7f 100644
--- a/dispatcher/utils/lru.go
+++ b/dispatcher/pkg/lru/lru.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
-package utils
+package lru
import (
"container/list"
diff --git a/dispatcher/utils/lru_test.go b/dispatcher/pkg/lru/lru_test.go
similarity index 99%
rename from dispatcher/utils/lru_test.go
rename to dispatcher/pkg/lru/lru_test.go
index 30a0b0878..dc3a7c2f6 100644
--- a/dispatcher/utils/lru_test.go
+++ b/dispatcher/pkg/lru/lru_test.go
@@ -1,4 +1,4 @@
-package utils
+package lru
import (
"reflect"
diff --git a/dispatcher/matcher/domain/interface.go b/dispatcher/pkg/matcher/domain/interface.go
similarity index 100%
rename from dispatcher/matcher/domain/interface.go
rename to dispatcher/pkg/matcher/domain/interface.go
diff --git a/dispatcher/matcher/domain/load_helper.go b/dispatcher/pkg/matcher/domain/load_helper.go
similarity index 77%
rename from dispatcher/matcher/domain/load_helper.go
rename to dispatcher/pkg/matcher/domain/load_helper.go
index 1bbc32d96..4f69c82f2 100644
--- a/dispatcher/matcher/domain/load_helper.go
+++ b/dispatcher/pkg/matcher/domain/load_helper.go
@@ -22,8 +22,9 @@ import (
"bytes"
"errors"
"fmt"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/v2data"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/load_cache"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/v2data"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"io"
"io/ioutil"
"strings"
@@ -32,7 +33,7 @@ import (
"github.com/golang/protobuf/proto"
)
-var matcherCache = utils.NewCache()
+var matcherCache = load_cache.NewCache()
const (
cacheTTL = time.Second * 30
@@ -41,35 +42,32 @@ const (
// ProcessAttrFunc processes the additional attributions. The given []string could have a 0 length or is nil.
type ProcessAttrFunc func([]string) (v interface{}, accept bool, err error)
-// LoadFromFile loads data from the file.
-// File can be a text file or a v2ray data file.
-// Only MixMatcher can load v2ray data file.
-// v2ray data file needs to specify the data category by using ':', e.g. 'geosite.dat:cn'
-func LoadFromFile(m Matcher, file string, processAttr ProcessAttrFunc) error {
- var err error
- if tmp := strings.SplitN(file, ":", 2); len(tmp) == 2 { // is a v2ray data file
- mixMatcher, ok := m.(*MixMatcher)
- if !ok {
- return errors.New("only MixMatcher can load v2ray data file")
- }
- filePath := tmp[0]
- countryCode := tmp[1]
- err = mixMatcher.LoadFromDAT(filePath, countryCode, processAttr)
- } else { // is a text file
- err = LoadFromTextFile(m, file, processAttr)
- }
- if err != nil {
- return err
+// Load loads data from a entry.
+// If entry begin with "ext:", Load loads the file by using LoadFromFile.
+// Else it loads the entry as a text pattern by using LoadFromText.
+func Load(m Matcher, entry string, processAttr ProcessAttrFunc) error {
+ s1, s2, ok := utils.SplitString2(entry, ":")
+ if ok && s1 == "ext" {
+ return LoadFromFile(m, s2, processAttr)
}
+ return LoadFromText(m, entry, processAttr)
+}
+// BatchLoadMatcher loads multiple files using Load.
+func BatchLoadMatcher(m Matcher, f []string, processAttr ProcessAttrFunc) error {
+ for _, file := range f {
+ err := Load(m, file, processAttr)
+ if err != nil {
+ return fmt.Errorf("failed to load file %s: %w", file, err)
+ }
+ }
return nil
}
-// LoadFromFileAsV2Matcher loads data from a file.
+// LoadFromFile loads data from a file.
// v2ray data file can also have multiple @attr. e.g. 'geosite.dat:cn@attr1@attr2'.
// Only the record with all of the @attr will be loaded.
-// Also see LoadFromFile.
-func LoadFromFileAsV2Matcher(m Matcher, file string) error {
+func LoadFromFile(m Matcher, file string, processAttr ProcessAttrFunc) error {
var err error
if tmp := strings.SplitN(file, ":", 2); len(tmp) == 2 { // is a v2ray data file
mixMatcher, ok := m.(*MixMatcher)
@@ -80,12 +78,19 @@ func LoadFromFileAsV2Matcher(m Matcher, file string) error {
tmp := strings.Split(tmp[1], "@")
countryCode := tmp[0]
wantedAttr := tmp[1:]
- processAttr := func(attr []string) (v interface{}, accept bool, err error) {
- return nil, mustHaveAttr(attr, wantedAttr), nil
+ v2ProcessAttr := func(attr []string) (v interface{}, accept bool, err error) {
+ v2Accept := mustHaveAttr(attr, wantedAttr)
+ if v2Accept {
+ if processAttr != nil {
+ return processAttr(attr)
+ }
+ return nil, true, nil
+ }
+ return nil, false, nil
}
- err = mixMatcher.LoadFromDAT(filePath, countryCode, processAttr)
+ err = mixMatcher.LoadFromDAT(filePath, countryCode, v2ProcessAttr)
} else { // is a text file
- err = LoadFromTextFile(m, file, nil)
+ err = LoadFromTextFile(m, file, processAttr)
}
if err != nil {
return err
@@ -94,28 +99,6 @@ func LoadFromFileAsV2Matcher(m Matcher, file string) error {
return nil
}
-// BatchLoadMatcher loads multiple files using LoadFromFile
-func BatchLoadMatcher(m Matcher, f []string, processAttr ProcessAttrFunc) error {
- for _, file := range f {
- err := LoadFromFile(m, file, processAttr)
- if err != nil {
- return fmt.Errorf("failed to load file %s: %w", file, err)
- }
- }
- return nil
-}
-
-// BatchLoadMixMatcherV2Matcher loads multiple files using LoadFromFileAsV2Matcher
-func BatchLoadMixMatcherV2Matcher(m Matcher, f []string) error {
- for _, file := range f {
- err := LoadFromFileAsV2Matcher(m, file)
- if err != nil {
- return fmt.Errorf("failed to load file %s: %w", file, err)
- }
- }
- return nil
-}
-
func LoadFromTextFile(m Matcher, file string, processAttr ProcessAttrFunc) error {
data, err := ioutil.ReadFile(file)
if err != nil {
diff --git a/dispatcher/matcher/domain/load_helper_test.go b/dispatcher/pkg/matcher/domain/load_helper_test.go
similarity index 100%
rename from dispatcher/matcher/domain/load_helper_test.go
rename to dispatcher/pkg/matcher/domain/load_helper_test.go
diff --git a/dispatcher/matcher/domain/matcher.go b/dispatcher/pkg/matcher/domain/matcher.go
similarity index 99%
rename from dispatcher/matcher/domain/matcher.go
rename to dispatcher/pkg/matcher/domain/matcher.go
index a0b8bbb67..d1227977a 100644
--- a/dispatcher/matcher/domain/matcher.go
+++ b/dispatcher/pkg/matcher/domain/matcher.go
@@ -19,7 +19,7 @@ package domain
import (
"fmt"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
"regexp"
"strings"
diff --git a/dispatcher/matcher/domain/matcher_test.go b/dispatcher/pkg/matcher/domain/matcher_test.go
similarity index 100%
rename from dispatcher/matcher/domain/matcher_test.go
rename to dispatcher/pkg/matcher/domain/matcher_test.go
diff --git a/dispatcher/matcher/elem/rr_type.go b/dispatcher/pkg/matcher/elem/rr_type.go
similarity index 100%
rename from dispatcher/matcher/elem/rr_type.go
rename to dispatcher/pkg/matcher/elem/rr_type.go
diff --git a/dispatcher/matcher/elem/rr_type_test.go b/dispatcher/pkg/matcher/elem/rr_type_test.go
similarity index 100%
rename from dispatcher/matcher/elem/rr_type_test.go
rename to dispatcher/pkg/matcher/elem/rr_type_test.go
diff --git a/dispatcher/matcher/netlist/interface.go b/dispatcher/pkg/matcher/netlist/interface.go
similarity index 100%
rename from dispatcher/matcher/netlist/interface.go
rename to dispatcher/pkg/matcher/netlist/interface.go
diff --git a/dispatcher/matcher/netlist/list.go b/dispatcher/pkg/matcher/netlist/list.go
similarity index 100%
rename from dispatcher/matcher/netlist/list.go
rename to dispatcher/pkg/matcher/netlist/list.go
diff --git a/dispatcher/matcher/netlist/load_helper.go b/dispatcher/pkg/matcher/netlist/load_helper.go
similarity index 50%
rename from dispatcher/matcher/netlist/load_helper.go
rename to dispatcher/pkg/matcher/netlist/load_helper.go
index b75fdb193..c77f64fa8 100644
--- a/dispatcher/matcher/netlist/load_helper.go
+++ b/dispatcher/pkg/matcher/netlist/load_helper.go
@@ -20,67 +20,49 @@ package netlist
import (
"bufio"
"bytes"
- "errors"
"fmt"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/v2data"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/load_cache"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/v2data"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/golang/protobuf/proto"
"io"
"io/ioutil"
- "net"
"strings"
"time"
)
-var matcherCache = utils.NewCache()
+var matcherCache = load_cache.NewCache()
const (
cacheTTL = time.Second * 30
)
-type MatcherGroup struct {
- m []Matcher
-}
-
-func (mg *MatcherGroup) Match(ip net.IP) bool {
- for _, m := range mg.m {
- if m.Match(ip) {
- return true
+// BatchLoad is a helper func to load multiple files using Load.
+// It might modify the List and causes List unsorted.
+func BatchLoad(l *List, entries []string) error {
+ for _, file := range entries {
+ err := Load(l, file)
+ if err != nil {
+ return fmt.Errorf("failed to load ip file %s: %w", file, err)
}
}
- return false
+ return nil
}
-func NewMatcherGroup(m []Matcher) *MatcherGroup {
- return &MatcherGroup{m: m}
-}
-
-// BatchLoad is helper func to load multiple files using NewListFromFile.
-func BatchLoad(f []string) (m *List, err error) {
- if len(f) == 0 {
- return nil, errors.New("no file to load")
- }
-
- if len(f) == 1 {
- return NewListFromFile(f[0])
+// Load loads data from entry.
+// If entry begin with "ext:", Load loads the file by using LoadFromFile.
+// Else it loads the entry as a text pattern by using LoadFromText.
+func Load(l *List, entry string) error {
+ s1, s2, ok := utils.SplitString2(entry, ":")
+ if ok && s1 == "ext" {
+ return LoadFromFile(l, s2)
}
-
- list := NewList()
- for _, file := range f {
- l, err := NewListFromFile(file)
- if err != nil {
- return nil, fmt.Errorf("failed to load ip file %s: %w", file, err)
- }
- list.Merge(l)
- }
-
- list.Sort()
- return list, nil
+ return LoadFromText(l, entry)
}
-//NewListFromReader read IP list from a reader. The returned *List is sorted.
-func NewListFromReader(reader io.Reader) (*List, error) {
- ipNetList := NewList()
+// LoadFromReader loads IP list from a reader.
+// It might modify the List and causes List unsorted.
+func LoadFromReader(l *List, reader io.Reader) error {
scanner := bufio.NewScanner(reader)
//count how many lines we have read.
@@ -88,69 +70,76 @@ func NewListFromReader(reader io.Reader) (*List, error) {
for scanner.Scan() {
lineCounter++
- s := strings.TrimSpace(utils.BytesToStringUnsafe(scanner.Bytes()))
- s = utils.RemoveComment(s, "#")
- s = utils.RemoveComment(s, " ") // remove other strings, e.g. 192.168.1.1 str1 str2
-
- if len(s) == 0 {
- continue
- }
-
- ipNet, err := ParseCIDR(s)
+ s := utils.BytesToStringUnsafe(scanner.Bytes())
+ err := LoadFromText(l, s)
if err != nil {
- return nil, fmt.Errorf("invalid CIDR format %s in line %d", s, lineCounter)
+ return fmt.Errorf("invalid data at line #%d: %w", lineCounter, err)
}
+ }
+
+ return nil
+}
+
+// LoadFromText loads an IP from s.
+// It might modify the List and causes List unsorted.
+func LoadFromText(l *List, s string) error {
+ s = strings.TrimSpace(s)
+ s = utils.RemoveComment(s, "#")
+ s = utils.RemoveComment(s, " ") // remove other strings, e.g. 192.168.1.1 str1 str2
- ipNetList.Append(ipNet)
+ if len(s) == 0 {
+ return nil
}
- ipNetList.Sort()
- return ipNetList, nil
+ ipNet, err := ParseCIDR(s)
+ if err != nil {
+ return err
+ }
+ l.Append(ipNet)
+ return nil
}
-// NewListFromFile loads ip from a text file or a geoip file.
+// LoadFromFile loads ip from a text file or a geoip file.
// If file contains a ':' and has format like 'geoip:cn', it will be read as a geoip file.
-// The returned *List is already been sorted.
-func NewListFromFile(file string) (*List, error) {
+// It might modify the List and causes List unsorted.
+func LoadFromFile(l *List, file string) error {
if strings.Contains(file, ":") {
tmp := strings.SplitN(file, ":", 2)
- return NewListFromDAT(tmp[0], tmp[1]) // file and tag
+ return LoadFromDAT(l, tmp[0], tmp[1]) // file and tag
} else {
- return NewListFromTextFile(file)
+ return LoadFromTextFile(l, file)
}
}
-// NewListFromTextFile reads IP list from a text file.
-// The returned *List is already been sorted.
-func NewListFromTextFile(file string) (*List, error) {
+// LoadFromTextFile reads IP list from a text file.
+// It might modify the List and causes List unsorted.
+func LoadFromTextFile(l *List, file string) error {
b, err := ioutil.ReadFile(file)
if err != nil {
- return nil, err
+ return err
}
-
- return NewListFromReader(bytes.NewBuffer(b))
+ return LoadFromReader(l, bytes.NewBuffer(b))
}
-// NewListFromDAT loads ip from v2ray proto file.
-// The returned *List is already been sorted.
-func NewListFromDAT(file, tag string) (*List, error) {
+// LoadFromDAT loads ip from v2ray proto file.
+// It might modify the List and causes List unsorted.
+func LoadFromDAT(l *List, file, tag string) error {
geoIP, err := LoadGeoIPFromDAT(file, tag)
if err != nil {
- return nil, err
+ return err
}
- return NewListFromV2CIDR(geoIP.GetCidr())
+ return LoadFromV2CIDR(l, geoIP.GetCidr())
}
-// NewListFromV2CIDR loads ip from v2ray CIDR.
-// The returned *List is already been sorted.
-func NewListFromV2CIDR(cidr []*v2data.CIDR) (*List, error) {
- l := NewList()
- l.Grow(len(cidr))
+// LoadFromV2CIDR loads ip from v2ray CIDR.
+// It might modify the List and causes List unsorted.
+func LoadFromV2CIDR(l *List, cidr []*v2data.CIDR) error {
+ l.Grow(l.Len() + len(cidr))
for i, e := range cidr {
ipv6, err := Conv(e.Ip)
if err != nil {
- return nil, fmt.Errorf("invalid data ip at index #%d, %w", i, err)
+ return fmt.Errorf("invalid data ip at index #%d, %w", i, err)
}
switch len(e.Ip) {
case 4:
@@ -158,12 +147,10 @@ func NewListFromV2CIDR(cidr []*v2data.CIDR) (*List, error) {
case 16:
l.Append(NewNet(ipv6, uint(e.Prefix)))
default:
- return nil, fmt.Errorf("invalid cidr ip length at #%d", i)
+ return fmt.Errorf("invalid cidr ip length at #%d", i)
}
}
-
- l.Sort()
- return l, nil
+ return nil
}
func LoadGeoIPFromDAT(file, tag string) (*v2data.GeoIP, error) {
diff --git a/dispatcher/matcher/netlist/net.go b/dispatcher/pkg/matcher/netlist/net.go
similarity index 98%
rename from dispatcher/matcher/netlist/net.go
rename to dispatcher/pkg/matcher/netlist/net.go
index bfcbd6706..19a1ae522 100644
--- a/dispatcher/matcher/netlist/net.go
+++ b/dispatcher/pkg/matcher/netlist/net.go
@@ -21,7 +21,7 @@ import (
"encoding/binary"
"errors"
"fmt"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"net"
"strconv"
)
diff --git a/dispatcher/matcher/netlist/netlist_test.go b/dispatcher/pkg/matcher/netlist/netlist_test.go
similarity index 96%
rename from dispatcher/matcher/netlist/netlist_test.go
rename to dispatcher/pkg/matcher/netlist/netlist_test.go
index 3e0f26e36..ecf830813 100644
--- a/dispatcher/matcher/netlist/netlist_test.go
+++ b/dispatcher/pkg/matcher/netlist/netlist_test.go
@@ -55,10 +55,12 @@ var (
)
func TestIPNetList_New_And_Contains(t *testing.T) {
- ipNetList, err := NewListFromReader(bytes.NewBufferString(rawList))
+ ipNetList := NewList()
+ err := LoadFromReader(ipNetList, bytes.NewBufferString(rawList))
if err != nil {
t.Fatal(err)
}
+ ipNetList.Sort()
if ipNetList.Len() != 18 {
t.Fatalf("unexpected length %d", ipNetList.Len())
diff --git a/dispatcher/matcher/v2data/data.pb.go b/dispatcher/pkg/matcher/v2data/data.pb.go
similarity index 100%
rename from dispatcher/matcher/v2data/data.pb.go
rename to dispatcher/pkg/matcher/v2data/data.pb.go
diff --git a/dispatcher/matcher/v2data/data.proto b/dispatcher/pkg/matcher/v2data/data.proto
similarity index 100%
rename from dispatcher/matcher/v2data/data.proto
rename to dispatcher/pkg/matcher/v2data/data.proto
diff --git a/dispatcher/pkg/pool/bytes_buf.go b/dispatcher/pkg/pool/bytes_buf.go
new file mode 100644
index 000000000..4d96fd7be
--- /dev/null
+++ b/dispatcher/pkg/pool/bytes_buf.go
@@ -0,0 +1,51 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of mosdns.
+//
+// mosdns is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// mosdns is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package pool
+
+import (
+ "bytes"
+ "fmt"
+ "sync"
+)
+
+type BytesBufPool struct {
+ p sync.Pool
+}
+
+func NewBytesBufPool(initSize int) *BytesBufPool {
+ if initSize < 0 {
+ panic(fmt.Sprintf("utils.NewBytesBufPool: negative init size %d", initSize))
+ }
+
+ return &BytesBufPool{
+ p: sync.Pool{New: func() interface{} {
+ b := new(bytes.Buffer)
+ b.Grow(initSize)
+ return b
+ }},
+ }
+}
+
+func (p *BytesBufPool) Get() *bytes.Buffer {
+ return p.p.Get().(*bytes.Buffer)
+}
+
+func (p *BytesBufPool) Release(b *bytes.Buffer) {
+ b.Reset()
+ p.p.Put(b)
+}
diff --git a/dispatcher/utils/msg_buf.go b/dispatcher/pkg/pool/msg_buf.go
similarity index 99%
rename from dispatcher/utils/msg_buf.go
rename to dispatcher/pkg/pool/msg_buf.go
index 9992ae801..d67b53c9d 100644
--- a/dispatcher/utils/msg_buf.go
+++ b/dispatcher/pkg/pool/msg_buf.go
@@ -18,7 +18,7 @@
// This file is a modified version from github.com/xtaci/smux/blob/master/alloc.go f386d90
// license of smux: MIT https://github.com/xtaci/smux/blob/master/LICENSE
-package utils
+package pool
import (
"fmt"
diff --git a/dispatcher/utils/timer.go b/dispatcher/pkg/pool/timer.go
similarity index 98%
rename from dispatcher/utils/timer.go
rename to dispatcher/pkg/pool/timer.go
index dfc4b13cc..83d60533e 100644
--- a/dispatcher/utils/timer.go
+++ b/dispatcher/pkg/pool/timer.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
-package utils
+package pool
import (
"sync"
diff --git a/dispatcher/utils/server_handler.go b/dispatcher/pkg/server_handler/server_handler.go
similarity index 75%
rename from dispatcher/utils/server_handler.go
rename to dispatcher/pkg/server_handler/server_handler.go
index ea16cb69c..e0de06372 100644
--- a/dispatcher/utils/server_handler.go
+++ b/dispatcher/pkg/server_handler/server_handler.go
@@ -15,11 +15,13 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
-package utils
+package server_handler
import (
"context"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/concurrent_limiter"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/executable_seq"
"github.com/miekg/dns"
"go.uber.org/zap"
"testing"
@@ -38,15 +40,15 @@ type ResponseWriter interface {
type DefaultServerHandler struct {
config *DefaultServerHandlerConfig
- limiter *ConcurrentLimiter // if it's nil, means no limit.
- clientLimiter *ClientQueryLimiter // if it's nil, means no limit.
+ limiter *concurrent_limiter.ConcurrentLimiter // if it's nil, means no limit.
+ clientLimiter *concurrent_limiter.ClientQueryLimiter // if it's nil, means no limit.
}
type DefaultServerHandlerConfig struct {
// Logger is used for logging, it cannot be nil.
Logger *zap.Logger
// Entry is the entry ExecutablePlugin's tag. This shouldn't be empty.
- Entry *ExecutableCmdSequence
+ Entry *executable_seq.ExecutableCmdSequence
// ConcurrentLimit controls the max concurrent queries for the DefaultServerHandler.
// If ConcurrentLimit <= 0, means no limit.
// When calling DefaultServerHandler.ServeDNS(), if a query exceeds the limit, it will wait on a FIFO queue until
@@ -62,22 +64,22 @@ type DefaultServerHandlerConfig struct {
ConcurrentLimitPreClient int
}
-// NewDefaultServerHandler:
+// NewDefaultServerHandler
// Also see DefaultServerHandler.ServeDNS.
func NewDefaultServerHandler(config *DefaultServerHandlerConfig) *DefaultServerHandler {
h := &DefaultServerHandler{config: config}
if config.ConcurrentLimit > 0 {
- h.limiter = NewConcurrentLimiter(config.ConcurrentLimit)
+ h.limiter = concurrent_limiter.NewConcurrentLimiter(config.ConcurrentLimit)
}
if config.ConcurrentLimitPreClient > 0 {
- h.clientLimiter = NewClientQueryLimiter(config.ConcurrentLimitPreClient)
+ h.clientLimiter = concurrent_limiter.NewClientQueryLimiter(config.ConcurrentLimitPreClient)
}
return h
}
-// ServeDNS:
+// ServeDNS
// If entry returns an err, a SERVFAIL response will be sent back to client.
// If concurrentLimit is reached, the query will block and wait available token until ctx is done.
func (h *DefaultServerHandler) ServeDNS(ctx context.Context, qCtx *handler.Context, w ResponseWriter) {
@@ -135,7 +137,7 @@ func (h *DefaultServerHandler) ServeDNS(ctx context.Context, qCtx *handler.Conte
}
func (h *DefaultServerHandler) execEntry(ctx context.Context, qCtx *handler.Context) error {
- err := WalkExecutableCmd(ctx, qCtx, h.config.Logger, h.config.Entry)
+ err := executable_seq.WalkExecutableCmd(ctx, qCtx, h.config.Logger, h.config.Entry)
if err != nil {
return err
}
@@ -164,50 +166,3 @@ func (d *DummyServerHandler) ServeDNS(_ context.Context, qCtx *handler.Context,
d.T.Errorf("DummyServerHandler: %v", err)
}
}
-
-type ClientQueryLimiter struct {
- maxQueries int
- m *ConcurrentMap
-}
-
-func NewClientQueryLimiter(maxQueries int) *ClientQueryLimiter {
- return &ClientQueryLimiter{
- maxQueries: maxQueries,
- m: NewConcurrentMap(64),
- }
-}
-
-func (l *ClientQueryLimiter) Acquire(key string) bool {
- return l.m.TestAndSet(key, l.acquireTestAndSet)
-}
-
-func (l *ClientQueryLimiter) acquireTestAndSet(v interface{}, ok bool) (newV interface{}, wantUpdate, passed bool) {
- n := 0
- if ok {
- n = v.(int)
- }
- if n >= l.maxQueries {
- return nil, false, false
- }
- n++
- return n, true, true
-}
-
-func (l *ClientQueryLimiter) doneTestAndSet(v interface{}, ok bool) (newV interface{}, wantUpdate, passed bool) {
- if !ok {
- panic("ClientQueryLimiter doneTestAndSet: value is not exist")
- }
- n := v.(int)
- n--
- if n < 0 {
- panic("ClientQueryLimiter doneTestAndSet: value becomes negative")
- }
- if n == 0 {
- return nil, true, true
- }
- return n, true, true
-}
-
-func (l *ClientQueryLimiter) Done(key string) {
- l.m.TestAndSet(key, l.doneTestAndSet)
-}
diff --git a/dispatcher/utils/utils.go b/dispatcher/pkg/utils/utils.go
similarity index 76%
rename from dispatcher/utils/utils.go
rename to dispatcher/pkg/utils/utils.go
index 0564d27b9..eb6c65f74 100644
--- a/dispatcher/utils/utils.go
+++ b/dispatcher/pkg/utils/utils.go
@@ -18,7 +18,6 @@
package utils
import (
- "bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
@@ -342,128 +341,3 @@ func ExchangeParallel(ctx context.Context, qCtx *handler.Context, upstreams []Up
// all upstreams are failed
return nil, errors.New("no response")
}
-
-// GetMinimalTTL returns the minimal ttl of this msg.
-// If msg m has no record, it returns 0.
-func GetMinimalTTL(m *dns.Msg) uint32 {
- minTTL := ^uint32(0)
- hasRecord := false
- for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
- for _, rr := range section {
- if rr.Header().Rrtype == dns.TypeOPT {
- continue // opt record ttl is not ttl.
- }
- hasRecord = true
- ttl := rr.Header().Ttl
- if ttl < minTTL {
- minTTL = ttl
- }
- }
- }
-
- if !hasRecord { // no ttl applied
- return 0
- }
- return minTTL
-}
-
-// SetTTL updates all records' ttl to ttl, except opt record.
-func SetTTL(m *dns.Msg, ttl uint32) {
- for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
- for _, rr := range section {
- if rr.Header().Rrtype == dns.TypeOPT {
- continue // opt record ttl is not ttl.
- }
- rr.Header().Ttl = ttl
- }
- }
-}
-
-func ApplyMaximumTTL(m *dns.Msg, ttl uint32) {
- applyTTL(m, ttl, true)
-}
-
-func ApplyMinimalTTL(m *dns.Msg, ttl uint32) {
- applyTTL(m, ttl, false)
-}
-
-func applyTTL(m *dns.Msg, ttl uint32, maximum bool) {
- for _, section := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
- for _, rr := range section {
- if rr.Header().Rrtype == dns.TypeOPT {
- continue // opt record ttl is not ttl.
- }
- if maximum {
- if rr.Header().Ttl > ttl {
- rr.Header().Ttl = ttl
- }
- } else {
- if rr.Header().Ttl < ttl {
- rr.Header().Ttl = ttl
- }
- }
- }
- }
-}
-
-type BytesBufPool struct {
- p sync.Pool
-}
-
-func NewBytesBufPool(initSize int) *BytesBufPool {
- if initSize < 0 {
- panic(fmt.Sprintf("utils.NewBytesBufPool: negative init size %d", initSize))
- }
-
- return &BytesBufPool{
- p: sync.Pool{New: func() interface{} {
- b := new(bytes.Buffer)
- b.Grow(initSize)
- return b
- }},
- }
-}
-
-func (p *BytesBufPool) Get() *bytes.Buffer {
- return p.p.Get().(*bytes.Buffer)
-}
-
-func (p *BytesBufPool) Release(b *bytes.Buffer) {
- b.Reset()
- p.p.Put(b)
-}
-
-// ConcurrentLimiter
-type ConcurrentLimiter struct {
- bucket chan struct{}
-}
-
-// NewConcurrentLimiter returns a ConcurrentLimiter, max must > 0.
-func NewConcurrentLimiter(max int) *ConcurrentLimiter {
- if max <= 0 {
- panic(fmt.Sprintf("ConcurrentLimiter: invalid max arg: %d", max))
- }
-
- bucket := make(chan struct{}, max)
- for i := 0; i < max; i++ {
- bucket <- struct{}{}
- }
-
- return &ConcurrentLimiter{bucket: bucket}
-}
-
-func (l *ConcurrentLimiter) Wait() <-chan struct{} {
- return l.bucket
-}
-
-func (l *ConcurrentLimiter) Done() {
- select {
- case l.bucket <- struct{}{}:
- default:
- panic("ConcurrentLimiter: bucket overflow")
- }
-}
-
-func (l *ConcurrentLimiter) Available() int {
- return len(l.bucket)
-}
diff --git a/dispatcher/utils/utils_test.go b/dispatcher/pkg/utils/utils_test.go
similarity index 90%
rename from dispatcher/utils/utils_test.go
rename to dispatcher/pkg/utils/utils_test.go
index 599763b83..20b944168 100644
--- a/dispatcher/utils/utils_test.go
+++ b/dispatcher/pkg/utils/utils_test.go
@@ -21,11 +21,10 @@ import (
"context"
"errors"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/concurrent_limiter"
"github.com/miekg/dns"
"reflect"
- "sync"
"testing"
- "time"
)
func TestBoolLogic(t *testing.T) {
@@ -189,39 +188,13 @@ func Test_NewConcurrentLimiter(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if got := NewConcurrentLimiter(tt.args.max); got.Available() != tt.wantLen {
+ if got := concurrent_limiter.NewConcurrentLimiter(tt.args.max); got.Available() != tt.wantLen {
t.Errorf("NewConcurrentLimiter() = %v, want %v", got.Available(), tt.wantLen)
}
})
}
}
-func Test_ConcurrentLimiter_acquire_release(t *testing.T) {
- l := NewConcurrentLimiter(500)
- ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
- defer cancel()
-
- wg := new(sync.WaitGroup)
- wg.Add(1000)
- for i := 0; i < 1000; i++ {
- go func() {
- defer wg.Done()
- select {
- case <-l.Wait():
- time.Sleep(time.Millisecond * 200)
- l.Done()
- case <-ctx.Done():
- t.Fail()
- }
- }()
- }
-
- wg.Wait()
- if l.Available() != 500 {
- t.Fatal("token leaked")
- }
-}
-
func TestRemoveComment(t *testing.T) {
type args struct {
s string
diff --git a/dispatcher/plugin/cache/mem_cache.go b/dispatcher/plugin/cache/mem_cache.go
index 58928f314..58dc071a3 100644
--- a/dispatcher/plugin/cache/mem_cache.go
+++ b/dispatcher/plugin/cache/mem_cache.go
@@ -19,7 +19,7 @@ package cache
import (
"context"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/concurrent_lru"
"github.com/miekg/dns"
"sync"
"time"
@@ -31,7 +31,7 @@ type memCache struct {
closeOnce sync.Once
closeChan chan struct{}
- lru *utils.ConcurrentLRU
+ lru *concurrent_lru.ConcurrentLRU
}
type elem struct {
@@ -45,7 +45,7 @@ type elem struct {
func newMemCache(shardNum, maxSizePerShard int, cleanerInterval time.Duration) *memCache {
c := &memCache{
cleanerInterval: cleanerInterval,
- lru: utils.NewConcurrentLRU(shardNum, maxSizePerShard, nil, nil),
+ lru: concurrent_lru.NewConcurrentLRU(shardNum, maxSizePerShard, nil, nil),
}
if c.cleanerInterval > 0 {
diff --git a/dispatcher/plugin/cache/plugin.go b/dispatcher/plugin/cache/plugin.go
index b77575757..20d274620 100644
--- a/dispatcher/plugin/cache/plugin.go
+++ b/dispatcher/plugin/cache/plugin.go
@@ -20,7 +20,8 @@ import (
"context"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
"go.uber.org/zap"
"time"
@@ -130,7 +131,7 @@ func (c *cachePlugin) searchAndReply(ctx context.Context, qCtx *handler.Context)
if r != nil { // if cache hit
c.L().Debug("cache hit", qCtx.InfoField())
r.Id = q.Id
- utils.SetTTL(r, uint32(ttl/time.Second))
+ dnsutils.SetTTL(r, uint32(ttl/time.Second))
qCtx.SetResponse(r, handler.ContextStatusResponded)
return key, true
}
@@ -158,7 +159,7 @@ func (d *deferCacheStore) Exec(ctx context.Context, qCtx *handler.Context) (err
func (d *deferCacheStore) exec(ctx context.Context, qCtx *handler.Context) (err error) {
r := qCtx.R()
if r != nil && r.Rcode == dns.RcodeSuccess && r.Truncated == false && len(r.Answer) != 0 {
- ttl := utils.GetMinimalTTL(r)
+ ttl := dnsutils.GetMinimalTTL(r)
if ttl > maxTTL {
ttl = maxTTL
}
diff --git a/dispatcher/plugin/cache/redis_cache.go b/dispatcher/plugin/cache/redis_cache.go
index aee5e7ea5..311f5773f 100644
--- a/dispatcher/plugin/cache/redis_cache.go
+++ b/dispatcher/plugin/cache/redis_cache.go
@@ -19,7 +19,7 @@ package cache
import (
"context"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool"
"github.com/go-redis/redis/v8"
"github.com/miekg/dns"
"time"
@@ -63,10 +63,10 @@ func (r *redisCache) get(ctx context.Context, key string) (v *dns.Msg, ttl time.
}
func (r *redisCache) store(ctx context.Context, key string, v *dns.Msg, ttl time.Duration) (err error) {
- wireMsg, buf, err := utils.PackBuffer(v)
+ wireMsg, buf, err := pool.PackBuffer(v)
if err != nil {
return err
}
- defer utils.ReleaseMsgBuf(buf)
+ defer pool.ReleaseMsgBuf(buf)
return r.client.Set(ctx, key, wireMsg, ttl).Err()
}
diff --git a/dispatcher/plugin/executable/ecs/ecs.go b/dispatcher/plugin/executable/ecs/ecs.go
index e371ba236..3e6e49ec6 100644
--- a/dispatcher/plugin/executable/ecs/ecs.go
+++ b/dispatcher/plugin/executable/ecs/ecs.go
@@ -21,7 +21,7 @@ import (
"context"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
"go.uber.org/zap"
"net"
diff --git a/dispatcher/plugin/executable/ecs/ecs_test.go b/dispatcher/plugin/executable/ecs/ecs_test.go
index c1a93b59d..6cf518643 100644
--- a/dispatcher/plugin/executable/ecs/ecs_test.go
+++ b/dispatcher/plugin/executable/ecs/ecs_test.go
@@ -20,7 +20,7 @@ package ecs
import (
"context"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
"gopkg.in/yaml.v3"
"net"
diff --git a/dispatcher/plugin/executable/fallback/fallback.go b/dispatcher/plugin/executable/fallback/fallback.go
index afc5e7f82..7b022f893 100644
--- a/dispatcher/plugin/executable/fallback/fallback.go
+++ b/dispatcher/plugin/executable/fallback/fallback.go
@@ -20,7 +20,7 @@ package fallback
import (
"context"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/executable_seq"
)
const PluginType = "fallback"
@@ -34,17 +34,17 @@ var _ handler.ExecutablePlugin = (*fallback)(nil)
type fallback struct {
*handler.BP
- fallbackECS *utils.FallbackECS
+ fallbackECS *executable_seq.FallbackECS
}
-type Args = utils.FallbackConfig
+type Args = executable_seq.FallbackConfig
func Init(bp *handler.BP, args interface{}) (p handler.Plugin, err error) {
return newFallback(bp, args.(*Args))
}
func newFallback(bp *handler.BP, args *Args) (*fallback, error) {
- fallbackECS, err := utils.ParseFallbackECS(args)
+ fallbackECS, err := executable_seq.ParseFallbackECS(args)
if err != nil {
return nil, err
}
@@ -55,5 +55,5 @@ func newFallback(bp *handler.BP, args *Args) (*fallback, error) {
}
func (f *fallback) Exec(ctx context.Context, qCtx *handler.Context) (err error) {
- return utils.WalkExecutableCmd(ctx, qCtx, f.L(), f.fallbackECS)
+ return executable_seq.WalkExecutableCmd(ctx, qCtx, f.L(), f.fallbackECS)
}
diff --git a/dispatcher/plugin/executable/fast_forward/fast_forward.go b/dispatcher/plugin/executable/fast_forward/fast_forward.go
index 004fd34be..f7dc8e5fe 100644
--- a/dispatcher/plugin/executable/fast_forward/fast_forward.go
+++ b/dispatcher/plugin/executable/fast_forward/fast_forward.go
@@ -23,7 +23,7 @@ import (
"errors"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
"time"
)
diff --git a/dispatcher/plugin/executable/fast_forward/transport.go b/dispatcher/plugin/executable/fast_forward/transport.go
index a30d6c2be..6d0168051 100644
--- a/dispatcher/plugin/executable/fast_forward/transport.go
+++ b/dispatcher/plugin/executable/fast_forward/transport.go
@@ -20,7 +20,7 @@ package fastforward
import (
"errors"
"fmt"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool"
"github.com/miekg/dns"
"go.uber.org/zap"
"io"
@@ -143,8 +143,8 @@ func (t *transport) getConn() (conn *dnsConn, reusedConn bool, err error) {
panic("Transport getConn: dCall is nil")
}
- timer := utils.GetTimer(t.readTimeout)
- defer utils.ReleaseTimer(timer)
+ timer := pool.GetTimer(t.readTimeout)
+ defer pool.ReleaseTimer(timer)
select {
case <-timer.C:
return nil, false, errDialTimeout
@@ -234,8 +234,8 @@ func (c *dnsConn) exchange(m *dns.Msg) (r *dns.Msg, err error) {
return nil, err
}
- timer := utils.GetTimer(c.t.readTimeout)
- defer utils.ReleaseTimer(timer)
+ timer := pool.GetTimer(c.t.readTimeout)
+ defer pool.ReleaseTimer(timer)
select {
case <-timer.C:
diff --git a/dispatcher/plugin/executable/fast_forward/upstream.go b/dispatcher/plugin/executable/fast_forward/upstream.go
index 9da0fdefb..7e96d9958 100644
--- a/dispatcher/plugin/executable/fast_forward/upstream.go
+++ b/dispatcher/plugin/executable/fast_forward/upstream.go
@@ -24,7 +24,8 @@ import (
"errors"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
"go.uber.org/zap"
"golang.org/x/net/http2"
@@ -139,9 +140,9 @@ func newFastUpstream(config *UpstreamConfig, logger *zap.Logger, certPool *x509.
func() (net.Conn, error) {
return u.dialTimeout("udp", dialTimeout)
},
- utils.WriteMsgToUDP,
+ dnsutils.WriteMsgToUDP,
func(c io.Reader) (m *dns.Msg, n int, err error) {
- return utils.ReadMsgFromUDP(c, utils.IPv4UdpMaxPayload)
+ return dnsutils.ReadMsgFromUDP(c, dnsutils.IPv4UdpMaxPayload)
},
maxConn,
time.Second*30,
@@ -182,8 +183,8 @@ func newFastUpstream(config *UpstreamConfig, logger *zap.Logger, certPool *x509.
u.tcpTransport = newTransport(
logger,
dialFunc,
- utils.WriteMsgToTCP,
- utils.ReadMsgFromTCP,
+ dnsutils.WriteMsgToTCP,
+ dnsutils.ReadMsgFromTCP,
maxConn,
idleTimeout,
timeout,
diff --git a/dispatcher/plugin/executable/fast_forward/upstream_doh.go b/dispatcher/plugin/executable/fast_forward/upstream_doh.go
index e7714ae8e..76c0a4a45 100644
--- a/dispatcher/plugin/executable/fast_forward/upstream_doh.go
+++ b/dispatcher/plugin/executable/fast_forward/upstream_doh.go
@@ -21,23 +21,23 @@ import (
"context"
"encoding/base64"
"fmt"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool"
"github.com/miekg/dns"
"io"
"net/http"
)
var (
- bufPool512 = utils.NewBytesBufPool(512)
+ bufPool512 = pool.NewBytesBufPool(512)
)
func (u *fastUpstream) exchangeDoH(q *dns.Msg) (r *dns.Msg, err error) {
- rRaw, buf, err := utils.PackBuffer(q)
+ rRaw, buf, err := pool.PackBuffer(q)
if err != nil {
return nil, err
}
- defer utils.ReleaseMsgBuf(buf)
+ defer pool.ReleaseMsgBuf(buf)
// In order to maximize HTTP cache friendliness, DoH clients using media
// formats that include the ID field from the DNS message header, such
diff --git a/dispatcher/plugin/executable/fast_forward/upstream_test.go b/dispatcher/plugin/executable/fast_forward/upstream_test.go
index 17d42d5c0..a03b3da6f 100644
--- a/dispatcher/plugin/executable/fast_forward/upstream_test.go
+++ b/dispatcher/plugin/executable/fast_forward/upstream_test.go
@@ -23,7 +23,7 @@ import (
"fmt"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
"go.uber.org/zap"
"net"
diff --git a/dispatcher/plugin/executable/forward/forward.go b/dispatcher/plugin/executable/forward/forward.go
index f78b0d0dc..44fe9efb4 100644
--- a/dispatcher/plugin/executable/forward/forward.go
+++ b/dispatcher/plugin/executable/forward/forward.go
@@ -24,7 +24,7 @@ import (
"github.com/AdguardTeam/dnsproxy/fastip"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
"net"
"time"
diff --git a/dispatcher/plugin/executable/parallel/parallel.go b/dispatcher/plugin/executable/parallel/parallel.go
index 08c4cbbad..e654f6486 100644
--- a/dispatcher/plugin/executable/parallel/parallel.go
+++ b/dispatcher/plugin/executable/parallel/parallel.go
@@ -20,7 +20,7 @@ package parallel
import (
"context"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/executable_seq"
)
const PluginType = "parallel"
@@ -34,17 +34,17 @@ var _ handler.ExecutablePlugin = (*parallel)(nil)
type parallel struct {
*handler.BP
- ps *utils.ParallelECS
+ ps *executable_seq.ParallelECS
}
-type Args = utils.ParallelECSConfig
+type Args = executable_seq.ParallelECSConfig
func Init(bp *handler.BP, args interface{}) (p handler.Plugin, err error) {
return newParallel(bp, args.(*Args))
}
func newParallel(bp *handler.BP, args *Args) (*parallel, error) {
- ps, err := utils.ParseParallelECS(args)
+ ps, err := executable_seq.ParseParallelECS(args)
if err != nil {
return nil, err
}
@@ -56,5 +56,5 @@ func newParallel(bp *handler.BP, args *Args) (*parallel, error) {
}
func (p *parallel) Exec(ctx context.Context, qCtx *handler.Context) (err error) {
- return utils.WalkExecutableCmd(ctx, qCtx, p.L(), p.ps)
+ return executable_seq.WalkExecutableCmd(ctx, qCtx, p.L(), p.ps)
}
diff --git a/dispatcher/plugin/executable/sequence/sequence.go b/dispatcher/plugin/executable/sequence/sequence.go
index eb33b873f..b2cfffe97 100644
--- a/dispatcher/plugin/executable/sequence/sequence.go
+++ b/dispatcher/plugin/executable/sequence/sequence.go
@@ -21,7 +21,7 @@ import (
"context"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/executable_seq"
)
const PluginType = "sequence"
@@ -37,7 +37,7 @@ var _ handler.ExecutablePlugin = (*sequenceRouter)(nil)
type sequenceRouter struct {
*handler.BP
- ecs *utils.ExecutableCmdSequence
+ ecs *executable_seq.ExecutableCmdSequence
}
type Args struct {
@@ -49,7 +49,7 @@ func Init(bp *handler.BP, args interface{}) (p handler.Plugin, err error) {
}
func newSequencePlugin(bp *handler.BP, args *Args) (*sequenceRouter, error) {
- ecs, err := utils.ParseExecutableCmdSequence(args.Exec)
+ ecs, err := executable_seq.ParseExecutableCmdSequence(args.Exec)
if err != nil {
return nil, fmt.Errorf("invalid exec squence: %w", err)
}
@@ -61,7 +61,7 @@ func newSequencePlugin(bp *handler.BP, args *Args) (*sequenceRouter, error) {
}
func (s *sequenceRouter) Exec(ctx context.Context, qCtx *handler.Context) (err error) {
- return utils.WalkExecutableCmd(ctx, qCtx, s.L(), s.ecs)
+ return executable_seq.WalkExecutableCmd(ctx, qCtx, s.L(), s.ecs)
}
var _ handler.ExecutablePlugin = (*noop)(nil)
diff --git a/dispatcher/plugin/executable/sleep/sleep.go b/dispatcher/plugin/executable/sleep/sleep.go
index 803b1f603..c327cfa05 100644
--- a/dispatcher/plugin/executable/sleep/sleep.go
+++ b/dispatcher/plugin/executable/sleep/sleep.go
@@ -20,7 +20,7 @@ package sleep
import (
"context"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool"
"time"
)
@@ -51,8 +51,8 @@ func (s *sleep) sleep(ctx context.Context) (err error) {
return
}
- timer := utils.GetTimer(s.d)
- defer utils.ReleaseTimer(timer)
+ timer := pool.GetTimer(s.d)
+ defer pool.ReleaseTimer(timer)
select {
case <-timer.C:
case <-ctx.Done():
diff --git a/dispatcher/plugin/executable/ttl/ttl.go b/dispatcher/plugin/executable/ttl/ttl.go
index 55cad8ff8..09973da88 100644
--- a/dispatcher/plugin/executable/ttl/ttl.go
+++ b/dispatcher/plugin/executable/ttl/ttl.go
@@ -20,7 +20,7 @@ package ttl
import (
"context"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils"
"github.com/miekg/dns"
)
@@ -64,9 +64,9 @@ func (t ttl) Exec(ctx context.Context, qCtx *handler.Context) (err error) {
func (t ttl) exec(r *dns.Msg) {
if t.args.MaximumTTL > 0 {
- utils.ApplyMaximumTTL(r, t.args.MaximumTTL)
+ dnsutils.ApplyMaximumTTL(r, t.args.MaximumTTL)
}
if t.args.MinimalTTL > 0 {
- utils.ApplyMinimalTTL(r, t.args.MinimalTTL)
+ dnsutils.ApplyMinimalTTL(r, t.args.MinimalTTL)
}
}
diff --git a/dispatcher/plugin/hosts/hosts.go b/dispatcher/plugin/hosts/hosts.go
index f6f718c74..2153d6c04 100644
--- a/dispatcher/plugin/hosts/hosts.go
+++ b/dispatcher/plugin/hosts/hosts.go
@@ -19,10 +19,9 @@ package hosts
import (
"context"
- "errors"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain"
"github.com/miekg/dns"
"net"
)
@@ -59,10 +58,6 @@ var patternTypeMap = map[string]domain.MixMatcherPatternType{
}
func newHostsContainer(bp *handler.BP, args *Args) (*hostsContainer, error) {
- if len(args.Hosts) == 0 {
- return nil, errors.New("no hosts file is configured")
- }
-
mixMatcher := domain.NewMixMatcher()
mixMatcher.SetPattenTypeMap(patternTypeMap)
err := domain.BatchLoadMatcher(mixMatcher, args.Hosts, parseIP)
@@ -90,8 +85,14 @@ func (h *hostsContainer) matchAndSet(qCtx *handler.Context) (matched bool) {
if len(qCtx.Q().Question) != 1 {
return false
}
-
+ if qCtx.Q().Question[0].Qclass != dns.ClassINET {
+ return false
+ }
typ := qCtx.Q().Question[0].Qtype
+ if typ != dns.TypeA && typ != dns.TypeAAAA {
+ return false
+ }
+
fqdn := qCtx.Q().Question[0].Name
v, ok := h.matcher.Match(fqdn)
if !ok {
@@ -99,50 +100,46 @@ func (h *hostsContainer) matchAndSet(qCtx *handler.Context) (matched bool) {
}
record := v.(*ipRecord)
- switch typ {
- case dns.TypeA:
- if len(record.ipv4) != 0 {
- r := new(dns.Msg)
- r.SetReply(qCtx.Q())
- for _, ip := range record.ipv4 {
- ipCopy := make(net.IP, len(ip))
- copy(ipCopy, ip)
- rr := &dns.A{
- Hdr: dns.RR_Header{
- Name: fqdn,
- Rrtype: dns.TypeA,
- Class: dns.ClassINET,
- Ttl: 3600,
- },
- A: ipCopy,
- }
- r.Answer = append(r.Answer, rr)
+ switch {
+ case typ == dns.TypeA && len(record.ipv4) > 0:
+ r := new(dns.Msg)
+ r.SetReply(qCtx.Q())
+ for _, ip := range record.ipv4 {
+ ipCopy := make(net.IP, len(ip))
+ copy(ipCopy, ip)
+ rr := &dns.A{
+ Hdr: dns.RR_Header{
+ Name: fqdn,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ Ttl: 3600,
+ },
+ A: ipCopy,
}
- qCtx.SetResponse(r, handler.ContextStatusResponded)
- return true
+ r.Answer = append(r.Answer, rr)
}
-
- case dns.TypeAAAA:
- if len(record.ipv6) != 0 {
- r := new(dns.Msg)
- r.SetReply(qCtx.Q())
- for _, ip := range record.ipv6 {
- ipCopy := make(net.IP, len(ip))
- copy(ipCopy, ip)
- rr := &dns.AAAA{
- Hdr: dns.RR_Header{
- Name: fqdn,
- Rrtype: dns.TypeAAAA,
- Class: dns.ClassINET,
- Ttl: 3600,
- },
- AAAA: ipCopy,
- }
- r.Answer = append(r.Answer, rr)
+ qCtx.SetResponse(r, handler.ContextStatusResponded)
+ return true
+
+ case typ == dns.TypeAAAA && len(record.ipv6) > 0:
+ r := new(dns.Msg)
+ r.SetReply(qCtx.Q())
+ for _, ip := range record.ipv6 {
+ ipCopy := make(net.IP, len(ip))
+ copy(ipCopy, ip)
+ rr := &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: fqdn,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ Ttl: 3600,
+ },
+ AAAA: ipCopy,
}
- qCtx.SetResponse(r, handler.ContextStatusResponded)
- return true
+ r.Answer = append(r.Answer, rr)
}
+ qCtx.SetResponse(r, handler.ContextStatusResponded)
+ return true
}
return false
}
diff --git a/dispatcher/plugin/hosts/hosts_test.go b/dispatcher/plugin/hosts/hosts_test.go
index 5a30ac527..3678fdab4 100644
--- a/dispatcher/plugin/hosts/hosts_test.go
+++ b/dispatcher/plugin/hosts/hosts_test.go
@@ -20,7 +20,7 @@ package hosts
import (
"bytes"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain"
"github.com/miekg/dns"
"net"
"testing"
diff --git a/dispatcher/plugin/matcher/query_matcher/matcher_group.go b/dispatcher/plugin/matcher/query_matcher/matcher_group.go
index 152757680..513d8ac4b 100644
--- a/dispatcher/plugin/matcher/query_matcher/matcher_group.go
+++ b/dispatcher/plugin/matcher/query_matcher/matcher_group.go
@@ -21,10 +21,10 @@ import (
"context"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/elem"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/netlist"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/elem"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/netlist"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
)
type clientIPMatcher struct {
diff --git a/dispatcher/plugin/matcher/query_matcher/query_matcher.go b/dispatcher/plugin/matcher/query_matcher/query_matcher.go
index 61cee17de..3190307d4 100644
--- a/dispatcher/plugin/matcher/query_matcher/query_matcher.go
+++ b/dispatcher/plugin/matcher/query_matcher/query_matcher.go
@@ -21,10 +21,10 @@ import (
"context"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/elem"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/netlist"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/elem"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/netlist"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
)
@@ -73,15 +73,17 @@ func newQueryMatcher(bp *handler.BP, args *Args) (m *queryMatcher, err error) {
m.args = args
if len(args.ClientIP) > 0 {
- ipMatcher, err := netlist.BatchLoad(args.ClientIP)
+ ipMatcher := netlist.NewList()
+ err := netlist.BatchLoad(ipMatcher, args.ClientIP)
if err != nil {
return nil, err
}
+ ipMatcher.Sort()
m.matcherGroup = append(m.matcherGroup, newClientIPMatcher(ipMatcher))
}
if len(args.Domain) > 0 {
mixMatcher := domain.NewMixMatcher()
- err := domain.BatchLoadMixMatcherV2Matcher(mixMatcher, args.Domain)
+ err := domain.BatchLoadMatcher(mixMatcher, args.Domain, nil)
if err != nil {
return nil, err
}
diff --git a/dispatcher/plugin/matcher/response_matcher/matcher_group.go b/dispatcher/plugin/matcher/response_matcher/matcher_group.go
index 4609f6b53..a1e2036ce 100644
--- a/dispatcher/plugin/matcher/response_matcher/matcher_group.go
+++ b/dispatcher/plugin/matcher/response_matcher/matcher_group.go
@@ -20,9 +20,9 @@ package responsematcher
import (
"context"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/elem"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/netlist"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/elem"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/netlist"
"github.com/miekg/dns"
"net"
)
diff --git a/dispatcher/plugin/matcher/response_matcher/response_matcher.go b/dispatcher/plugin/matcher/response_matcher/response_matcher.go
index 1b4a33737..cd2c0a388 100644
--- a/dispatcher/plugin/matcher/response_matcher/response_matcher.go
+++ b/dispatcher/plugin/matcher/response_matcher/response_matcher.go
@@ -21,10 +21,10 @@ import (
"context"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/elem"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/netlist"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/elem"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/netlist"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
)
@@ -71,7 +71,7 @@ func newResponseMatcher(bp *handler.BP, args *Args) (m *responseMatcher, err err
if len(args.CNAME) > 0 {
mixMatcher := domain.NewMixMatcher()
- err := domain.BatchLoadMixMatcherV2Matcher(mixMatcher, args.CNAME)
+ err := domain.BatchLoadMatcher(mixMatcher, args.CNAME, nil)
if err != nil {
return nil, err
}
@@ -79,10 +79,12 @@ func newResponseMatcher(bp *handler.BP, args *Args) (m *responseMatcher, err err
}
if len(args.IP) > 0 {
- ipMatcher, err := netlist.BatchLoad(args.IP)
+ ipMatcher := netlist.NewList()
+ err := netlist.BatchLoad(ipMatcher, args.IP)
if err != nil {
return nil, err
}
+ ipMatcher.Sort()
m.matcherGroup = append(m.matcherGroup, newResponseIPMatcher(ipMatcher))
}
diff --git a/dispatcher/plugin/server/doh.go b/dispatcher/plugin/server/doh.go
index 2fef01c82..3e759d1f4 100644
--- a/dispatcher/plugin/server/doh.go
+++ b/dispatcher/plugin/server/doh.go
@@ -23,7 +23,9 @@ import (
"errors"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/pool"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
"go.uber.org/zap"
"io"
@@ -123,8 +125,8 @@ func getMsgFromReq(req *http.Request) (*dns.Msg, error) {
if msgSize > dns.MaxMsgSize {
return nil, fmt.Errorf("query length %d is too big", msgSize)
}
- msgBuf := utils.GetMsgBuf(msgSize)
- defer utils.ReleaseMsgBuf(msgBuf)
+ msgBuf := pool.GetMsgBuf(msgSize)
+ defer pool.ReleaseMsgBuf(msgBuf)
strBuf := readBufPool.Get()
defer readBufPool.Release(strBuf)
@@ -155,12 +157,12 @@ func getMsgFromReq(req *http.Request) (*dns.Msg, error) {
return q, nil
}
-var readBufPool = utils.NewBytesBufPool(512)
+var readBufPool = pool.NewBytesBufPool(512)
type httpDnsRespWriter struct {
httpRespWriter http.ResponseWriter
}
func (h *httpDnsRespWriter) Write(m *dns.Msg) (n int, err error) {
- return utils.WriteMsgToUDP(h.httpRespWriter, m)
+ return dnsutils.WriteMsgToUDP(h.httpRespWriter, m)
}
diff --git a/dispatcher/plugin/server/server.go b/dispatcher/plugin/server/server.go
index c1e935929..816dbd051 100644
--- a/dispatcher/plugin/server/server.go
+++ b/dispatcher/plugin/server/server.go
@@ -21,7 +21,9 @@ import (
"errors"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/executable_seq"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/server_handler"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"io"
"sync"
"time"
@@ -37,7 +39,7 @@ type ServerGroup struct {
*handler.BP
configs []*Server
- handler utils.ServerHandler
+ handler server_handler.ServerHandler
m sync.Mutex
activated bool
@@ -98,12 +100,12 @@ func newServerPlugin(bp *handler.BP, args *Args) (*ServerGroup, error) {
return nil, errors.New("empty entry")
}
- ecs, err := utils.ParseExecutableCmdSequence(args.Entry)
+ ecs, err := executable_seq.ParseExecutableCmdSequence(args.Entry)
if err != nil {
return nil, err
}
- sh := utils.NewDefaultServerHandler(&utils.DefaultServerHandlerConfig{
+ sh := server_handler.NewDefaultServerHandler(&server_handler.DefaultServerHandlerConfig{
Logger: bp.L(),
Entry: ecs,
ConcurrentLimit: args.MaxConcurrentQueries,
@@ -123,7 +125,7 @@ func newServerPlugin(bp *handler.BP, args *Args) (*ServerGroup, error) {
return sg, nil
}
-func NewServerGroup(bp *handler.BP, handler utils.ServerHandler, configs []*Server) *ServerGroup {
+func NewServerGroup(bp *handler.BP, handler server_handler.ServerHandler, configs []*Server) *ServerGroup {
s := &ServerGroup{
BP: bp,
configs: configs,
diff --git a/dispatcher/plugin/server/server_test.go b/dispatcher/plugin/server/server_test.go
index e98478200..499dd8b12 100644
--- a/dispatcher/plugin/server/server_test.go
+++ b/dispatcher/plugin/server/server_test.go
@@ -20,7 +20,7 @@ package server
import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/server_handler"
"github.com/miekg/dns"
"testing"
"time"
@@ -46,7 +46,7 @@ func TestUdpServer_ListenAndServe(t *testing.T) {
return
}
func() {
- sg := NewServerGroup(handler.NewBP("test", PluginType), &utils.DummyServerHandler{T: t}, []*Server{tt.config})
+ sg := NewServerGroup(handler.NewBP("test", PluginType), &server_handler.DummyServerHandler{T: t}, []*Server{tt.config})
if err := sg.Activate(); err != nil {
t.Fatal(err)
}
diff --git a/dispatcher/plugin/server/tcp.go b/dispatcher/plugin/server/tcp.go
index ba35d1c0b..1221f2c33 100644
--- a/dispatcher/plugin/server/tcp.go
+++ b/dispatcher/plugin/server/tcp.go
@@ -23,7 +23,7 @@ import (
"errors"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils"
"github.com/miekg/dns"
"go.uber.org/zap"
"net"
@@ -40,7 +40,7 @@ type tcpResponseWriter struct {
func (t *tcpResponseWriter) Write(m *dns.Msg) (n int, err error) {
t.c.SetWriteDeadline(time.Now().Add(serverTCPWriteTimeout))
- return utils.WriteMsgToTCP(t.c, m)
+ return dnsutils.WriteMsgToTCP(t.c, m)
}
// remainder: startTCP should be called only after ServerGroup is locked.
@@ -93,7 +93,7 @@ func (sg *ServerGroup) startTCP(conf *Server, isDoT bool) error {
for {
c.SetReadDeadline(time.Now().Add(conf.idleTimeout))
- q, _, err := utils.ReadMsgFromTCP(c)
+ q, _, err := dnsutils.ReadMsgFromTCP(c)
if err != nil {
return // read err, close the conn
}
diff --git a/dispatcher/plugin/server/udp.go b/dispatcher/plugin/server/udp.go
index cab4ee0fa..f2753f051 100644
--- a/dispatcher/plugin/server/udp.go
+++ b/dispatcher/plugin/server/udp.go
@@ -21,7 +21,7 @@ import (
"context"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/dnsutils"
"github.com/miekg/dns"
"go.uber.org/zap"
"net"
@@ -44,7 +44,7 @@ func getMaxSizeFromQuery(m *dns.Msg) int {
func (u *udpResponseWriter) Write(m *dns.Msg) (n int, err error) {
m.Truncate(u.maxSize)
- return utils.WriteUDPMsgTo(m, u.c, u.to)
+ return dnsutils.WriteUDPMsgTo(m, u.c, u.to)
}
// remainder: startUDP should be called only after ServerGroup is locked.
@@ -63,9 +63,9 @@ func (sg *ServerGroup) startUDP(conf *Server) error {
defer cancel()
for {
- q, from, _, err := utils.ReadUDPMsgFrom(c, utils.IPv4UdpMaxPayload)
+ q, from, _, err := dnsutils.ReadUDPMsgFrom(c, dnsutils.IPv4UdpMaxPayload)
if err != nil {
- if ioErr := utils.IsIOErr(err); ioErr != nil {
+ if ioErr := dnsutils.IsIOErr(err); ioErr != nil {
if netErr, ok := ioErr.(net.Error); ok && netErr.Temporary() { // is a temporary net err
sg.L().Warn("listener temporary err", zap.Stringer("addr", c.LocalAddr()), zap.Error(err))
time.Sleep(time.Second * 5)
diff --git a/dispatcher/utils/executable_sequence.go b/dispatcher/utils/executable_sequence.go
deleted file mode 100644
index ca6a02b77..000000000
--- a/dispatcher/utils/executable_sequence.go
+++ /dev/null
@@ -1,631 +0,0 @@
-// Copyright (C) 2020-2021, IrineSistiana
-//
-// This file is part of mosdns.
-//
-// mosdns is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// mosdns is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with this program. If not, see .
-
-package utils
-
-import (
- "context"
- "errors"
- "fmt"
- "github.com/IrineSistiana/mosdns/dispatcher/handler"
- "github.com/miekg/dns"
- "go.uber.org/zap"
- "reflect"
- "strings"
- "sync"
- "time"
-)
-
-type ExecutableCmd interface {
- ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error)
-}
-
-type executablePluginTag struct {
- s string
-}
-
-func (t executablePluginTag) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) {
- p, err := handler.GetPlugin(t.s)
- if err != nil {
- return "", false, err
- }
-
- logger.Debug("exec executable plugin", qCtx.InfoField(), zap.String("exec", t.s))
- earlyStop, err = p.ExecES(ctx, qCtx)
- return "", earlyStop, err
-}
-
-type IfBlockConfig struct {
- If []string `yaml:"if"`
- IfAnd []string `yaml:"if_and"`
- Exec []interface{} `yaml:"exec"`
- Goto string `yaml:"goto"`
-}
-
-type matcher struct {
- tag string
- negate bool
-}
-
-func paresMatcher(s []string) []matcher {
- m := make([]matcher, 0, len(s))
- for _, tag := range s {
- if strings.HasPrefix(tag, "!") {
- m = append(m, matcher{tag: strings.TrimPrefix(tag, "!"), negate: true})
- } else {
- m = append(m, matcher{tag: tag})
- }
- }
- return m
-}
-
-type IfBlock struct {
- ifMatcher []matcher
- ifAndMatcher []matcher
- executableCmd ExecutableCmd
- goTwo string
-}
-
-func (b *IfBlock) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) {
- if len(b.ifMatcher) > 0 {
- If, err := ifCondition(ctx, qCtx, logger, b.ifMatcher, false)
- if err != nil {
- return "", false, err
- }
- if If == false {
- return "", false, nil // if case returns false, skip this block.
- }
- }
-
- if len(b.ifAndMatcher) > 0 {
- If, err := ifCondition(ctx, qCtx, logger, b.ifAndMatcher, true)
- if err != nil {
- return "", false, err
- }
- if If == false {
- return "", false, nil
- }
- }
-
- // exec
- if b.executableCmd != nil {
- goTwo, earlyStop, err = b.executableCmd.ExecCmd(ctx, qCtx, logger)
- if err != nil {
- return "", false, err
- }
- if len(goTwo) != 0 || earlyStop {
- return goTwo, earlyStop, nil
- }
- }
-
- // goto
- if len(b.goTwo) != 0 { // if block has a goto, return it
- return b.goTwo, false, nil
- }
-
- return "", false, nil
-}
-
-func ifCondition(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, p []matcher, isAnd bool) (ok bool, err error) {
- if len(p) == 0 {
- return false, err
- }
-
- for _, m := range p {
- mp, err := handler.GetPlugin(m.tag)
- if err != nil {
- return false, err
- }
- matched, err := mp.Match(ctx, qCtx)
- if err != nil {
- return false, err
- }
- logger.Debug("exec matcher plugin", qCtx.InfoField(), zap.String("exec", m.tag), zap.Bool("result", matched))
-
- res := matched != m.negate
- if !isAnd && res == true {
- return true, nil // or: if one of the case is true, skip others.
- }
- if isAnd && res == false {
- return false, nil // and: if one of the case is false, skip others.
- }
-
- ok = res
- }
- return ok, nil
-}
-
-func ParseIfBlock(in map[string]interface{}) (*IfBlock, error) {
- c := new(IfBlockConfig)
- err := handler.WeakDecode(in, c)
- if err != nil {
- return nil, err
- }
-
- b := &IfBlock{
- ifMatcher: paresMatcher(c.If),
- ifAndMatcher: paresMatcher(c.IfAnd),
- goTwo: c.Goto,
- }
-
- if len(c.Exec) != 0 {
- ecs, err := ParseExecutableCmdSequence(c.Exec)
- if err != nil {
- return nil, err
- }
- b.executableCmd = ecs
- }
-
- return b, nil
-}
-
-type ParallelECS struct {
- s []*ExecutableCmdSequence
- timeout time.Duration
-}
-
-type ParallelECSConfig struct {
- Parallel [][]interface{} `yaml:"parallel"`
- Timeout uint `yaml:"timeout"`
-}
-
-func ParseParallelECS(c *ParallelECSConfig) (*ParallelECS, error) {
- if len(c.Parallel) < 2 {
- return nil, fmt.Errorf("parallel needs at least 2 cmd sequences, but got %d", len(c.Parallel))
- }
-
- ps := make([]*ExecutableCmdSequence, 0, len(c.Parallel))
- for i, subSequence := range c.Parallel {
- es, err := ParseExecutableCmdSequence(subSequence)
- if err != nil {
- return nil, fmt.Errorf("invalid parallel sequence at index %d: %w", i, err)
- }
- ps = append(ps, es)
- }
- return &ParallelECS{s: ps, timeout: time.Duration(c.Timeout) * time.Second}, nil
-}
-
-type parallelECSResult struct {
- r *dns.Msg
- status handler.ContextStatus
- err error
- from int
-}
-
-func (p *ParallelECS) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) {
- return "", false, p.execCmd(ctx, qCtx, logger)
-}
-
-func (p *ParallelECS) execCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
-
- var pCtx context.Context // only valid if p.timeout == 0
- var cancel func()
- if p.timeout == 0 {
- pCtx, cancel = context.WithCancel(ctx)
- defer cancel()
- }
-
- t := len(p.s)
- c := make(chan *parallelECSResult, len(p.s)) // use buf chan to avoid block.
-
- for i, sequence := range p.s {
- i := i
- sequence := sequence
- qCtxCopy := qCtx.Copy()
-
- go func() {
- var ecsCtx context.Context
- var ecsCancel func()
- if p.timeout == 0 {
- ecsCtx = pCtx
- } else {
- ecsCtx, ecsCancel = context.WithTimeout(context.Background(), p.timeout)
- defer ecsCancel()
- }
-
- err := WalkExecutableCmd(ecsCtx, qCtxCopy, logger, sequence)
- if err == nil {
- err = qCtxCopy.ExecDefer(pCtx)
- }
- c <- ¶llelECSResult{
- r: qCtxCopy.R(),
- status: qCtxCopy.Status(),
- err: err,
- from: i,
- }
- }()
- }
-
- return asyncWait(ctx, qCtx, logger, c, t)
-}
-
-type FallbackConfig struct {
- // Primary exec sequence, must have at least one element.
- Primary []interface{} `yaml:"primary"`
- // Secondary exec sequence, must have at least one element.
- Secondary []interface{} `yaml:"secondary"`
-
- StatLength int `yaml:"stat_length"` // An Zero value disables the (normal) fallback.
- Threshold int `yaml:"threshold"`
-
- // FastFallback threshold in milliseconds. Zero means fast fallback is disabled.
- FastFallback int `yaml:"fast_fallback"`
-
- // AlwaysStandby: secondary should always standby in fast fallback.
- AlwaysStandby bool `yaml:"always_standby"`
-}
-
-type FallbackECS struct {
- primary *ExecutableCmdSequence
- secondary *ExecutableCmdSequence
- fastFallbackDuration time.Duration
- alwaysStandby bool
-
- primaryST *statusTracker // nil if normal fallback is disabled
-}
-
-type statusTracker struct {
- sync.Mutex
- threshold int
- status []uint8 // 0 means success, !0 means failed
- p int
-}
-
-func newStatusTracker(threshold, statLength int) *statusTracker {
- return &statusTracker{
- threshold: threshold,
- status: make([]uint8, statLength),
- }
-}
-
-func (t *statusTracker) good() bool {
- t.Lock()
- defer t.Unlock()
-
- var failedSum int
- for _, s := range t.status {
- if s != 0 {
- failedSum++
- }
- }
- return failedSum < t.threshold
-}
-
-func (t *statusTracker) update(s uint8) {
- t.Lock()
- defer t.Unlock()
-
- if t.p >= len(t.status) {
- t.p = 0
- }
- t.status[t.p] = s
- t.p++
-}
-
-func ParseFallbackECS(c *FallbackConfig) (*FallbackECS, error) {
- if len(c.Primary) == 0 {
- return nil, errors.New("primary sequence is empty")
- }
- if len(c.Secondary) == 0 {
- return nil, errors.New("secondary sequence is empty")
- }
-
- primaryECS, err := ParseExecutableCmdSequence(c.Primary)
- if err != nil {
- return nil, fmt.Errorf("invalid primary sequence: %w", err)
- }
-
- secondaryECS, err := ParseExecutableCmdSequence(c.Secondary)
- if err != nil {
- return nil, fmt.Errorf("invalid secondary sequence: %w", err)
- }
-
- fallbackECS := &FallbackECS{
- primary: primaryECS,
- secondary: secondaryECS,
- fastFallbackDuration: time.Duration(c.FastFallback) * time.Millisecond,
- alwaysStandby: c.AlwaysStandby,
- }
-
- if c.StatLength > 0 {
- if c.Threshold > c.StatLength {
- c.Threshold = c.StatLength
- }
- fallbackECS.primaryST = newStatusTracker(c.Threshold, c.StatLength)
- }
-
- return fallbackECS, nil
-}
-
-func (f *FallbackECS) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) {
- return "", false, f.execCmd(ctx, qCtx, logger)
-}
-
-func (f *FallbackECS) execCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
- if f.primaryST == nil || f.primaryST.good() {
- if f.fastFallbackDuration > 0 {
- return f.doFastFallback(ctx, qCtx, logger)
- } else {
- return f.isolateDoPrimary(ctx, qCtx, logger)
- }
- }
- logger.Debug("primary is not good", qCtx.InfoField())
- return f.doFallback(ctx, qCtx, logger)
-}
-
-func (f *FallbackECS) isolateDoPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
- qCtxCopy := qCtx.Copy()
- err = f.doPrimary(ctx, qCtxCopy, logger)
- qCtx.SetResponse(qCtxCopy.R(), qCtxCopy.Status())
- return err
-}
-
-func (f *FallbackECS) doPrimary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
- err = WalkExecutableCmd(ctx, qCtx, logger, f.primary)
- if err == nil {
- err = qCtx.ExecDefer(ctx)
- }
- if f.primaryST != nil {
- if err != nil || qCtx.R() == nil {
- f.primaryST.update(1)
- } else {
- f.primaryST.update(0)
- }
- }
-
- return err
-}
-
-func (f *FallbackECS) doFastFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
- fCtx, cancel := context.WithCancel(ctx)
- defer cancel()
-
- timer := GetTimer(f.fastFallbackDuration)
- defer ReleaseTimer(timer)
-
- c := make(chan *parallelECSResult, 2)
- primFailed := make(chan struct{}) // will be closed if primary returns an err.
-
- qCtxCopyP := qCtx.Copy()
- go func() {
- err := f.doPrimary(fCtx, qCtxCopyP, logger)
- if err != nil || qCtxCopyP.R() == nil {
- close(primFailed)
- }
- c <- ¶llelECSResult{
- r: qCtxCopyP.R(),
- status: qCtxCopyP.Status(),
- err: err,
- from: 1,
- }
- }()
-
- qCtxCopyS := qCtx.Copy()
- go func() {
- if !f.alwaysStandby { // not always standby, wait here.
- select {
- case <-fCtx.Done(): // primary is done, no needs to exec this.
- return
- case <-primFailed: // primary failed or timeout, exec now.
- case <-timer.C:
- }
- }
-
- err := f.doSecondary(fCtx, qCtxCopyS, logger)
- res := ¶llelECSResult{
- r: qCtxCopyS.R(),
- status: qCtxCopyS.Status(),
- err: err,
- from: 2,
- }
-
- if f.alwaysStandby { // always standby
- select {
- case <-fCtx.Done():
- return
- case <-primFailed: // only send secondary result when primary is failed.
- c <- res
- case <-timer.C: // or timeout.
- c <- res
- }
- } else {
- c <- res // not always standby, send the result asap.
- }
- }()
-
- return asyncWait(ctx, qCtx, logger, c, 2)
-}
-
-func (f *FallbackECS) doSecondary(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
- err = WalkExecutableCmd(ctx, qCtx, logger, f.secondary)
- if err == nil {
- err = qCtx.ExecDefer(ctx)
- }
- return err
-}
-
-func (f *FallbackECS) doFallback(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) error {
- fCtx, cancel := context.WithCancel(ctx)
- defer cancel()
-
- c := make(chan *parallelECSResult, 2) // buf size is 2, avoid block.
-
- qCtxCopyP := qCtx.Copy()
- go func() {
- err := f.doPrimary(fCtx, qCtxCopyP, logger)
- c <- ¶llelECSResult{
- r: qCtxCopyP.R(),
- status: qCtxCopyP.Status(),
- err: err,
- from: 1,
- }
- }()
-
- qCtxCopyS := qCtx.Copy()
- go func() {
- err := WalkExecutableCmd(fCtx, qCtxCopyS, logger, f.secondary)
- if err == nil {
- err = qCtxCopyS.ExecDefer(fCtx)
- }
- c <- ¶llelECSResult{
- r: qCtxCopyS.R(),
- status: qCtxCopyS.Status(),
- err: err,
- from: 2,
- }
- }()
-
- return asyncWait(ctx, qCtx, logger, c, 2)
-}
-
-func asyncWait(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, c chan *parallelECSResult, total int) error {
- for i := 0; i < total; i++ {
- select {
- case res := <-c:
- if res.err != nil {
- logger.Warn("sequence failed", qCtx.InfoField(), zap.Int("sequence", res.from), zap.Error(res.err))
- continue
- }
-
- if res.r == nil {
- logger.Debug("sequence returned with an empty response", qCtx.InfoField(), zap.Int("sequence", res.from))
- continue
- }
-
- logger.Debug("sequence returned a response", qCtx.InfoField(), zap.Int("sequence", res.from))
- qCtx.SetResponse(res.r, res.status)
- return nil
-
- case <-ctx.Done():
- return ctx.Err()
- }
- }
-
- // No response
- qCtx.SetResponse(nil, handler.ContextStatusServerFailed)
- return errors.New("no response")
-}
-
-type ExecutableCmdSequence struct {
- c []ExecutableCmd
-}
-
-func ParseExecutableCmdSequence(in []interface{}) (*ExecutableCmdSequence, error) {
- es := &ExecutableCmdSequence{c: make([]ExecutableCmd, 0, len(in))}
- for i, v := range in {
- ec, err := parseExecutableCmd(v)
- if err != nil {
- return nil, fmt.Errorf("invalid cmd #%d: %w", i, err)
- }
- es.c = append(es.c, ec)
- }
- return es, nil
-}
-
-func parseExecutableCmd(in interface{}) (ExecutableCmd, error) {
- switch v := in.(type) {
- case string:
- return &executablePluginTag{s: v}, nil
- case map[string]interface{}:
- switch {
- case hasKey(v, "if") || hasKey(v, "if_and"): // if block
- ec, err := ParseIfBlock(v)
- if err != nil {
- return nil, fmt.Errorf("invalid if section: %w", err)
- }
- return ec, nil
- case hasKey(v, "parallel"): // parallel
- ec, err := parseParallelECS(v)
- if err != nil {
- return nil, fmt.Errorf("invalid parallel section: %w", err)
- }
- return ec, nil
- case hasKey(v, "primary"):
- ec, err := parseFallbackECS(v)
- if err != nil {
- return nil, fmt.Errorf("invalid fallback section: %w", err)
- }
- return ec, nil
- default:
- return nil, errors.New("unknown section")
- }
- default:
- return nil, fmt.Errorf("unexpected type: %s", reflect.TypeOf(in).String())
- }
-}
-
-func parseParallelECS(m map[string]interface{}) (ec ExecutableCmd, err error) {
- conf := new(ParallelECSConfig)
- err = handler.WeakDecode(m, conf)
- if err != nil {
- return nil, err
- }
- return ParseParallelECS(conf)
-}
-
-func parseFallbackECS(m map[string]interface{}) (ec ExecutableCmd, err error) {
- conf := new(FallbackConfig)
- err = handler.WeakDecode(m, conf)
- if err != nil {
- return nil, err
- }
- return ParseFallbackECS(conf)
-}
-
-func hasKey(m map[string]interface{}, key string) bool {
- _, ok := m[key]
- return ok
-}
-
-// ExecCmd executes the sequence.
-func (es *ExecutableCmdSequence) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (goTwo string, earlyStop bool, err error) {
- for _, cmd := range es.c {
- goTwo, earlyStop, err = cmd.ExecCmd(ctx, qCtx, logger)
- if err != nil {
- return "", false, err
- }
- if len(goTwo) != 0 || earlyStop {
- return goTwo, earlyStop, nil
- }
- }
-
- return "", false, nil
-}
-
-func (es *ExecutableCmdSequence) Len() int {
- return len(es.c)
-}
-
-// WalkExecutableCmd executes the ExecutableCmd, include its `goto`.
-// This should only be used in root cmd node.
-func WalkExecutableCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger, entry ExecutableCmd) (err error) {
- goTwo, _, err := entry.ExecCmd(ctx, qCtx, logger)
- if err != nil {
- return err
- }
-
- if len(goTwo) != 0 {
- logger.Debug("goto plugin", qCtx.InfoField(), zap.String("goto", goTwo))
- p, err := handler.GetPlugin(goTwo)
- if err != nil {
- return err
- }
- _, err = p.ExecES(ctx, qCtx)
- return err
- }
- return nil
-}
diff --git a/tools/tools.go b/tools/tools.go
index 77c0b95bf..ae77b926f 100644
--- a/tools/tools.go
+++ b/tools/tools.go
@@ -20,11 +20,11 @@ package tools
import (
"crypto/tls"
"fmt"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/domain"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/netlist"
- "github.com/IrineSistiana/mosdns/dispatcher/matcher/v2data"
"github.com/IrineSistiana/mosdns/dispatcher/mlog"
- "github.com/IrineSistiana/mosdns/dispatcher/utils"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/domain"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/netlist"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/matcher/v2data"
+ "github.com/IrineSistiana/mosdns/dispatcher/pkg/utils"
"github.com/miekg/dns"
"io"
"net"
@@ -113,10 +113,12 @@ func ProbServerTimeout(addr string) error {
}
func BenchIPMatcher(f string) error {
- list, err := netlist.NewListFromFile(f)
+ list := netlist.NewList()
+ err := netlist.LoadFromFile(list, f)
if err != nil {
return err
}
+ list.Sort()
ip := net.IPv4(8, 8, 8, 8).To4()
@@ -135,7 +137,7 @@ func BenchIPMatcher(f string) error {
func BenchDomainMatcher(f string) error {
matcher := domain.NewMixMatcher()
- err := domain.LoadFromFileAsV2Matcher(matcher, f)
+ err := domain.LoadFromFile(matcher, f, nil)
if err != nil {
return err
}