diff --git a/client/core/socks.go b/client/core/socks.go index fe95337664..81d639d89b 100644 --- a/client/core/socks.go +++ b/client/core/socks.go @@ -31,6 +31,7 @@ import ( "github.com/bishopfox/sliver/protobuf/rpcpb" "github.com/bishopfox/sliver/protobuf/sliverpb" "github.com/bishopfox/sliver/util/leaky" + "golang.org/x/time/rate" ) var ( @@ -210,6 +211,8 @@ const leakyBufSize = 4108 // data.len(2) + hmacsha1(10) + data(4096) var leakyBuf = leaky.NewLeakyBuf(2048, leakyBufSize) func connect(conn net.Conn, stream rpcpb.SliverRPC_SocksProxyClient, frame *sliverpb.SocksData) { + // Client Rate Limiter: 10 operations per second, burst of 1 + limiter := rate.NewLimiter(rate.Limit(10), 1) SocksConnPool.Store(frame.TunnelID, conn) @@ -241,6 +244,11 @@ func connect(conn net.Conn, stream rpcpb.SliverRPC_SocksProxyClient, frame *sliv return } if n > 0 { + if err := limiter.Wait(context.Background()); err != nil { + log.Printf("[socks] rate limiter error: %s", err) + return + } + frame.Data = buff[:n] frame.Sequence = ToImplantSequence log.Printf("[socks] (User to Client) to Server to agent Data Sequence %d , Data Size %d \n", ToImplantSequence, len(frame.Data)) diff --git a/implant/sliver/handlers/tunnel_handlers/socks_handler.go b/implant/sliver/handlers/tunnel_handlers/socks_handler.go index 7e9f2e278b..0575da81fa 100644 --- a/implant/sliver/handlers/tunnel_handlers/socks_handler.go +++ b/implant/sliver/handlers/tunnel_handlers/socks_handler.go @@ -36,16 +36,29 @@ import ( "google.golang.org/protobuf/proto" ) +const ( + inactivityCheckInterval = 4 * time.Second + inactivityTimeout = 15 * time.Second +) + type socksTunnelPool struct { - tunnels *sync.Map // map[uint64]chan []byte + tunnels *sync.Map // map[uint64]chan []byte + lastActivity *sync.Map // map[uint64]time.Time } var socksTunnels = socksTunnelPool{ - tunnels: &sync.Map{}, + tunnels: &sync.Map{}, + lastActivity: &sync.Map{}, } var socksServer *socks5.Server +// Initialize socks server +func init() { + socksServer = socks5.NewServer() + socksTunnels.startCleanupMonitor() +} + func SocksReqHandler(envelope *sliverpb.Envelope, connection *transports.Connection) { socksData := &sliverpb.SocksData{} err := proto.Unmarshal(envelope.Data, socksData) @@ -55,9 +68,26 @@ func SocksReqHandler(envelope *sliverpb.Envelope, connection *transports.Connect // {{end}} return } + time.Sleep(10 * time.Millisecond) // Necessary delay + + // Check early to see if this is a close request from server + if socksData.CloseConn { + if tunnel, ok := socksTunnels.tunnels.LoadAndDelete(socksData.TunnelID); ok { + if ch, ok := tunnel.(chan []byte); ok { + close(ch) + } + } + socksTunnels.lastActivity.Delete(socksData.TunnelID) + return + } + if socksData.Data == nil { return } + + // Record activity as soon as we get data for this tunnel + socksTunnels.recordActivity(socksData.TunnelID) + // {{if .Config.Debug}} log.Printf("[socks] User to Client to (server to implant) Data Sequence %d, Data Size %d\n", socksData.Sequence, len(socksData.Data)) // {{end}} @@ -70,8 +100,6 @@ func SocksReqHandler(envelope *sliverpb.Envelope, connection *transports.Connect socksServer = socks5.NewServer( socks5.WithAuthMethods([]socks5.Authenticator{auth}), ) - } else { - socksServer = socks5.NewServer() } // {{if .Config.Debug}} @@ -80,7 +108,7 @@ func SocksReqHandler(envelope *sliverpb.Envelope, connection *transports.Connect // init tunnel if tunnel, ok := socksTunnels.tunnels.Load(socksData.TunnelID); !ok { - tunnelChan := make(chan []byte, 10) + tunnelChan := make(chan []byte, 100) // Buffered channel for 100 messages socksTunnels.tunnels.Store(socksData.TunnelID, tunnelChan) tunnelChan <- socksData.Data err := socksServer.ServeConn(&socks{stream: socksData, conn: connection}) @@ -88,9 +116,12 @@ func SocksReqHandler(envelope *sliverpb.Envelope, connection *transports.Connect // {{if .Config.Debug}} log.Printf("[socks] Failed to serve connection: %v", err) // {{end}} + // Cleanup on serve failure + socksTunnels.tunnels.Delete(socksData.TunnelID) return } } else { + // Will block when channel is full tunnel.(chan []byte) <- socksData.Data } } @@ -105,16 +136,22 @@ type socks struct { } func (s *socks) Read(b []byte) (n int, err error) { + time.Sleep(10 * time.Millisecond) // Necessary delay + channel, ok := socksTunnels.tunnels.Load(s.stream.TunnelID) if !ok { return 0, errors.New("[socks] invalid tunnel id") } + socksTunnels.recordActivity(s.stream.TunnelID) data := <-channel.(chan []byte) return copy(b, data), nil } func (s *socks) Write(b []byte) (n int, err error) { + time.Sleep(10 * time.Millisecond) // Necessary delay + + socksTunnels.recordActivity(s.stream.TunnelID) data, err := proto.Marshal(&sliverpb.SocksData{ TunnelID: s.stream.TunnelID, Data: b, @@ -136,12 +173,15 @@ func (s *socks) Write(b []byte) (n int, err error) { } func (s *socks) Close() error { + time.Sleep(10 * time.Millisecond) // Necessary delay + channel, ok := socksTunnels.tunnels.LoadAndDelete(s.stream.TunnelID) if !ok { return errors.New("[socks] can't close unknown channel") } close(channel.(chan []byte)) + // Signal to server that we need to close this tunnel data, err := proto.Marshal(&sliverpb.SocksData{ TunnelID: s.stream.TunnelID, CloseConn: true, @@ -181,3 +221,38 @@ func (c *socks) SetReadDeadline(t time.Time) error { func (c *socks) SetWriteDeadline(t time.Time) error { return nil } + +func (s *socksTunnelPool) recordActivity(tunnelID uint64) { + s.lastActivity.Store(tunnelID, time.Now()) +} + +// Periodically check for inactive tunnels and clean up +func (s *socksTunnelPool) startCleanupMonitor() { + go func() { + ticker := time.NewTicker(inactivityCheckInterval) + defer ticker.Stop() + + for range ticker.C { + s.tunnels.Range(func(key, value interface{}) bool { + tunnelID := key.(uint64) + lastActivityI, exists := s.lastActivity.Load(tunnelID) + if !exists { + // If no activity record exists, create one + s.recordActivity(tunnelID) + return true + } + + lastActivity := lastActivityI.(time.Time) + if time.Since(lastActivity) > inactivityTimeout { + // Clean up the inactive tunnel + if ch, ok := value.(chan []byte); ok { + close(ch) + } + s.tunnels.Delete(tunnelID) + s.lastActivity.Delete(tunnelID) + } + return true + }) + } + }() +} diff --git a/server/rpc/rpc-socks.go b/server/rpc/rpc-socks.go index 7074b1cf33..689ff99d9f 100644 --- a/server/rpc/rpc-socks.go +++ b/server/rpc/rpc-socks.go @@ -20,8 +20,11 @@ package rpc import ( "context" + "fmt" "io" "sync" + "sync/atomic" + "time" "github.com/bishopfox/sliver/protobuf/commonpb" "github.com/bishopfox/sliver/protobuf/rpcpb" @@ -32,15 +35,16 @@ import ( var ( // SessionID->Tunnels[TunnelID]->Tunnel->Cache map[uint64]*sliverpb.SocksData{} - toImplantCacheSocks = socksDataCache{mutex: &sync.RWMutex{}, cache: map[uint64]map[uint64]*sliverpb.SocksData{}} + toImplantCacheSocks = socksDataCache{mutex: &sync.RWMutex{}, cache: map[uint64]map[uint64]*sliverpb.SocksData{}, lastActivity: map[uint64]time.Time{}} // SessionID->Tunnels[TunnelID]->Tunnel->Cache - fromImplantCacheSocks = socksDataCache{mutex: &sync.RWMutex{}, cache: map[uint64]map[uint64]*sliverpb.SocksData{}} + fromImplantCacheSocks = socksDataCache{mutex: &sync.RWMutex{}, cache: map[uint64]map[uint64]*sliverpb.SocksData{}, lastActivity: map[uint64]time.Time{}} ) type socksDataCache struct { - mutex *sync.RWMutex - cache map[uint64]map[uint64]*sliverpb.SocksData + mutex *sync.RWMutex + cache map[uint64]map[uint64]*sliverpb.SocksData + lastActivity map[uint64]time.Time } func (c *socksDataCache) Add(tunnelID uint64, sequence uint64, tunnelData *sliverpb.SocksData) { @@ -72,6 +76,7 @@ func (c *socksDataCache) DeleteTun(tunnelID uint64) { defer c.mutex.Unlock() delete(c.cache, tunnelID) + delete(c.lastActivity, tunnelID) } func (c *socksDataCache) DeleteSeq(tunnelID uint64, sequence uint64) { @@ -85,69 +90,299 @@ func (c *socksDataCache) DeleteSeq(tunnelID uint64, sequence uint64) { delete(c.cache[tunnelID], sequence) } +func (c *socksDataCache) recordActivity(tunnelID uint64) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.lastActivity[tunnelID] = time.Now() +} + // Socks - Open an in-band port forward + +const ( + writeTimeout = 5 * time.Second + batchSize = 100 // Maximum number of sequences to batch + inactivityCheckInterval = 4 * time.Second + inactivityTimeout = 15 * time.Second + ToImplantTickerInterval = 10 * time.Millisecond // data going towards implant is usually smaller request data + ToClientTickerInterval = 5 * time.Millisecond // data coming back from implant is usually larger response data +) + func (s *Server) SocksProxy(stream rpcpb.SliverRPC_SocksProxyServer) error { + errChan := make(chan error, 2) + defer close(errChan) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connDone := make(chan struct{}) + defer close(connDone) + + // Track all goroutines spawned for this session + var wg sync.WaitGroup + defer wg.Wait() + + // Track all tunnels created for this session + activeTunnels := make(map[uint64]bool) + var tunnelMutex sync.Mutex + + // Cleanup all tunnels on SocksProxy closure + defer func() { + tunnelMutex.Lock() + for tunnelID := range activeTunnels { + if tunnel := core.SocksTunnels.Get(tunnelID); tunnel != nil { + rpcLog.Infof("Cleaning up tunnel %d on proxy closure", tunnelID) + close(tunnel.FromImplant) + tunnel.Client = nil + s.CloseSocks(context.Background(), &sliverpb.Socks{TunnelID: tunnelID}) + } + } + tunnelMutex.Unlock() + }() + for { + select { + case err := <-errChan: + rpcLog.Errorf("SocksProxy error: %v", err) + return err + default: + } + fromClient, err := stream.Recv() if err == io.EOF { - break + return nil } - //fmt.Println("Send Agent 1 ",fromClient.TunnelID,len(fromClient.Data)) if err != nil { rpcLog.Warnf("Error on stream recv %s", err) return err } - tunnelLog.Debugf("Tunnel %d: From client %d byte(s)", - fromClient.TunnelID, len(fromClient.Data)) - socks := core.SocksTunnels.Get(fromClient.TunnelID) - if socks == nil { - return nil - } - if socks.Client == nil { - socks.Client = stream // Bind client to tunnel - // Send Client - go func() { - for tunnelData := range socks.FromImplant { - fromImplantCacheSocks.Add(fromClient.TunnelID, tunnelData.Sequence, tunnelData) + tunnelMutex.Lock() + activeTunnels[fromClient.TunnelID] = true // Mark this as an active tunnel + tunnelMutex.Unlock() - for recv, ok := fromImplantCacheSocks.Get(fromClient.TunnelID, socks.FromImplantSequence); ok; recv, ok = fromImplantCacheSocks.Get(fromClient.TunnelID, socks.FromImplantSequence) { - rpcLog.Debugf("[socks] agent to (Server To Client) Data Sequence %d , Data Size %d ,Data %v\n", socks.FromImplantSequence, len(recv.Data), recv.Data) - socks.Client.Send(&sliverpb.SocksData{ - CloseConn: recv.CloseConn, - TunnelID: recv.TunnelID, - Data: recv.Data, - }) + tunnel := core.SocksTunnels.Get(fromClient.TunnelID) + if tunnel == nil { + continue + } - fromImplantCacheSocks.DeleteSeq(fromClient.TunnelID, socks.FromImplantSequence) - socks.FromImplantSequence++ + if tunnel.Client == nil { + tunnel.Client = stream + tunnel.FromImplant = make(chan *sliverpb.SocksData, 100) // Buffered channel for 100 messages + + // Monitor tunnel goroutines for inactivity and cleanup + wg.Add(1) + go func(tunnelID uint64) { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + rpcLog.Errorf("Recovered from panic in monitor: %v", r) + errChan <- fmt.Errorf("monitor goroutine panic: %v", r) + cancel() // Cancel context in case of a panic } + }() + + ticker := time.NewTicker(inactivityCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-connDone: + return + case <-ticker.C: + tunnel := core.SocksTunnels.Get(tunnelID) + if tunnel == nil || tunnel.Client == nil { + return + } + session := core.Sessions.Get(tunnel.SessionID) + + // Check both caches for activity + toImplantCacheSocks.mutex.RLock() + fromImplantCacheSocks.mutex.RLock() + toLastActivity := toImplantCacheSocks.lastActivity[tunnelID] + fromLastActivity := fromImplantCacheSocks.lastActivity[tunnelID] + toImplantCacheSocks.mutex.RUnlock() + fromImplantCacheSocks.mutex.RUnlock() + + // Clean up goroutine if both directions have hit the idle threshold + if time.Since(toLastActivity) > inactivityTimeout && time.Since(fromLastActivity) > inactivityTimeout { + s.CloseSocks(context.Background(), &sliverpb.Socks{TunnelID: tunnelID}) + return + } + + // Clean up goroutine if the client has disconnected early + if tunnel.Client == nil || session == nil { + s.CloseSocks(context.Background(), &sliverpb.Socks{TunnelID: tunnelID}) + return + } + } } - }() - } - - // Send Agent - go func() { - toImplantCacheSocks.Add(fromClient.TunnelID, fromClient.Sequence, fromClient) + }(fromClient.TunnelID) - for recv, ok := toImplantCacheSocks.Get(fromClient.TunnelID, socks.ToImplantSequence); ok; recv, ok = toImplantCacheSocks.Get(fromClient.TunnelID, socks.ToImplantSequence) { - rpcLog.Debugf("[socks] Client to (Server To Agent) Data Sequence %d , Data Size %d \n", socks.ToImplantSequence, len(fromClient.Data)) - data, _ := proto.Marshal(recv) + // Send Client + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + rpcLog.Errorf("Recovered from panic in client sender: %v", r) + errChan <- fmt.Errorf("client sender panic: %v", r) + cancel() // Cancel context in case of a panic + } + }() + + pendingData := make(map[uint64]*sliverpb.SocksData) + ticker := time.NewTicker(ToClientTickerInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-connDone: + return + case tunnelData, ok := <-tunnel.FromImplant: + if !ok { + return + } + + // Check if implant is requesting to close the tunnel + if tunnelData.CloseConn { + // Clean up the tunnel + s.CloseSocks(context.Background(), &sliverpb.Socks{TunnelID: fromClient.TunnelID}) + return + } + + sequence := tunnelData.Sequence + fromImplantCacheSocks.Add(fromClient.TunnelID, sequence, tunnelData) + pendingData[sequence] = tunnelData + fromImplantCacheSocks.recordActivity(fromClient.TunnelID) + + case <-ticker.C: + if tunnel.Client == nil { + return + } + if len(pendingData) == 0 { + continue + } + + expectedSequence := atomic.LoadUint64(&tunnel.FromImplantSequence) + processed := 0 + + // Perform Batching + for i := 0; i < batchSize && processed < len(pendingData); i++ { + data, exists := pendingData[expectedSequence] + if !exists { + break // Stop batching if we don't have the next expected sequence + } + + func() { + defer func() { + if r := recover(); r != nil { + errChan <- fmt.Errorf("Client sender panic: %v", r) + } + }() + + err := stream.Send(&sliverpb.SocksData{ + CloseConn: data.CloseConn, + TunnelID: data.TunnelID, + Data: data.Data, + }) + + if err != nil { + rpcLog.Errorf("Send error: %v", err) + return + } + + delete(pendingData, expectedSequence) + fromImplantCacheSocks.DeleteSeq(fromClient.TunnelID, expectedSequence) + atomic.AddUint64(&tunnel.FromImplantSequence, 1) + expectedSequence++ + processed++ + }() + } - session := core.Sessions.Get(fromClient.Request.SessionID) - session.Connection.Send <- &sliverpb.Envelope{ - Type: sliverpb.MsgSocksData, - Data: data, + } } + }() - toImplantCacheSocks.DeleteSeq(fromClient.TunnelID, socks.ToImplantSequence) - socks.ToImplantSequence++ - } + // Send Agent + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + rpcLog.Errorf("Recovered from panic in agent sender: %v", r) + errChan <- fmt.Errorf("agent sender panic: %v", r) + cancel() // Cancel context in case of a panic + } + }() + + pendingData := make(map[uint64]*sliverpb.SocksData) + ticker := time.NewTicker(ToImplantTickerInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-connDone: + return + case <-ticker.C: + if tunnel.Client == nil { + return + } + sequence := atomic.LoadUint64(&tunnel.ToImplantSequence) + + func() { + defer func() { + if r := recover(); r != nil { + rpcLog.Errorf("Recovered from processing panic: %v", r) + } + }() + + for { + recv, ok := toImplantCacheSocks.Get(fromClient.TunnelID, sequence) + if !ok { + break + } + + session := core.Sessions.Get(fromClient.Request.SessionID) + if session == nil { + rpcLog.Error("Session not found") + break + } + + data, err := proto.Marshal(recv) + if err != nil { + rpcLog.Errorf("Failed to marshal data: %v", err) + continue + } + + select { + case session.Connection.Send <- &sliverpb.Envelope{ + Type: sliverpb.MsgSocksData, + Data: data, + }: + toImplantCacheSocks.DeleteSeq(fromClient.TunnelID, sequence) + atomic.AddUint64(&tunnel.ToImplantSequence, 1) + sequence++ + case <-time.After(writeTimeout): + rpcLog.Error("Write timeout to implant") + pendingData[sequence] = recv + break + } + } + }() + } + } + }() + } - }() + toImplantCacheSocks.Add(fromClient.TunnelID, fromClient.Sequence, fromClient) } - return nil } // CreateSocks5 - Create requests we close a Socks @@ -169,11 +404,48 @@ func (s *Server) CreateSocks(ctx context.Context, req *sliverpb.Socks) (*sliverp // CloseSocks - Client requests we close a Socks func (s *Server) CloseSocks(ctx context.Context, req *sliverpb.Socks) (*commonpb.Empty, error) { - err := core.SocksTunnels.Close(req.TunnelID) + defer func() { + if r := recover(); r != nil { + rpcLog.Errorf("Recovered from panic in CloseSocks for tunnel %d: %v", req.TunnelID, r) + } + }() + + tunnel := core.SocksTunnels.Get(req.TunnelID) + if tunnel != nil { + // Signal close to implant first + if session := core.Sessions.Get(tunnel.SessionID); session != nil { + data, _ := proto.Marshal(&sliverpb.SocksData{ + TunnelID: req.TunnelID, + CloseConn: true, + }) + session.Connection.Send <- &sliverpb.Envelope{ + Type: sliverpb.MsgSocksData, + Data: data, + } + } + time.Sleep(100 * time.Millisecond) // Delay to allow close message to be sent + tunnel.Client = nil // Cleanup the tunnel + if tunnel.FromImplant != nil { + select { + case _, ok := <-tunnel.FromImplant: + if ok { + close(tunnel.FromImplant) + } + default: + close(tunnel.FromImplant) + } + tunnel.FromImplant = nil + } + } + + // Clean up caches toImplantCacheSocks.DeleteTun(req.TunnelID) fromImplantCacheSocks.DeleteTun(req.TunnelID) - if err != nil { - return nil, err + + // Remove from core tunnels last + if err := core.SocksTunnels.Close(req.TunnelID); err != nil { + rpcLog.Errorf("Error closing tunnel %d: %v", req.TunnelID, err) } + return &commonpb.Empty{}, nil }