From 43fbf77ab9ecdd0caec7a57a05f0b6bac9624553 Mon Sep 17 00:00:00 2001 From: dviejokfs Date: Sat, 4 Feb 2023 22:36:57 +0100 Subject: [PATCH] Improve code legibility --- cmd/server/server.go | 43 ++++++++++++++++--------------------------- 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index c56f72b..7c99296 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -43,11 +43,17 @@ type SessionRegistry struct { sessions map[string]*Session } -func (s Session) cleanup() { +func (s *Session) cleanup() { + defer func() { + s.cleanupDone = true + }() + if s.cleanupDone { + return + } if s.Conn != nil { s.Conn.Close() } - if s.Sess != nil { + if s.Sess != nil && !s.Sess.IsClosed() { s.Sess.Close() } if s.Mux != nil { @@ -93,6 +99,7 @@ func (r *SessionRegistry) find(sni string) *Session { return s } func (c *serverCmd) handleTunnelRequest(mux *vhost.TLSMuxer, conn net.Conn) error { + defer conn.Close() log.Debugf("client %s connected", conn.RemoteAddr().String()) config := yamux.DefaultConfig() // setup session @@ -101,10 +108,10 @@ func (c *serverCmd) handleTunnelRequest(mux *vhost.TLSMuxer, conn net.Conn) erro log.Errorf("failed to create yamux session") return err } + defer sess.Close() // accept connection initialConn, err := sess.Accept() if err != nil { - defer sess.Close() log.Debugf("client %s disconnected", conn.RemoteAddr().String()) log.Errorf("multiplex conn accept failed %v", err) return err @@ -118,8 +125,6 @@ func (c *serverCmd) handleTunnelRequest(mux *vhost.TLSMuxer, conn net.Conn) erro sni := msg.GetTls().GetSni() s := c.sessionRegistry.find(sni) if s != nil { - defer conn.Close() - defer sess.Close() log.Debugf("Session already exists in the registry for %s", sni) err = c.returnResponse(initialConn, messages.TunnelStatus_ALREADY_EXISTS) if err != nil { @@ -132,11 +137,6 @@ func (c *serverCmd) handleTunnelRequest(mux *vhost.TLSMuxer, conn net.Conn) erro if msg != nil { log.Errorf("failed to listen on %s", msg.GetTls().GetSni()) } - if muxListener != nil { - muxListener.Close() - } - defer conn.Close() - defer sess.Close() if strings.Contains(strings.ToLower(err.Error()), "already bound") { err = c.returnResponse(initialConn, messages.TunnelStatus_ALREADY_EXISTS) if err != nil { @@ -152,10 +152,13 @@ func (c *serverCmd) handleTunnelRequest(mux *vhost.TLSMuxer, conn net.Conn) erro Mux: muxListener, Sess: sess, } + defer func() { + session.cleanup() + c.sessionRegistry.delete(sni) + }() err = c.sessionRegistry.store(sni, session) if err != nil { log.Errorf("failed to store session for %s", sni) - session.cleanup() return err } err = c.returnResponse(initialConn, messages.TunnelStatus_OK) @@ -164,20 +167,10 @@ func (c *serverCmd) handleTunnelRequest(mux *vhost.TLSMuxer, conn net.Conn) erro if err != nil { log.Warnf("Failed to send response: %v", err) } - c.sessionRegistry.delete(sni) return err } - defer func() { - c.sessionRegistry.delete(sni) - if r := recover(); r != nil { - log.Infof("Recovered in request dispatcher %v", r) - } - }() go func() { log.Debugf("Checking if session %s is alive", sni) - defer func() { - c.sessionRegistry.delete(sni) - }() for { _, err = sess.Ping() if err != nil { @@ -193,7 +186,6 @@ func (c *serverCmd) handleTunnelRequest(mux *vhost.TLSMuxer, conn net.Conn) erro conn, err := muxListener.Accept() if err != nil { log.Errorf("Error accepting connection", err) - c.sessionRegistry.delete(sni) if strings.Contains(strings.ToLower(err.Error()), "listener closed") { log.Info("listener closed") return errors.New("listener closed") @@ -263,10 +255,6 @@ func (c *serverCmd) run() error { if err != nil { panic(fmt.Errorf("error listening on %s: %w", c.tunnelAddr, err)) } - defer func(muxServer net.Listener) { - log.Warnf("Closing mux server %s", muxServer.Addr()) - _ = muxServer.Close() - }(muxServer) go func() { log.Infof("tunnel listening on %s", c.tunnelAddr) defer func() { @@ -353,7 +341,8 @@ func NewServerCmd() *cobra.Command { } type Session struct { - SNI string `json:"sni"` + cleanupDone bool + SNI string `json:"sni"` //RemoteAddr string `json:"remoteAddr"` //LocalAddr string `json:"localAddr"` Conn net.Conn