Skip to content

Commit

Permalink
Overhaul UDP server (#141)
Browse files Browse the repository at this point in the history
* Close upstream UDP conns when downstream is closed

* Ensure UDP buffers are sufficiently sized

* Use single goroutine per UDP connection
  • Loading branch information
jtackaberry authored May 29, 2024
1 parent 83ccc7e commit 6a8be7c
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 26 deletions.
166 changes: 145 additions & 21 deletions layer4/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package layer4
import (
"bytes"
"fmt"
"io"
"net"
"sync"
"time"
Expand Down Expand Up @@ -85,23 +86,69 @@ func (s Server) serve(ln net.Listener) error {
}

func (s Server) servePacket(pc net.PacketConn) error {
// Spawn a goroutine whose only job is to consume packets from the socket
// and send to the packets channel.
packets := make(chan packet, 10)
go func(packets chan packet) {
for {
buf := udpBufPool.Get().([]byte)
n, addr, err := pc.ReadFrom(buf)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
packets <- packet{err: err}
return
}
packets <- packet{
pooledBuf: buf,
n: n,
addr: addr,
}
}
}(packets)

// udpConns tracks active packetConns by downstream address:port. They will
// be removed from this map after being closed.
udpConns := make(map[string]*packetConn)
// closeCh is used to receive notifications of socket closures from
// packetConn, which allows us to to remove stale connections (whose
// proxy handlers have completed) from the udpConns map.
closeCh := make(chan string, 10)
for {
buf := udpBufPool.Get().([]byte)
n, addr, err := pc.ReadFrom(buf)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
select {
case addr := <-closeCh:
// UDP connection is closed (either implicitly through timeout or by
// explicit call to Close()).
delete(udpConns, addr)

case pkt := <-packets:
if pkt.err != nil {
return pkt.err
}
return err
conn, ok := udpConns[pkt.addr.String()]
if !ok {
// No existing proxy handler is running for this downstream.
// Create one now.
conn = &packetConn{
PacketConn: pc,
readCh: make(chan *packet, 5),
addr: pkt.addr,
closeCh: closeCh,
}
udpConns[pkt.addr.String()] = conn
go func(conn *packetConn) {
s.handle(conn)
// It might seem cleaner to send to closeCh here rather than
// in packetConn, but doing it earlier in packetConn closes
// the gap between the proxy handler shutting down and new
// packets coming in from the same downstream. Should that
// happen, we'll just spin up a new handler concurrent to
// the old one shutting down.
}(conn)
}
conn.readCh <- &pkt
}
go func(buf []byte, n int, addr net.Addr) {
defer udpBufPool.Put(buf)
s.handle(packetConn{
PacketConn: pc,
buf: bytes.NewBuffer(buf[:n]),
addr: addr,
})
}(buf, n, addr)
}
}

Expand Down Expand Up @@ -129,29 +176,106 @@ func (s Server) handle(conn net.Conn) {
)
}

type packet struct {
// The underlying bytes slice that was gotten from udpBufPool. It's up to
// packetConn to return it to udpBufPool once it's consumed.
pooledBuf []byte
// Number of bytes read from socket
n int
// Error that occurred while reading from socket
err error
// Address of downstream
addr net.Addr
}

type packetConn struct {
net.PacketConn
buf *bytes.Buffer
addr net.Addr
addr net.Addr
readCh chan *packet
closeCh chan string
// If not nil, then the previous Read() call didn't consume all the data
// from the buffer, and this packet will be reused in the next Read()
// without waiting for readCh.
lastPacket *packet
lastBuf *bytes.Buffer
}

func (pc packetConn) Read(b []byte) (n int, err error) {
return pc.buf.Read(b)
func (pc *packetConn) Read(b []byte) (n int, err error) {
if pc.lastPacket != nil {
// There is a partial buffer to continue reading from the previous
// packet.
n, err = pc.lastBuf.Read(b)
if pc.lastBuf.Len() == 0 {
udpBufPool.Put(pc.lastPacket.pooledBuf)
pc.lastPacket = nil
pc.lastBuf = nil
}
return
}
select {
case pkt := <-pc.readCh:
if pkt == nil {
// Channel is closed. Return EOF below.
break
}
buf := bytes.NewBuffer(pkt.pooledBuf[:pkt.n])
n, err = buf.Read(b)
if buf.Len() == 0 {
// Buffer fully consumed, release it.
udpBufPool.Put(pkt.pooledBuf)
} else {
// Buffer only partially consumed. Keep track of it for
// next Read() call.
pc.lastPacket = pkt
pc.lastBuf = buf
}
return
// TODO: idle timeout should be configurable per server
case <-time.After(30 * time.Second):
break
}
// Idle timeout simulates socket closure.
//
// Although Close() also does this, we inform the server loop early about
// the closure to ensure that if any new packets are received from this
// connection in the meantime, a new handler will be started.
pc.closeCh <- pc.addr.String()
// Returning EOF here ensures that io.Copy() waiting on the downstream for
// reads will terminate.
return 0, io.EOF
}

func (pc packetConn) Write(b []byte) (n int, err error) {
return pc.PacketConn.WriteTo(b, pc.addr)
}

func (pc packetConn) Close() error {
// Do nothing, we don't want to close the UDP server
func (pc *packetConn) Close() error {
if pc.lastPacket != nil {
udpBufPool.Put(pc.lastPacket.pooledBuf)
pc.lastPacket = nil
}
// This will abort any active Read() from another goroutine and return EOF
close(pc.readCh)
// Drain pending packets to ensure we release buffers back to the pool
for pkt := range pc.readCh {
udpBufPool.Put(pkt.pooledBuf)
}
// We may have already done this earlier in Read(), but just in case
// Read() wasn't being called, (re-)notify server loop we're closed.
pc.closeCh <- pc.addr.String()
// We don't call net.PacketConn.Close() here as we would stop the UDP
// server.
return nil
}

func (pc packetConn) RemoteAddr() net.Addr { return pc.addr }

var udpBufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 1024)
// Buffers need to be as large as the largest datagram we'll consume, because
// ReadFrom() can't resume partial reads. (This is standard for UDP
// sockets on *nix.) So our buffer sizes are 9000 bytes to accommodate
// networks with jumbo frames. See also https://github.com/golang/go/issues/18056
return make([]byte, 9000)
},
}
26 changes: 21 additions & 5 deletions modules/l4proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net"
"runtime/debug"
"sync"
"sync/atomic"
"time"

"github.com/caddyserver/caddy/v2"
Expand Down Expand Up @@ -253,6 +254,7 @@ func (h *Handler) proxy(down *layer4.Connection, upConns []net.Conn) {
}

var wg sync.WaitGroup
var downClosed atomic.Bool

for _, up := range upConns {
wg.Add(1)
Expand All @@ -261,11 +263,16 @@ func (h *Handler) proxy(down *layer4.Connection, upConns []net.Conn) {
defer wg.Done()

if _, err := io.Copy(down, up); err != nil {
h.logger.Error("upstream connection",
zap.String("local_address", up.LocalAddr().String()),
zap.String("remote_address", up.RemoteAddr().String()),
zap.Error(err),
)
// If the downstream connection has been closed, we can assume this is
// the reason io.Copy() errored. That's normal operation for UDP
// connections after idle timeout, so don't log an error in that case.
if !downClosed.Load() {
h.logger.Error("upstream connection",
zap.String("local_address", up.LocalAddr().String()),
zap.String("remote_address", up.RemoteAddr().String()),
zap.Error(err),
)
}
}
}(up)
}
Expand All @@ -280,9 +287,18 @@ func (h *Handler) proxy(down *layer4.Connection, upConns []net.Conn) {

// Shut down the writing side of all upstream connections, in case
// that the downstream connection is half closed. (issue #40)
//
// UDP connections meanwhile don't implement CloseWrite(), but in order
// to ensure io.Copy() in the per-upstream goroutines (above) returns,
// we need to close the socket. This will cause io.Copy() return an
// error, which in this particular case is expected, so we signal the
// intentional closure by setting this flag.
downClosed.Store(true)
for _, up := range upConns {
if conn, ok := up.(closeWriter); ok {
_ = conn.CloseWrite()
} else {
up.Close()
}
}
}()
Expand Down

0 comments on commit 6a8be7c

Please sign in to comment.