Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SOCKS5 Stability Improvements for HTTP/mTLS Transports #1807

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions client/core/socks.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/bishopfox/sliver/protobuf/rpcpb"
"github.com/bishopfox/sliver/protobuf/sliverpb"
"github.com/bishopfox/sliver/util/leaky"
"golang.org/x/time/rate"
)

var (
Expand Down Expand Up @@ -210,6 +211,8 @@ const leakyBufSize = 4108 // data.len(2) + hmacsha1(10) + data(4096)
var leakyBuf = leaky.NewLeakyBuf(2048, leakyBufSize)

func connect(conn net.Conn, stream rpcpb.SliverRPC_SocksProxyClient, frame *sliverpb.SocksData) {
// Client Rate Limiter: 10 operations per second, burst of 1
limiter := rate.NewLimiter(rate.Limit(10), 1)

SocksConnPool.Store(frame.TunnelID, conn)

Expand Down Expand Up @@ -241,6 +244,11 @@ func connect(conn net.Conn, stream rpcpb.SliverRPC_SocksProxyClient, frame *sliv
return
}
if n > 0 {
if err := limiter.Wait(context.Background()); err != nil {
log.Printf("[socks] rate limiter error: %s", err)
return
}

frame.Data = buff[:n]
frame.Sequence = ToImplantSequence
log.Printf("[socks] (User to Client) to Server to agent Data Sequence %d , Data Size %d \n", ToImplantSequence, len(frame.Data))
Expand Down
85 changes: 80 additions & 5 deletions implant/sliver/handlers/tunnel_handlers/socks_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,29 @@ import (
"google.golang.org/protobuf/proto"
)

const (
inactivityCheckInterval = 4 * time.Second
inactivityTimeout = 15 * time.Second
)

type socksTunnelPool struct {
tunnels *sync.Map // map[uint64]chan []byte
tunnels *sync.Map // map[uint64]chan []byte
lastActivity *sync.Map // map[uint64]time.Time
}

var socksTunnels = socksTunnelPool{
tunnels: &sync.Map{},
tunnels: &sync.Map{},
lastActivity: &sync.Map{},
}

var socksServer *socks5.Server

// Initialize socks server
func init() {
socksServer = socks5.NewServer()
socksTunnels.startCleanupMonitor()
}

func SocksReqHandler(envelope *sliverpb.Envelope, connection *transports.Connection) {
socksData := &sliverpb.SocksData{}
err := proto.Unmarshal(envelope.Data, socksData)
Expand All @@ -55,9 +68,26 @@ func SocksReqHandler(envelope *sliverpb.Envelope, connection *transports.Connect
// {{end}}
return
}
time.Sleep(10 * time.Millisecond) // Necessary delay

// Check early to see if this is a close request from server
if socksData.CloseConn {
if tunnel, ok := socksTunnels.tunnels.LoadAndDelete(socksData.TunnelID); ok {
if ch, ok := tunnel.(chan []byte); ok {
close(ch)
}
}
socksTunnels.lastActivity.Delete(socksData.TunnelID)
return
}

if socksData.Data == nil {
return
}

// Record activity as soon as we get data for this tunnel
socksTunnels.recordActivity(socksData.TunnelID)

// {{if .Config.Debug}}
log.Printf("[socks] User to Client to (server to implant) Data Sequence %d, Data Size %d\n", socksData.Sequence, len(socksData.Data))
// {{end}}
Expand All @@ -70,8 +100,6 @@ func SocksReqHandler(envelope *sliverpb.Envelope, connection *transports.Connect
socksServer = socks5.NewServer(
socks5.WithAuthMethods([]socks5.Authenticator{auth}),
)
} else {
socksServer = socks5.NewServer()
}

// {{if .Config.Debug}}
Expand All @@ -80,17 +108,20 @@ func SocksReqHandler(envelope *sliverpb.Envelope, connection *transports.Connect

// init tunnel
if tunnel, ok := socksTunnels.tunnels.Load(socksData.TunnelID); !ok {
tunnelChan := make(chan []byte, 10)
tunnelChan := make(chan []byte, 100) // Buffered channel for 100 messages
socksTunnels.tunnels.Store(socksData.TunnelID, tunnelChan)
tunnelChan <- socksData.Data
err := socksServer.ServeConn(&socks{stream: socksData, conn: connection})
if err != nil {
// {{if .Config.Debug}}
log.Printf("[socks] Failed to serve connection: %v", err)
// {{end}}
// Cleanup on serve failure
socksTunnels.tunnels.Delete(socksData.TunnelID)
return
}
} else {
// Will block when channel is full
tunnel.(chan []byte) <- socksData.Data
}
}
Expand All @@ -105,16 +136,22 @@ type socks struct {
}

func (s *socks) Read(b []byte) (n int, err error) {
time.Sleep(10 * time.Millisecond) // Necessary delay

channel, ok := socksTunnels.tunnels.Load(s.stream.TunnelID)
if !ok {
return 0, errors.New("[socks] invalid tunnel id")
}

socksTunnels.recordActivity(s.stream.TunnelID)
data := <-channel.(chan []byte)
return copy(b, data), nil
}

func (s *socks) Write(b []byte) (n int, err error) {
time.Sleep(10 * time.Millisecond) // Necessary delay

socksTunnels.recordActivity(s.stream.TunnelID)
data, err := proto.Marshal(&sliverpb.SocksData{
TunnelID: s.stream.TunnelID,
Data: b,
Expand All @@ -136,12 +173,15 @@ func (s *socks) Write(b []byte) (n int, err error) {
}

func (s *socks) Close() error {
time.Sleep(10 * time.Millisecond) // Necessary delay

channel, ok := socksTunnels.tunnels.LoadAndDelete(s.stream.TunnelID)
if !ok {
return errors.New("[socks] can't close unknown channel")
}
close(channel.(chan []byte))

// Signal to server that we need to close this tunnel
data, err := proto.Marshal(&sliverpb.SocksData{
TunnelID: s.stream.TunnelID,
CloseConn: true,
Expand Down Expand Up @@ -181,3 +221,38 @@ func (c *socks) SetReadDeadline(t time.Time) error {
func (c *socks) SetWriteDeadline(t time.Time) error {
return nil
}

func (s *socksTunnelPool) recordActivity(tunnelID uint64) {
s.lastActivity.Store(tunnelID, time.Now())
}

// Periodically check for inactive tunnels and clean up
func (s *socksTunnelPool) startCleanupMonitor() {
go func() {
ticker := time.NewTicker(inactivityCheckInterval)
defer ticker.Stop()

for range ticker.C {
s.tunnels.Range(func(key, value interface{}) bool {
tunnelID := key.(uint64)
lastActivityI, exists := s.lastActivity.Load(tunnelID)
if !exists {
// If no activity record exists, create one
s.recordActivity(tunnelID)
return true
}

lastActivity := lastActivityI.(time.Time)
if time.Since(lastActivity) > inactivityTimeout {
// Clean up the inactive tunnel
if ch, ok := value.(chan []byte); ok {
close(ch)
}
s.tunnels.Delete(tunnelID)
s.lastActivity.Delete(tunnelID)
}
return true
})
}
}()
}
Loading
Loading