Skip to content

Commit

Permalink
Update structure
Browse files Browse the repository at this point in the history
  • Loading branch information
dviejokfs committed Feb 4, 2023
1 parent ec33e78 commit e40418b
Show file tree
Hide file tree
Showing 7 changed files with 522 additions and 129 deletions.
170 changes: 79 additions & 91 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ package server
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/kfsoftware/getout/proxy"
"github.com/kfsoftware/go-vhost"
"io"
"github.com/pkg/errors"
"net"
"strings"
"sync"
"time"

"github.com/hashicorp/yamux"
"github.com/kfsoftware/getout/log"
"github.com/kfsoftware/getout/pkg/messages"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
)

Expand All @@ -29,7 +30,7 @@ func (c *serverCmd) validate() error {

func (c *serverCmd) returnResponse(initialConn net.Conn, status messages.TunnelStatus) error {
tunnelResponse := &messages.TunnelResponse{Status: status}
log.Debug().Msgf("Returning response to client: %s", status)
log.Debugf("Returning response to client: %s", status)
err := messages.WriteMsg(initialConn, tunnelResponse)
if err != nil {
return err
Expand Down Expand Up @@ -58,7 +59,7 @@ func (r *SessionRegistry) store(sni string, s *Session) error {
defer r.Unlock()
_, ok := r.sessions[sni]
if ok {
log.Warn().Msgf("Session already exists for %s", sni)
log.Warnf("Session already exists for %s", sni)
return fmt.Errorf("session already exists for %s", sni)
}
r.sessions[sni] = s
Expand All @@ -74,7 +75,14 @@ func (r *SessionRegistry) delete(sni string) {
delete(r.sessions, sni)
}
}

func (r *SessionRegistry) cleanupDeadSessions() {
for sni, session := range r.sessions {
if session.Sess.IsClosed() {
log.Debugf("Session closed for %s", sni)
r.delete(sni)
}
}
}
func (r *SessionRegistry) find(sni string) *Session {
r.RLock()
defer r.RUnlock()
Expand All @@ -85,20 +93,20 @@ func (r *SessionRegistry) find(sni string) *Session {
return s
}
func (c *serverCmd) handleTunnelRequest(mux *vhost.TLSMuxer, conn net.Conn) error {
log.Trace().Msgf("client %s connected", conn.RemoteAddr().String())
log.Debugf("client %s connected", conn.RemoteAddr().String())
config := yamux.DefaultConfig()
// setup session
sess, err := yamux.Server(conn, config)
if err != nil {
log.Err(err).Msg("failed to create yamux session")
log.Errorf("failed to create yamux session")
return err
}
// accept connection
initialConn, err := sess.Accept()
if err != nil {
defer sess.Close()
log.Trace().Msgf("client %s disconnected", conn.RemoteAddr().String())
log.Error().Msgf("multiplex conn accept failed %v", err)
log.Debugf("client %s disconnected", conn.RemoteAddr().String())
log.Errorf("multiplex conn accept failed %v", err)
return err
}
defer initialConn.Close()
Expand All @@ -112,17 +120,17 @@ func (c *serverCmd) handleTunnelRequest(mux *vhost.TLSMuxer, conn net.Conn) erro
if s != nil {
defer conn.Close()
defer sess.Close()
log.Trace().Msgf("Session already exists in the registry for %s", sni)
log.Debugf("Session already exists in the registry for %s", sni)
err = c.returnResponse(initialConn, messages.TunnelStatus_ALREADY_EXISTS)
if err != nil {
log.Warn().Msgf("Failed to send response: %v", err)
log.Warnf("Failed to send response: %v", err)
}
return err
}
muxListener, err := c.startMuxListener(mux, initialConn, sni)
if err != nil {
if msg != nil {
log.Err(err).Msgf("failed to listen on %s", msg.GetTls().GetSni())
log.Errorf("failed to listen on %s", msg.GetTls().GetSni())
}
if muxListener != nil {
muxListener.Close()
Expand All @@ -132,7 +140,7 @@ func (c *serverCmd) handleTunnelRequest(mux *vhost.TLSMuxer, conn net.Conn) erro
if strings.Contains(strings.ToLower(err.Error()), "already bound") {
err = c.returnResponse(initialConn, messages.TunnelStatus_ALREADY_EXISTS)
if err != nil {
log.Warn().Msgf("Failed to send response: %v", err)
log.Warnf("Failed to send response: %v", err)
}
return err
}
Expand All @@ -146,103 +154,81 @@ func (c *serverCmd) handleTunnelRequest(mux *vhost.TLSMuxer, conn net.Conn) erro
}
err = c.sessionRegistry.store(sni, session)
if err != nil {
log.Err(err).Msgf("failed to store session for %s", sni)
log.Errorf("failed to store session for %s", sni)
session.cleanup()
return err
}
err = c.returnResponse(initialConn, messages.TunnelStatus_OK)
if err != nil {
err = c.returnResponse(initialConn, messages.TunnelStatus_ERROR)
if err != nil {
log.Warn().Msgf("Failed to send response: %v", err)
log.Warnf("Failed to send response: %v", err)
}
c.sessionRegistry.delete(sni)
return err
}
go func(ml net.Listener) {
defer func() {
c.sessionRegistry.delete(sni)
if r := recover(); r != nil {
log.Info().Msgf("Recovered in request dispatcher %v", r)
}
}()
for {
conn, err := ml.Accept()
if err != nil {
log.Err(err).Msg("Error accepting connection")
c.sessionRegistry.delete(sni)
if strings.Contains(strings.ToLower(err.Error()), "listener closed") {
log.Info().Msg("listener closed")
return
}
continue
}
destConn, err := sess.Open()
if err != nil {
_ = conn.Close()
log.Warn().Msgf("Connection closed")
continue
}
transfer := func(side string, dst, src net.Conn) {
log.Trace().Msgf("proxing %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
tStart := time.Now()

n, err := io.Copy(dst, src)
if err != nil {
log.Error().Msgf("%s: copy error: %s", side, err)
}

if err := src.Close(); err != nil {
log.Trace().Msgf("%s: close error: %s", side, err)
}

// not for yamux streams, but for client to local server connections
if d, ok := dst.(*net.TCPConn); ok {
if err := d.CloseWrite(); err != nil {
log.Trace().Msgf("%s: closeWrite error: %s", side, err)
}
if err := d.CloseRead(); err != nil {
log.Trace().Msgf("%s: closeRead error: %s", side, err)
}
}
log.Trace().Msgf("done proxing %s -> %s: %d bytes in %s", src.RemoteAddr(), dst.RemoteAddr(), n, time.Since(tStart))
}
go transfer("remote to local", conn, destConn)
go transfer("local to remote", destConn, conn)
defer func() {
c.sessionRegistry.delete(sni)
if r := recover(); r != nil {
log.Infof("Recovered in request dispatcher %v", r)
}
}(muxListener)
}()
go func() {
log.Debug().Msgf("Checking if session %s is alive", sni)
log.Debugf("Checking if session %s is alive", sni)
defer func() {
c.sessionRegistry.delete(sni)
}()
for {
_, err = sess.Ping()
if err != nil {
log.Info().Msgf("Session %s inactive, removing it: %v", sni, err)
log.Infof("Session %s inactive, removing it: %v", sni, err)
c.sessionRegistry.delete(sni)
break
}
time.Sleep(2 * time.Second)
continue
}
}()
for {
conn, err := muxListener.Accept()
if err != nil {
log.Errorf("Error accepting connection", err)
c.sessionRegistry.delete(sni)
if strings.Contains(strings.ToLower(err.Error()), "listener closed") {
log.Info("listener closed")
return errors.New("listener closed")
}
continue
}
destConn, err := sess.Open()
if err != nil {
_ = conn.Close()
log.Warnf("Connection closed")
continue
}
p := proxy.New(
conn,
destConn,
)
p.Start()
}

return nil
}

func (c *serverCmd) startMuxListener(mux *vhost.TLSMuxer, initialConn net.Conn, sni string) (net.Listener, error) {
log.Debug().Msgf("Received request for %v", sni)
log.Debugf("Received request for %v", sni)
muxListener, err := mux.Listen(sni)
if err != nil {
log.Err(err).Msgf("failed to listen on %s", sni)
log.Errorf("failed to listen on %s", sni)
if strings.Contains(strings.ToLower(err.Error()), "already bound") {
err = mux.Del(sni)
if err != nil {
log.Err(err).Msgf("failed to delete mux %s", sni)
log.Errorf("failed to delete mux %s", sni)
}
respErr := c.returnResponse(initialConn, messages.TunnelStatus_ALREADY_EXISTS)
if respErr != nil {
log.Warn().Msgf("Failed to send response: %v", err)
log.Warnf("Failed to send response: %v", err)
return muxListener, respErr
}
return muxListener, err
Expand All @@ -253,7 +239,7 @@ func (c *serverCmd) startMuxListener(mux *vhost.TLSMuxer, initialConn net.Conn,
if err != nil {
err = c.returnResponse(initialConn, messages.TunnelStatus_ERROR)
if err != nil {
log.Warn().Msgf("Failed to send response: %v", err)
log.Warnf("Failed to send response: %v", err)
}
c.sessionRegistry.delete(sni)
return nil, err
Expand All @@ -270,37 +256,39 @@ func (c *serverCmd) run() error {
// start multiplexing on it
mux, err := vhost.NewTLSMuxer(l, muxTimeout)
if err != nil {
log.Err(err).Msg("failed to create muxer")
log.Errorf("failed to create muxer")
}
log.Debug().Msgf("Starting server %s", c.addr)
log.Debugf("Starting server %s", c.addr)
muxServer, err := net.Listen("tcp", c.tunnelAddr)
if err != nil {
panic(fmt.Errorf("error listening on %s: %w", c.tunnelAddr, err))
}
defer func(muxServer net.Listener) {
log.Warn().Msgf("Closing mux server %s", muxServer.Addr())
log.Warnf("Closing mux server %s", muxServer.Addr())
_ = muxServer.Close()
}(muxServer)
go func() {
log.Info().Msgf("tunnel listening on %s", c.tunnelAddr)
log.Infof("tunnel listening on %s", c.tunnelAddr)
defer func() {
if r := recover(); r != nil {
log.Info().Msgf("tunnel listener closed %v", r)
log.Infof("tunnel listener closed %v", r)
}
}()
for {
conn, err := muxServer.Accept()
if err != nil {
log.Warn().Msgf("Connection closed")
log.Warnf("Connection closed")
return
}
log.Trace().Msgf("Accepted connection from %s", conn.RemoteAddr())
err = c.handleTunnelRequest(mux, conn)
if err != nil {
log.Warn().Msgf("Failed to handle tunnel request: %v", err)
}
log.Debugf("Accepted connection from %s", conn.RemoteAddr())
go func() {
err = c.handleTunnelRequest(mux, conn)
if err != nil {
log.Warnf("Failed to handle tunnel request: %v", err)
}
}()
}
log.Info().Msg("tunnel server closed")
log.Info("tunnel server closed")
}()
go func() {
r := gin.Default()
Expand All @@ -311,24 +299,24 @@ func (c *serverCmd) run() error {
c.sessionRegistry.delete(c1.Param("sni"))
c1.JSON(200, c.sessionRegistry.sessions)
})
log.Info().Msgf("admin server listening on %s", c.adminAddr)
log.Infof("admin server listening on %s", c.adminAddr)
err := r.Run(c.adminAddr)
if err != nil {
log.Error().Msgf("failed to listen on address: %s %v", c.adminAddr, err)
log.Errorf("failed to listen on address: %s %v", c.adminAddr, err)
}
}()
go func() {
for {
conn, err := mux.NextError()
switch err.(type) {
case vhost.BadRequest:
log.Trace().Msgf("got a bad request!")
log.Debugf("got a bad request!")
case vhost.NotFound:
log.Trace().Msgf("got a connection for an unknown vhost")
log.Debugf("got a connection for an unknown vhost")
case vhost.Closed:
log.Trace().Msgf("closed conn: %s", err)
log.Debugf("closed conn: %s", err)
default:
log.Trace().Msgf("Server error")
log.Debugf("Server error")
}

if conn != nil {
Expand Down
Loading

0 comments on commit e40418b

Please sign in to comment.