diff --git a/cmd/vfkit/main.go b/cmd/vfkit/main.go index 3dbb6524..0cff38d3 100644 --- a/cmd/vfkit/main.go +++ b/cmd/vfkit/main.go @@ -37,6 +37,8 @@ import ( restvf "github.com/crc-org/vfkit/pkg/rest/vf" "github.com/crc-org/vfkit/pkg/vf" log "github.com/sirupsen/logrus" + + "github.com/crc-org/vfkit/pkg/util" ) func newLegacyBootloader(opts *cmdline.Options) config.Bootloader { @@ -121,6 +123,8 @@ func runVFKit(vmConfig *config.VirtualMachine, opts *cmdline.Options) error { runtime.LockOSThread() defer runtime.UnlockOSThread() + util.SetupExitSignalHandling() + gpuDevs := vmConfig.VirtioGPUDevices() if opts.UseGUI && len(gpuDevs) > 0 { gpuDevs[0].UsesGUI = true @@ -239,6 +243,11 @@ func startIgnitionProvisionerServer(ignitionReader io.Reader, ignitionSocketPath if err != nil { return err } + + util.RegisterExitHandler(func() { + os.Remove(ignitionSocketPath) + }) + defer func() { if err := listener.Close(); err != nil { log.Error(err) diff --git a/pkg/util/exithandler.go b/pkg/util/exithandler.go new file mode 100644 index 00000000..692eaaa3 --- /dev/null +++ b/pkg/util/exithandler.go @@ -0,0 +1,44 @@ +package util + +import ( + "log" + "os" + "os/signal" + "syscall" +) + +var exitHandlers []func() + +// RegisterExitHandler appends a func Exit handler to the list of handlers. +// The handlers will be invoked when vfkit receives a termination or interruption signal +// +// This method is useful when a caller wishes to execute a func before a shutdown. +func RegisterExitHandler(handler func()) { + exitHandlers = append(exitHandlers, handler) +} + +// SetupExitSignalHandling sets up a signal channel to listen for termination or interruption signals. +// When one of these signals is received, all the registered exit handlers will be invoked, just +// before terminating the program. +func SetupExitSignalHandling() { + setupExitSignalHandling(true) +} + +// setupExitSignalHandling sets up a signal channel to listen for termination or interruption signals. +// When one of these signals is received, all the registered exit handlers will be invoked. +// It is possible to prevent the program from exiting by setting the doExit param to false (used for testing) +func setupExitSignalHandling(doExit bool) { + sigChan := make(chan os.Signal, 2) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) + go func() { + for sig := range sigChan { + log.Printf("captured %v, calling exit handlers and exiting..", sig) + for _, handler := range exitHandlers { + handler() + } + if doExit { + os.Exit(1) + } + } + }() +} diff --git a/pkg/util/exithandler_test.go b/pkg/util/exithandler_test.go new file mode 100644 index 00000000..97c92d58 --- /dev/null +++ b/pkg/util/exithandler_test.go @@ -0,0 +1,29 @@ +package util + +import ( + "syscall" + "testing" + "time" +) + +func TestExitHandlerCalled(t *testing.T) { + setupExitSignalHandling(false) + + ch := make(chan struct{}) + RegisterExitHandler(func() { + close(ch) + }) + + err := syscall.Kill(syscall.Getpid(), syscall.SIGINT) + + if err != nil { + t.Errorf("failed at sending SIGINT signal") + } + + select { + case <-ch: + // exit handler was called + case <-time.After(5 * time.Second): + t.Errorf("Exit handler not called - timed out") + } +} diff --git a/pkg/vf/virtionet.go b/pkg/vf/virtionet.go index f6e22159..526f2ceb 100644 --- a/pkg/vf/virtionet.go +++ b/pkg/vf/virtionet.go @@ -5,11 +5,11 @@ import ( "math/rand" "net" "os" - "os/signal" "path/filepath" "syscall" "github.com/crc-org/vfkit/pkg/config" + "github.com/crc-org/vfkit/pkg/util" "github.com/Code-Hex/vz/v3" log "github.com/sirupsen/logrus" @@ -102,7 +102,7 @@ func (dev *VirtioNet) connectUnixPath() error { dev.Socket = fd dev.localAddr = &localAddr dev.UnixSocketPath = "" - registerExitHandler(func() { _ = dev.Shutdown() }) + util.RegisterExitHandler(func() { _ = dev.Shutdown() }) return nil } @@ -173,15 +173,3 @@ func (dev *VirtioNet) Shutdown() error { return nil } - -func registerExitHandler(handler func()) { - sigChan := make(chan os.Signal, 2) - signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) - go func() { - for sig := range sigChan { - log.Printf("captured %v, calling exit handlers and exiting..", sig) - handler() - os.Exit(1) - } - }() -}