diff --git a/dispatcher/handler/plugin_wrapper.go b/dispatcher/handler/plugin_wrapper.go index fca264c45..87b7e03db 100644 --- a/dispatcher/handler/plugin_wrapper.go +++ b/dispatcher/handler/plugin_wrapper.go @@ -26,6 +26,7 @@ var ( _ ESExecutable = (*PluginWrapper)(nil) _ Matcher = (*PluginWrapper)(nil) _ ContextConnector = (*PluginWrapper)(nil) + _ Service = (*PluginWrapper)(nil) ) // PluginWrapper wraps the original plugin to avoid extremely frequently @@ -38,6 +39,7 @@ type PluginWrapper struct { se ESExecutable m Matcher cc ContextConnector + s Service } func newPluginWrapper(gp Plugin) *PluginWrapper { @@ -56,6 +58,9 @@ func newPluginWrapper(gp Plugin) *PluginWrapper { if cc, ok := gp.(ContextConnector); ok { w.cc = cc } + if s, ok := gp.(Service); ok { + w.s = s + } return w } @@ -116,12 +121,21 @@ func (w *PluginWrapper) ExecES(ctx context.Context, qCtx *Context) (earlyStop bo return earlyStop, nil } +func (w *PluginWrapper) Shutdown() error { + if w.s == nil { + return fmt.Errorf("plugin tag: %s, type: %s is not a Service", w.p.Tag(), w.p.Type()) + } + + return w.s.Shutdown() +} + type PluginInterfaceType uint8 const ( PITESExecutable = iota PITMatcher PITContextConnector + PITService ) func (w *PluginWrapper) Is(t PluginInterfaceType) bool { @@ -132,6 +146,8 @@ func (w *PluginWrapper) Is(t PluginInterfaceType) bool { return w.m != nil case PITContextConnector: return w.cc != nil + case PITService: + return w.s != nil default: panic(fmt.Sprintf("hander: invalid PluginInterfaceType: %d", t)) } diff --git a/dispatcher/handler/register.go b/dispatcher/handler/register.go index 7261fe2ec..ebf9d80a9 100644 --- a/dispatcher/handler/register.go +++ b/dispatcher/handler/register.go @@ -55,25 +55,21 @@ func newPluginRegister() *pluginRegister { // shutdown the old service, it will panic. func (r *pluginRegister) regPlugin(p Plugin, errIfDup bool) error { r.Lock() - defer r.Unlock() tag := p.Tag() oldWrapper, dup := r.register[tag] - if dup { - if errIfDup { - return fmt.Errorf("plugin tag %s has been registered", tag) - } - mlog.L().Info("overwrite plugin", zap.String("tag", tag)) - if service, ok := oldWrapper.GetPlugin().(ServicePlugin); ok { - mlog.L().Info("shutting down old service", zap.String("tag", tag)) - if err := service.Shutdown(); err != nil { - panic(fmt.Sprintf("service %s failed to shutdown: %v", tag, err)) - } - mlog.L().Info("old service exited", zap.String("tag", tag)) - } - } + if dup && errIfDup { + r.Unlock() + return fmt.Errorf("plugin tag %s has been registered", tag) + } r.register[tag] = newPluginWrapper(p) + r.Unlock() + + if dup { + mlog.L().Info("plugin overwritten", zap.String("tag", tag)) + r.tryShutdownService(oldWrapper) + } return nil } @@ -87,13 +83,38 @@ func (r *pluginRegister) getPlugin(tag string) (p *PluginWrapper, err error) { return p, nil } -func (r *pluginRegister) getAllPluginTag() []string { +func (r *pluginRegister) delPlugin(tag string) { + r.Lock() + p, ok := r.register[tag] + if !ok { + r.Unlock() + return + } + delete(r.register, tag) + r.Unlock() + + r.tryShutdownService(p) + return +} + +func (r *pluginRegister) tryShutdownService(oldWrapper *PluginWrapper) { + tag := oldWrapper.GetPlugin().Tag() + if oldWrapper.Is(PITService) { + mlog.L().Info("shutting down old service", zap.String("tag", tag)) + if err := oldWrapper.Shutdown(); err != nil { + panic(fmt.Sprintf("service %s failed to shutdown: %v", tag, err)) + } + mlog.L().Info("old service exited", zap.String("tag", tag)) + } +} + +func (r *pluginRegister) getPluginAll() []Plugin { r.RLock() defer r.RUnlock() - t := make([]string, 0, len(r.register)) - for tag := range r.register { - t = append(t, tag) + t := make([]Plugin, 0, len(r.register)) + for _, pw := range r.register { + t = append(t, pw.GetPlugin()) } return t } @@ -105,7 +126,7 @@ func (r *pluginRegister) purge() { } // RegInitFunc registers this plugin type. -// This should only be called in init() of the plugin package. +// This should only be called in init() from the plugin package. // Duplicate plugin types are not allowed. func RegInitFunc(pluginType string, initFunc NewPluginFunc, argsType NewArgsFunc) { _, ok := pluginTypeRegister[pluginType] @@ -118,15 +139,6 @@ func RegInitFunc(pluginType string, initFunc NewPluginFunc, argsType NewArgsFunc } } -// GetConfigurablePluginTypes returns all plugin types which are configurable. -func GetConfigurablePluginTypes() []string { - b := make([]string, 0, len(pluginTypeRegister)) - for typ := range pluginTypeRegister { - b = append(b, typ) - } - return b -} - // InitAndRegPlugin inits and registers this plugin globally. // This is a help func of NewPlugin + RegPlugin. func InitAndRegPlugin(c *Config, errIfDup bool) (err error) { @@ -178,8 +190,27 @@ func GetPlugin(tag string) (p *PluginWrapper, err error) { return pluginTagRegister.getPlugin(tag) } -func GetAllPluginTag() []string { - return pluginTagRegister.getAllPluginTag() +// DelPlugin deletes this plugin. +// If this plugin is a Service, DelPlugin will call Service.Shutdown(). +// DelPlugin will panic if Service.Shutdown() returns an err. +func DelPlugin(tag string) { + pluginTagRegister.delPlugin(tag) +} + +// GetPluginAll returns all registered plugins. +// This should only be used in test or debug. +func GetPluginAll() []Plugin { + return pluginTagRegister.getPluginAll() +} + +// GetConfigurablePluginTypes returns all plugin types which are configurable. +// This should only be used in test or debug. +func GetConfigurablePluginTypes() []string { + b := make([]string, 0, len(pluginTypeRegister)) + for typ := range pluginTypeRegister { + b = append(b, typ) + } + return b } // PurgePluginRegister should only be used in test. diff --git a/dispatcher/handler/register_test.go b/dispatcher/handler/register_test.go index e73507e0d..90b7233f8 100644 --- a/dispatcher/handler/register_test.go +++ b/dispatcher/handler/register_test.go @@ -56,3 +56,46 @@ func TestRegPlugin(t *testing.T) { }) } } + +func Test_pluginRegister_delPlugin(t *testing.T) { + tests := []struct { + name string + p Plugin + tag string + wantPanic bool + }{ + {"del matcher", &DummyMatcherPlugin{ + BP: NewBP("test", ""), + }, "test", false}, + {"del service", &DummyServicePlugin{ + BP: NewBP("test", ""), + }, "test", false}, + {"del service but panic", &DummyServicePlugin{ + BP: NewBP("test", ""), + WantShutdownErr: errors.New(""), + }, "test", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := newPluginRegister() + err := r.regPlugin(tt.p, true) + if err != nil { + t.Fatal(err) + } + if tt.wantPanic { + defer func() { + msg := recover() + if msg == nil { + t.Error("delPlugin not panic") + } + }() + } + + r.delPlugin(tt.tag) + if _, err := GetPlugin(tt.tag); err == nil { + t.Error("plugin is not deleted") + } + }) + + } +} diff --git a/dispatcher/matcher/domain/load_helper.go b/dispatcher/matcher/domain/load_helper.go index f9c69a033..01f669485 100644 --- a/dispatcher/matcher/domain/load_helper.go +++ b/dispatcher/matcher/domain/load_helper.go @@ -61,7 +61,7 @@ func (m *MixMatcher) LoadFormFile(file string, filterRecord FilterRecordFunc, pa return nil } -// LoadFormFileAsV2Matcher loads data from file. +// LoadFormFileAsV2Matcher loads data from a file. // File can be a text file or a v2ray data file. // v2ray data file needs to specify the data category by using ':', e.g. 'geosite:cn' // v2ray data file can also have multiple @attr. e.g. 'geosite:cn@attr1@attr2'. diff --git a/dispatcher/plugin/executable/ecs/ecs.go b/dispatcher/plugin/executable/ecs/ecs.go index 16f45595a..02535b66b 100644 --- a/dispatcher/plugin/executable/ecs/ecs.go +++ b/dispatcher/plugin/executable/ecs/ecs.go @@ -102,7 +102,7 @@ func newPlugin(bp *handler.BP, args *Args) (p handler.Plugin, err error) { return ep, nil } -// Do tries to append ECS to qCtx.Q(). +// Exec tries to append ECS to qCtx.Q(). // If an error occurred, Do will just log it. // Therefore, Do will never return an err. func (e ecsPlugin) Exec(_ context.Context, qCtx *handler.Context) (err error) { diff --git a/dispatcher/utils/utils.go b/dispatcher/utils/utils.go index c00f1ef49..7dc370ce8 100644 --- a/dispatcher/utils/utils.go +++ b/dispatcher/utils/utils.go @@ -141,7 +141,7 @@ func LoadCertPool(certs []string) (*x509.CertPool, error) { // GenerateCertificate generates a ecdsa certificate with given dnsName. // This should only use in test. func GenerateCertificate(dnsName string) (cert tls.Certificate, err error) { - key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return } @@ -261,18 +261,19 @@ func ExchangeParallel(ctx context.Context, qCtx *handler.Context, upstreams []Tr return nil, errors.New("no upstream is configured") } if t == 1 { - r, err = upstreams[0].Exchange(qCtx.Q()) + u := upstreams[0] + r, err = u.Exchange(qCtx.Q()) if err != nil { return nil, err } - logger.Debug("received response", qCtx.InfoField(), zap.String("from", upstreams[0].Address())) + logger.Debug("received response", qCtx.InfoField(), zap.String("from", u.Address())) return r, nil } c := make(chan *parallelResult, t) // use buf chan to avoid block. + qCopy := qCtx.Q().Copy() // qCtx is not safe for concurrent use. for _, u := range upstreams { u := u - qCopy := qCtx.Q().Copy() // it is not safe to use the Q directly. go func() { r, err := u.Exchange(qCopy) c <- ¶llelResult{