diff --git a/pkg/server/udp.go b/pkg/server/udp.go index 339e4e442..5a6e64d3e 100644 --- a/pkg/server/udp.go +++ b/pkg/server/udp.go @@ -33,8 +33,8 @@ import ( // cmcUDPConn can read and write cmsg. type cmcUDPConn interface { - readFrom(b []byte) (n int, cm any, src net.Addr, err error) - writeTo(b []byte, cm any, dst net.Addr) (n int, err error) + readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) + writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) } func (s *Server) ServeUDP(c net.PacketConn) error { @@ -61,28 +61,28 @@ func (s *Server) ServeUDP(c net.PacketConn) error { var cmc cmcUDPConn var err error uc, ok := c.(*net.UDPConn) - if ok { - cmc, err = newUDPConn(uc) + if ok && uc.LocalAddr().(*net.UDPAddr).IP.IsUnspecified() { + cmc, err = newCmc(uc) if err != nil { return fmt.Errorf("failed to control socket cmsg, %w", err) } } else { - cmc = newDummyUDPConn(c) + cmc = newDummyCmc(c) } for { - n, cm, clientNetAddr, err := cmc.readFrom(rb) + n, localAddr, ifIndex, remoteAddr, err := cmc.readFrom(rb) if err != nil { if s.Closed() { return ErrServerClosed } return fmt.Errorf("unexpected read err: %w", err) } - clientAddr := utils.GetAddrFromAddr(clientNetAddr) + clientAddr := utils.GetAddrFromAddr(remoteAddr) q := new(dns.Msg) if err := q.Unpack(rb[:n]); err != nil { - s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", rb[:n]), zap.Stringer("from", clientNetAddr)) + s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", rb[:n]), zap.Stringer("from", remoteAddr)) continue } @@ -105,8 +105,8 @@ func (s *Server) ServeUDP(c net.PacketConn) error { return } defer buf.Release() - if _, err := cmc.writeTo(b, cm, clientNetAddr); err != nil { - s.opts.Logger.Warn("failed to write response", zap.Stringer("client", clientNetAddr), zap.Error(err)) + if _, err := cmc.writeTo(b, localAddr, ifIndex, remoteAddr); err != nil { + s.opts.Logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err)) } } }() @@ -124,22 +124,22 @@ func getUDPSize(m *dns.Msg) int { return int(s) } -// newDummyUDPConn returns a dummyWrapper. -func newDummyUDPConn(c net.PacketConn) cmcUDPConn { - return dummyWrapper{c: c} +// newDummyCmc returns a dummyCmcWrapper. +func newDummyCmc(c net.PacketConn) cmcUDPConn { + return dummyCmcWrapper{c: c} } -// dummyWrapper is just a wrapper that implements cmcUDPConn but does not +// dummyCmcWrapper is just a wrapper that implements cmcUDPConn but does not // write or read any control msg. -type dummyWrapper struct { +type dummyCmcWrapper struct { c net.PacketConn } -func (w dummyWrapper) readFrom(b []byte) (n int, cm any, src net.Addr, err error) { +func (w dummyCmcWrapper) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) { n, src, err = w.c.ReadFrom(b) return } -func (w dummyWrapper) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) { +func (w dummyCmcWrapper) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) { return w.c.WriteTo(b, dst) } diff --git a/pkg/server/udp_linux.go b/pkg/server/udp_linux.go index c0b26400b..4eb466de7 100644 --- a/pkg/server/udp_linux.go +++ b/pkg/server/udp_linux.go @@ -30,102 +30,106 @@ import ( "os" ) -type protocol int - -const ( - invalid protocol = iota - v4 - v6 -) - -type ipv4PacketConn struct { +type ipv4cmc struct { c *ipv4.PacketConn } -func (i ipv4PacketConn) readFrom(b []byte) (n int, cm any, src net.Addr, err error) { - return i.c.ReadFrom(b) +func newIpv4cmc(c *ipv4.PacketConn) *ipv4cmc { + return &ipv4cmc{c: c} } -func (i ipv4PacketConn) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) { - cm4 := cm.(*ipv4.ControlMessage) - cm4.Src = cm4.Dst - cm4.Dst = nil - return i.c.WriteTo(b, cm4, dst) +func (i *ipv4cmc) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) { + n, cm, src, err := i.c.ReadFrom(b) + if cm != nil { + dst, IfIndex = cm.Dst, cm.IfIndex + } + return } -type ipv6PacketConn struct { +func (i *ipv4cmc) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) { + cm := &ipv4.ControlMessage{ + Src: src, + IfIndex: IfIndex, + } + return i.c.WriteTo(b, cm, dst) +} + +type ipv6cmc struct { c4 *ipv4.PacketConn // ipv4 entrypoint for sending ipv4 packages. c6 *ipv6.PacketConn } -func (i ipv6PacketConn) readFrom(b []byte) (n int, cm any, src net.Addr, err error) { - return i.c6.ReadFrom(b) +func newIpv6PacketConn(c4 *ipv4.PacketConn, c6 *ipv6.PacketConn) *ipv6cmc { + return &ipv6cmc{c4: c4, c6: c6} } -func (i ipv6PacketConn) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) { - cm6 := cm.(*ipv6.ControlMessage) - cm6.Src = cm6.Dst - cm6.Dst = nil - - // If src is ipv4, use IP_PKTINFO instead of IPV6_PKTINFO. - // Otherwise, sendmsg will raise "invalid argument" error. - // No official doc found. - if src4 := cm6.Src.To4(); src4 != nil { - return i.c4.WriteTo(b, &ipv4.ControlMessage{ - Src: src4, - IfIndex: cm6.IfIndex, - }, dst) - } else { - return i.c6.WriteTo(b, cm6, dst) +func (i *ipv6cmc) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) { + n, cm, src, err := i.c6.ReadFrom(b) + if cm != nil { + dst, IfIndex = cm.Dst, cm.IfIndex } + return } -func newUDPConn(c *net.UDPConn) (cmcUDPConn, error) { - p, err := getSocketIPProtocol(c) - if err != nil { - return nil, fmt.Errorf("failed to get socket ip protocol, %w", err) - } - switch p { - case v4: - c := ipv4.NewPacketConn(c) - if err := c.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true); err != nil { - return nil, fmt.Errorf("failed to set ipv4 cmsg flags, %w", err) - } - return ipv4PacketConn{c: c}, nil - case v6: - c6 := ipv6.NewPacketConn(c) - if err := c6.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true); err != nil { - return nil, fmt.Errorf("failed to set ipv6 cmsg flags, %w", err) +func (i *ipv6cmc) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) { + if src != nil { + // If src is ipv4, use IP_PKTINFO instead of IPV6_PKTINFO. + // Otherwise, sendmsg will raise "invalid argument" error. + // No official doc found. + if src4 := src.To4(); src4 != nil { + cm4 := &ipv4.ControlMessage{ + Src: src4, + IfIndex: IfIndex, + } + return i.c4.WriteTo(b, cm4, dst) } - return ipv6PacketConn{c6: c6, c4: ipv4.NewPacketConn(c)}, nil - default: - return nil, fmt.Errorf("unknow protocol %d", p) } + cm6 := &ipv6.ControlMessage{ + Src: src, + IfIndex: IfIndex, + } + return i.c6.WriteTo(b, cm6, dst) } -func getSocketIPProtocol(c *net.UDPConn) (protocol, error) { +func newCmc(c *net.UDPConn) (cmcUDPConn, error) { sc, err := c.SyscallConn() if err != nil { - return 0, err + return nil, err } - proto := invalid - var syscallErr error - if controlErr := sc.Control(func(fd uintptr) { + + var controlErr error + var cmc cmcUDPConn + + if err := sc.Control(func(fd uintptr) { v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_DOMAIN) if err != nil { - syscallErr = os.NewSyscallError("failed to get SO_PROTOCOL", err) + controlErr = os.NewSyscallError("failed to get SO_PROTOCOL", err) return } switch v { case unix.AF_INET: - proto = v4 + c4 := ipv4.NewPacketConn(c) + if err := c4.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true); err != nil { + controlErr = fmt.Errorf("failed to set ipv4 cmsg flags, %w", err) + } + cmc = newIpv4cmc(c4) + return case unix.AF_INET6: - proto = v6 + c6 := ipv6.NewPacketConn(c) + if err := c6.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true); err != nil { + controlErr = fmt.Errorf("failed to set ipv6 cmsg flags, %w", err) + } + cmc = newIpv6PacketConn(ipv4.NewPacketConn(c), c6) + return default: - syscallErr = fmt.Errorf("socket protocol %d is not supported", v) + controlErr = fmt.Errorf("socket protocol %d is not supported", v) } }); err != nil { - return 0, fmt.Errorf("control fd err, %w", controlErr) + return nil, fmt.Errorf("control fd err, %w", controlErr) + } + + if controlErr != nil { + return nil, fmt.Errorf("failed to set up socket, %w", controlErr) } - return proto, syscallErr + return cmc, nil } diff --git a/pkg/server/udp_others.go b/pkg/server/udp_others.go index 4842e742e..8ce628064 100644 --- a/pkg/server/udp_others.go +++ b/pkg/server/udp_others.go @@ -23,6 +23,6 @@ package server import "net" -func newUDPConn(c *net.UDPConn) (cmcUDPConn, error) { - return newDummyUDPConn(c), nil +func newCmc(c *net.UDPConn) (cmcUDPConn, error) { + return newDummyCmc(c), nil }