Skip to content

Commit

Permalink
squash v0.25.2
Browse files Browse the repository at this point in the history
  • Loading branch information
IrineSistiana committed Jan 13, 2021
1 parent 391e04b commit 91abc16
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 36 deletions.
16 changes: 16 additions & 0 deletions dispatcher/handler/plugin_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ var (
_ ESExecutable = (*PluginWrapper)(nil)
_ Matcher = (*PluginWrapper)(nil)
_ ContextConnector = (*PluginWrapper)(nil)
_ Service = (*PluginWrapper)(nil)
)

// PluginWrapper wraps the original plugin to avoid extremely frequently
Expand All @@ -38,6 +39,7 @@ type PluginWrapper struct {
se ESExecutable
m Matcher
cc ContextConnector
s Service
}

func newPluginWrapper(gp Plugin) *PluginWrapper {
Expand All @@ -56,6 +58,9 @@ func newPluginWrapper(gp Plugin) *PluginWrapper {
if cc, ok := gp.(ContextConnector); ok {
w.cc = cc
}
if s, ok := gp.(Service); ok {
w.s = s
}

return w
}
Expand Down Expand Up @@ -116,12 +121,21 @@ func (w *PluginWrapper) ExecES(ctx context.Context, qCtx *Context) (earlyStop bo
return earlyStop, nil
}

func (w *PluginWrapper) Shutdown() error {
if w.s == nil {
return fmt.Errorf("plugin tag: %s, type: %s is not a Service", w.p.Tag(), w.p.Type())
}

return w.s.Shutdown()
}

type PluginInterfaceType uint8

const (
PITESExecutable = iota
PITMatcher
PITContextConnector
PITService
)

func (w *PluginWrapper) Is(t PluginInterfaceType) bool {
Expand All @@ -132,6 +146,8 @@ func (w *PluginWrapper) Is(t PluginInterfaceType) bool {
return w.m != nil
case PITContextConnector:
return w.cc != nil
case PITService:
return w.s != nil
default:
panic(fmt.Sprintf("hander: invalid PluginInterfaceType: %d", t))
}
Expand Down
91 changes: 61 additions & 30 deletions dispatcher/handler/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,21 @@ func newPluginRegister() *pluginRegister {
// shutdown the old service, it will panic.
func (r *pluginRegister) regPlugin(p Plugin, errIfDup bool) error {
r.Lock()
defer r.Unlock()

tag := p.Tag()
oldWrapper, dup := r.register[tag]
if dup {
if errIfDup {
return fmt.Errorf("plugin tag %s has been registered", tag)
}
mlog.L().Info("overwrite plugin", zap.String("tag", tag))
if service, ok := oldWrapper.GetPlugin().(ServicePlugin); ok {
mlog.L().Info("shutting down old service", zap.String("tag", tag))
if err := service.Shutdown(); err != nil {
panic(fmt.Sprintf("service %s failed to shutdown: %v", tag, err))
}
mlog.L().Info("old service exited", zap.String("tag", tag))
}
}

if dup && errIfDup {
r.Unlock()
return fmt.Errorf("plugin tag %s has been registered", tag)
}
r.register[tag] = newPluginWrapper(p)
r.Unlock()

if dup {
mlog.L().Info("plugin overwritten", zap.String("tag", tag))
r.tryShutdownService(oldWrapper)
}
return nil
}

Expand All @@ -87,13 +83,38 @@ func (r *pluginRegister) getPlugin(tag string) (p *PluginWrapper, err error) {
return p, nil
}

func (r *pluginRegister) getAllPluginTag() []string {
func (r *pluginRegister) delPlugin(tag string) {
r.Lock()
p, ok := r.register[tag]
if !ok {
r.Unlock()
return
}
delete(r.register, tag)
r.Unlock()

r.tryShutdownService(p)
return
}

func (r *pluginRegister) tryShutdownService(oldWrapper *PluginWrapper) {
tag := oldWrapper.GetPlugin().Tag()
if oldWrapper.Is(PITService) {
mlog.L().Info("shutting down old service", zap.String("tag", tag))
if err := oldWrapper.Shutdown(); err != nil {
panic(fmt.Sprintf("service %s failed to shutdown: %v", tag, err))
}
mlog.L().Info("old service exited", zap.String("tag", tag))
}
}

func (r *pluginRegister) getPluginAll() []Plugin {
r.RLock()
defer r.RUnlock()

t := make([]string, 0, len(r.register))
for tag := range r.register {
t = append(t, tag)
t := make([]Plugin, 0, len(r.register))
for _, pw := range r.register {
t = append(t, pw.GetPlugin())
}
return t
}
Expand All @@ -105,7 +126,7 @@ func (r *pluginRegister) purge() {
}

// RegInitFunc registers this plugin type.
// This should only be called in init() of the plugin package.
// This should only be called in init() from the plugin package.
// Duplicate plugin types are not allowed.
func RegInitFunc(pluginType string, initFunc NewPluginFunc, argsType NewArgsFunc) {
_, ok := pluginTypeRegister[pluginType]
Expand All @@ -118,15 +139,6 @@ func RegInitFunc(pluginType string, initFunc NewPluginFunc, argsType NewArgsFunc
}
}

// GetConfigurablePluginTypes returns all plugin types which are configurable.
func GetConfigurablePluginTypes() []string {
b := make([]string, 0, len(pluginTypeRegister))
for typ := range pluginTypeRegister {
b = append(b, typ)
}
return b
}

// InitAndRegPlugin inits and registers this plugin globally.
// This is a help func of NewPlugin + RegPlugin.
func InitAndRegPlugin(c *Config, errIfDup bool) (err error) {
Expand Down Expand Up @@ -178,8 +190,27 @@ func GetPlugin(tag string) (p *PluginWrapper, err error) {
return pluginTagRegister.getPlugin(tag)
}

func GetAllPluginTag() []string {
return pluginTagRegister.getAllPluginTag()
// DelPlugin deletes this plugin.
// If this plugin is a Service, DelPlugin will call Service.Shutdown().
// DelPlugin will panic if Service.Shutdown() returns an err.
func DelPlugin(tag string) {
pluginTagRegister.delPlugin(tag)
}

// GetPluginAll returns all registered plugins.
// This should only be used in test or debug.
func GetPluginAll() []Plugin {
return pluginTagRegister.getPluginAll()
}

// GetConfigurablePluginTypes returns all plugin types which are configurable.
// This should only be used in test or debug.
func GetConfigurablePluginTypes() []string {
b := make([]string, 0, len(pluginTypeRegister))
for typ := range pluginTypeRegister {
b = append(b, typ)
}
return b
}

// PurgePluginRegister should only be used in test.
Expand Down
43 changes: 43 additions & 0 deletions dispatcher/handler/register_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,46 @@ func TestRegPlugin(t *testing.T) {
})
}
}

func Test_pluginRegister_delPlugin(t *testing.T) {
tests := []struct {
name string
p Plugin
tag string
wantPanic bool
}{
{"del matcher", &DummyMatcherPlugin{
BP: NewBP("test", ""),
}, "test", false},
{"del service", &DummyServicePlugin{
BP: NewBP("test", ""),
}, "test", false},
{"del service but panic", &DummyServicePlugin{
BP: NewBP("test", ""),
WantShutdownErr: errors.New(""),
}, "test", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := newPluginRegister()
err := r.regPlugin(tt.p, true)
if err != nil {
t.Fatal(err)
}
if tt.wantPanic {
defer func() {
msg := recover()
if msg == nil {
t.Error("delPlugin not panic")
}
}()
}

r.delPlugin(tt.tag)
if _, err := GetPlugin(tt.tag); err == nil {
t.Error("plugin is not deleted")
}
})

}
}
2 changes: 1 addition & 1 deletion dispatcher/matcher/domain/load_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (m *MixMatcher) LoadFormFile(file string, filterRecord FilterRecordFunc, pa
return nil
}

// LoadFormFileAsV2Matcher loads data from file.
// LoadFormFileAsV2Matcher loads data from a file.
// File can be a text file or a v2ray data file.
// v2ray data file needs to specify the data category by using ':', e.g. 'geosite:cn'
// v2ray data file can also have multiple @attr. e.g. 'geosite:cn@attr1@attr2'.
Expand Down
2 changes: 1 addition & 1 deletion dispatcher/plugin/executable/ecs/ecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func newPlugin(bp *handler.BP, args *Args) (p handler.Plugin, err error) {
return ep, nil
}

// Do tries to append ECS to qCtx.Q().
// Exec tries to append ECS to qCtx.Q().
// If an error occurred, Do will just log it.
// Therefore, Do will never return an err.
func (e ecsPlugin) Exec(_ context.Context, qCtx *handler.Context) (err error) {
Expand Down
9 changes: 5 additions & 4 deletions dispatcher/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func LoadCertPool(certs []string) (*x509.CertPool, error) {
// GenerateCertificate generates a ecdsa certificate with given dnsName.
// This should only use in test.
func GenerateCertificate(dnsName string) (cert tls.Certificate, err error) {
key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return
}
Expand Down Expand Up @@ -261,18 +261,19 @@ func ExchangeParallel(ctx context.Context, qCtx *handler.Context, upstreams []Tr
return nil, errors.New("no upstream is configured")
}
if t == 1 {
r, err = upstreams[0].Exchange(qCtx.Q())
u := upstreams[0]
r, err = u.Exchange(qCtx.Q())
if err != nil {
return nil, err
}
logger.Debug("received response", qCtx.InfoField(), zap.String("from", upstreams[0].Address()))
logger.Debug("received response", qCtx.InfoField(), zap.String("from", u.Address()))
return r, nil
}

c := make(chan *parallelResult, t) // use buf chan to avoid block.
qCopy := qCtx.Q().Copy() // qCtx is not safe for concurrent use.
for _, u := range upstreams {
u := u
qCopy := qCtx.Q().Copy() // it is not safe to use the Q directly.
go func() {
r, err := u.Exchange(qCopy)
c <- &parallelResult{
Expand Down

0 comments on commit 91abc16

Please sign in to comment.