From bf7110b1abe496915121b7b05d86e5ee46852ef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 18 Apr 2023 09:29:29 +0800 Subject: [PATCH 1/5] Update udpnat usage --- go.mod | 4 ++-- go.sum | 8 ++++---- gvisor_udp.go | 2 +- system.go | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index ccd0e19..ac846d6 100644 --- a/go.mod +++ b/go.mod @@ -6,9 +6,9 @@ require ( github.com/fsnotify/fsnotify v1.6.0 github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 - github.com/sagernet/sing v0.2.1 + github.com/sagernet/sing v0.2.4-0.20230418025125-f196b4303e31 golang.org/x/net v0.8.0 - golang.org/x/sys v0.6.0 + golang.org/x/sys v0.7.0 gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c ) diff --git a/go.sum b/go.sum index 77e297e..19e759e 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61/go.mod h github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE= github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= -github.com/sagernet/sing v0.2.1 h1:r0STYeyfKBBtoAHsBtW1dQonxG+3Qidde7/1VAMhdn8= -github.com/sagernet/sing v0.2.1/go.mod h1:9uHswk2hITw8leDbiLS/xn0t9nzBcbePxzm9PJhwdlw= +github.com/sagernet/sing v0.2.4-0.20230418025125-f196b4303e31 h1:qgq8jeY/rbnY9NwYXByO//AP0ByIxnsKUxQx1tOB3W0= +github.com/sagernet/sing v0.2.4-0.20230418025125-f196b4303e31/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= @@ -16,8 +16,8 @@ golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c h1:m5lcgWnL3OElQNVyp3qcncItJ2c0sQlSGjYK2+nJTA4= diff --git a/gvisor_udp.go b/gvisor_udp.go index e62157d..38a24a4 100644 --- a/gvisor_udp.go +++ b/gvisor_udp.go @@ -29,7 +29,7 @@ func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, u return &UDPForwarder{ ctx: ctx, stack: stack, - udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), + udpNat: udpnat.New[netip.AddrPort](ctx, udpTimeout, handler), } } diff --git a/system.go b/system.go index 1ebeb56..f6e4e04 100644 --- a/system.go +++ b/system.go @@ -116,7 +116,7 @@ func (s *System) Start() error { go s.acceptLoop(tcpListener) } s.tcpNat = NewNat() - s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler) + s.udpNat = udpnat.New[netip.AddrPort](s.ctx, s.udpTimeout, s.handler) go s.tunLoop() return nil } From 53f50347e04032a4a67204ccde578faac5665e47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 18 Apr 2023 21:21:59 +0800 Subject: [PATCH 2/5] Fix system nat mapping --- gvisor.go | 9 +++++--- route_mapping.go | 5 ++++- system.go | 15 ++++++++------ system_nat.go | 54 ++++++++++++++++++++++++++++++++++++------------ 4 files changed, 60 insertions(+), 23 deletions(-) diff --git a/gvisor.go b/gvisor.go index 81e8efd..9e42a7c 100644 --- a/gvisor.go +++ b/gvisor.go @@ -57,7 +57,7 @@ func NewGVisor( return nil, E.New("gVisor stack is unsupported on current platform") } - return &GVisor{ + gStack := &GVisor{ ctx: options.Context, tun: gTun, tunMtu: options.MTU, @@ -66,8 +66,11 @@ func NewGVisor( router: options.Router, handler: options.Handler, logger: options.Logger, - routeMapping: NewRouteMapping(options.UDPTimeout), - }, nil + } + if gStack.router != nil { + gStack.routeMapping = NewRouteMapping(options.Context, options.UDPTimeout) + } + return gStack, nil } func (t *GVisor) Start() error { diff --git a/route_mapping.go b/route_mapping.go index b0ccad8..be89853 100644 --- a/route_mapping.go +++ b/route_mapping.go @@ -1,6 +1,8 @@ package tun import ( + "context" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/cache" ) @@ -9,9 +11,10 @@ type RouteMapping struct { status *cache.LruCache[RouteSession, RouteAction] } -func NewRouteMapping(maxAge int64) *RouteMapping { +func NewRouteMapping(ctx context.Context, maxAge int64) *RouteMapping { return &RouteMapping{ status: cache.New( + cache.WithContext[RouteSession, RouteAction](ctx), cache.WithAge[RouteSession, RouteAction](maxAge), cache.WithUpdateAgeOnGet[RouteSession, RouteAction](), cache.WithEvict[RouteSession, RouteAction](func(key RouteSession, conn RouteAction) { diff --git a/system.go b/system.go index f6e4e04..9705d07 100644 --- a/system.go +++ b/system.go @@ -63,7 +63,9 @@ func NewSystem(options StackOptions) (Stack, error) { inet4Prefixes: options.Inet4Address, inet6Prefixes: options.Inet6Address, underPlatform: options.UnderPlatform, - routeMapping: NewRouteMapping(options.UDPTimeout), + } + if stack.router != nil { + stack.routeMapping = NewRouteMapping(options.Context, options.UDPTimeout) } if len(options.Inet4Address) > 0 { if options.Inet4Address[0].Bits() == 32 { @@ -115,7 +117,7 @@ func (s *System) Start() error { s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port go s.acceptLoop(tcpListener) } - s.tcpNat = NewNat() + s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout)) s.udpNat = udpnat.New[netip.AddrPort](s.ctx, s.udpTimeout, s.handler) go s.tunLoop() return nil @@ -208,13 +210,14 @@ func (s *System) acceptLoop(listener net.Listener) { } } go func() { - s.handler.NewConnection(s.ctx, conn, M.Metadata{ + _ = s.handler.NewConnection(s.ctx, conn, M.Metadata{ Source: M.SocksaddrFromNetIP(session.Source), Destination: destination, }) - conn.Close() - time.Sleep(time.Second) - s.tcpNat.Revoke(connPort, session) + if tcpConn, isTCPConn := conn.(*net.TCPConn); isTCPConn { + _ = tcpConn.SetLinger(0) + } + _ = conn.Close() }() } } diff --git a/system_nat.go b/system_nat.go index adac1a6..ff80413 100644 --- a/system_nat.go +++ b/system_nat.go @@ -1,8 +1,10 @@ package tun import ( + "context" "net/netip" "sync" + "time" ) type TCPNat struct { @@ -16,20 +18,54 @@ type TCPNat struct { type TCPSession struct { Source netip.AddrPort Destination netip.AddrPort + LastActive time.Time } -func NewNat() *TCPNat { - return &TCPNat{ +func NewNat(ctx context.Context, timeout time.Duration) *TCPNat { + natMap := &TCPNat{ portIndex: 10000, addrMap: make(map[netip.AddrPort]uint16), portMap: make(map[uint16]*TCPSession), } + go natMap.loopCheckTimeout(ctx, timeout) + return natMap +} + +func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) { + ticker := time.NewTicker(timeout) + defer ticker.Stop() + for { + select { + case <-ticker.C: + n.checkTimeout(timeout) + case <-ctx.Done(): + return + } + } +} + +func (n *TCPNat) checkTimeout(timeout time.Duration) { + now := time.Now() + n.portAccess.Lock() + defer n.portAccess.Unlock() + n.addrAccess.Lock() + defer n.addrAccess.Unlock() + for natPort, session := range n.portMap { + if now.Sub(session.LastActive) > timeout { + delete(n.addrMap, session.Source) + delete(n.portMap, natPort) + } + } } func (n *TCPNat) LookupBack(port uint16) *TCPSession { n.portAccess.RLock() - defer n.portAccess.RUnlock() - return n.portMap[port] + session := n.portMap[port] + n.portAccess.RUnlock() + if session != nil { + session.LastActive = time.Now() + } + return session } func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint16 { @@ -53,16 +89,8 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint1 n.portMap[nextPort] = &TCPSession{ Source: source, Destination: destination, + LastActive: time.Now(), } n.portAccess.Unlock() return nextPort } - -func (n *TCPNat) Revoke(natPort uint16, session *TCPSession) { - n.addrAccess.Lock() - delete(n.addrMap, session.Source) - n.addrAccess.Unlock() - n.portAccess.Lock() - delete(n.portMap, natPort) - n.portAccess.Unlock() -} From d744d03d9302b841101d611b637b3aae564a83c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 19 Apr 2023 09:10:45 +0800 Subject: [PATCH 3/5] Fix default interface check for darwin --- monitor_darwin.go | 117 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 99 insertions(+), 18 deletions(-) diff --git a/monitor_darwin.go b/monitor_darwin.go index f6b0f52..296ba54 100644 --- a/monitor_darwin.go +++ b/monitor_darwin.go @@ -5,10 +5,9 @@ import ( "net" "net/netip" "os" - "runtime" - "strings" "sync" "syscall" + "time" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" @@ -85,32 +84,114 @@ func (m *defaultInterfaceMonitor) checkUpdate() error { if err != nil { return err } + var defaultInterface *net.Interface for _, rawRouteMessage := range routeMessages { routeMessage := rawRouteMessage.(*route.RouteMessage) + if len(routeMessage.Addrs) <= unix.RTAX_NETMASK { + continue + } + destination, isIPv4Destination := routeMessage.Addrs[unix.RTAX_DST].(*route.Inet4Addr) + if !isIPv4Destination { + continue + } + if destination.IP != netip.IPv4Unspecified().As4() { + continue + } + mask, isIPv4Mask := routeMessage.Addrs[unix.RTAX_NETMASK].(*route.Inet4Addr) + if !isIPv4Mask { + continue + } + ones, _ := net.IPMask(mask.IP[:]).Size() + if ones != 0 { + continue + } routeInterface, err := net.InterfaceByIndex(routeMessage.Index) if err != nil { return err } - if runtime.GOOS == "ios" && strings.HasPrefix(routeInterface.Name, "utun") { + if routeMessage.Flags&unix.RTF_UP == 0 { + continue + } + if routeMessage.Flags&unix.RTF_GATEWAY == 0 { + continue + } + if routeMessage.Flags&unix.RTF_IFSCOPE != 0 { continue } - if common.Any(common.FilterIsInstance(routeMessage.Addrs, func(it route.Addr) (*route.Inet4Addr, bool) { - addr, loaded := it.(*route.Inet4Addr) - return addr, loaded - }), func(addr *route.Inet4Addr) bool { - return addr.IP == netip.IPv4Unspecified().As4() - }) { - oldInterface := m.defaultInterfaceName - oldIndex := m.defaultInterfaceIndex + defaultInterface = routeInterface + break + } + if defaultInterface == nil { + defaultInterface, err = getDefaultInterfaceBySocket() + if err != nil { + return err + } + } + oldInterface := m.defaultInterfaceName + oldIndex := m.defaultInterfaceIndex + m.defaultInterfaceIndex = defaultInterface.Index + m.defaultInterfaceName = defaultInterface.Name + if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex { + return nil + } + m.emit(EventInterfaceUpdate) + return nil +} - m.defaultInterfaceIndex = routeMessage.Index - m.defaultInterfaceName = routeInterface.Name - if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex { - return nil +func getDefaultInterfaceBySocket() (*net.Interface, error) { + socketFd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, 0) + if err != nil { + return nil, E.Cause(err, "create file descriptor") + } + defer unix.Close(socketFd) + go unix.Connect(socketFd, &unix.SockaddrInet4{ + Addr: [4]byte{10, 255, 255, 255}, + Port: 80, + }) + result := make(chan netip.Addr, 1) + go func() { + for { + sockname, sockErr := unix.Getsockname(socketFd) + if sockErr != nil { + break + } + sockaddr, isInet4Sockaddr := sockname.(*unix.SockaddrInet4) + if !isInet4Sockaddr { + break + } + addr := netip.AddrFrom4(sockaddr.Addr) + if addr.IsUnspecified() { + time.Sleep(time.Millisecond) + continue + } + result <- addr + break + } + }() + var selectedAddr netip.Addr + select { + case selectedAddr = <-result: + case <-time.After(time.Second): + return nil, os.ErrDeadlineExceeded + } + interfaces, err := net.Interfaces() + if err != nil { + return nil, E.Cause(err, "net.Interfaces") + } + for _, netInterface := range interfaces { + interfaceAddrs, err := netInterface.Addrs() + if err != nil { + return nil, E.Cause(err, "net.Interfaces.Addrs") + } + for _, interfaceAddr := range interfaceAddrs { + ipNet, isIPNet := interfaceAddr.(*net.IPNet) + if !isIPNet { + continue + } + if ipNet.Contains(selectedAddr.AsSlice()) { + return &netInterface, nil } - m.emit(EventInterfaceUpdate) - return nil } } - return ErrNoRoute + return nil, E.New("no interface found for address ", selectedAddr) } From e46ae0b2b097a82ecca39c02a7e52907e5b6f29b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 19 Apr 2023 22:57:56 +0800 Subject: [PATCH 4/5] Fix lwip build --- lwip.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lwip.go b/lwip.go index 7a37c3a..23954c2 100644 --- a/lwip.go +++ b/lwip.go @@ -35,7 +35,7 @@ func NewLWIP( tunMtu: options.MTU, handler: options.Handler, stack: lwip.NewLWIPStack(), - udpNat: udpnat.New[netip.AddrPort](options.UDPTimeout, options.Handler), + udpNat: udpnat.New[netip.AddrPort](options.Context, options.UDPTimeout, options.Handler), }, nil } From 510a1815ba26c633cd728ff8d7de6e7412527725 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 19 Apr 2023 23:18:49 +0800 Subject: [PATCH 5/5] Update udpnat usage --- go.mod | 2 +- go.sum | 4 ++-- gvisor.go | 12 +++++++++--- gvisor_udp.go | 6 +++++- lwip.go | 7 +++++-- route_mapping.go | 11 ++++++++++- system.go | 6 ++++-- 7 files changed, 36 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index ac846d6..1600f2e 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/fsnotify/fsnotify v1.6.0 github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 - github.com/sagernet/sing v0.2.4-0.20230418025125-f196b4303e31 + github.com/sagernet/sing v0.2.4-0.20230419150837-2b3a62786474 golang.org/x/net v0.8.0 golang.org/x/sys v0.7.0 gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c diff --git a/go.sum b/go.sum index 19e759e..b00f9c6 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61/go.mod h github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE= github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= -github.com/sagernet/sing v0.2.4-0.20230418025125-f196b4303e31 h1:qgq8jeY/rbnY9NwYXByO//AP0ByIxnsKUxQx1tOB3W0= -github.com/sagernet/sing v0.2.4-0.20230418025125-f196b4303e31/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= +github.com/sagernet/sing v0.2.4-0.20230419150837-2b3a62786474 h1:eSYMHrZvHo9hTKSAPTFaGsiafUss7FPhTDanXsNrfwE= +github.com/sagernet/sing v0.2.4-0.20230419150837-2b3a62786474/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= diff --git a/gvisor.go b/gvisor.go index 9e42a7c..9406466 100644 --- a/gvisor.go +++ b/gvisor.go @@ -8,6 +8,7 @@ import ( "syscall" "time" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" @@ -42,6 +43,7 @@ type GVisor struct { stack *stack.Stack endpoint stack.LinkEndpoint routeMapping *RouteMapping + udpForwarder *UDPForwarder } type GVisorTun interface { @@ -68,7 +70,7 @@ func NewGVisor( logger: options.Logger, } if gStack.router != nil { - gStack.routeMapping = NewRouteMapping(options.Context, options.UDPTimeout) + gStack.routeMapping = NewRouteMapping(options.UDPTimeout) } return gStack, nil } @@ -256,7 +258,8 @@ func (t *GVisor) Start() error { return udpForwarder.HandlePacket(id, buffer) }) } else { - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) + t.udpForwarder = NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, t.udpForwarder.HandlePacket) } t.stack = ipStack @@ -270,5 +273,8 @@ func (t *GVisor) Close() error { for _, endpoint := range t.stack.CleanupEndpoints() { endpoint.Abort() } - return nil + return common.Close( + common.PtrOrNil(t.routeMapping), + common.PtrOrNil(t.udpForwarder), + ) } diff --git a/gvisor_udp.go b/gvisor_udp.go index 38a24a4..688b2c7 100644 --- a/gvisor_udp.go +++ b/gvisor_udp.go @@ -29,10 +29,14 @@ func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, u return &UDPForwarder{ ctx: ctx, stack: stack, - udpNat: udpnat.New[netip.AddrPort](ctx, udpTimeout, handler), + udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), } } +func (f *UDPForwarder) Close() error { + return f.udpNat.Close() +} + func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { var upstreamMetadata M.Metadata upstreamMetadata.Source = M.SocksaddrFrom(M.AddrFromIP(net.IP(id.RemoteAddress)), id.RemotePort) diff --git a/lwip.go b/lwip.go index 23954c2..2933dd9 100644 --- a/lwip.go +++ b/lwip.go @@ -35,7 +35,7 @@ func NewLWIP( tunMtu: options.MTU, handler: options.Handler, stack: lwip.NewLWIPStack(), - udpNat: udpnat.New[netip.AddrPort](options.Context, options.UDPTimeout, options.Handler), + udpNat: udpnat.New[netip.AddrPort](options.UDPTimeout, options.Handler), }, nil } @@ -96,7 +96,10 @@ func (l *LWIP) Close() error { lwip.RegisterOutputFn(func(bytes []byte) (int, error) { return 0, os.ErrClosed }) - return l.stack.Close() + return common.Close( + l.stack, + common.PtrOrNil(l.udpNat), + ) } func (l *LWIP) Handle(conn net.Conn) error { diff --git a/route_mapping.go b/route_mapping.go index be89853..ca4feda 100644 --- a/route_mapping.go +++ b/route_mapping.go @@ -2,6 +2,7 @@ package tun import ( "context" + "net" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/cache" @@ -9,9 +10,11 @@ import ( type RouteMapping struct { status *cache.LruCache[RouteSession, RouteAction] + cancel common.ContextCancelCauseFunc } -func NewRouteMapping(ctx context.Context, maxAge int64) *RouteMapping { +func NewRouteMapping(maxAge int64) *RouteMapping { + ctx, cancel := common.ContextWithCancelCause(context.Background()) return &RouteMapping{ status: cache.New( cache.WithContext[RouteSession, RouteAction](ctx), @@ -21,6 +24,7 @@ func NewRouteMapping(ctx context.Context, maxAge int64) *RouteMapping { common.Close(conn) }), ), + cancel: cancel, } } @@ -33,3 +37,8 @@ func (m *RouteMapping) Lookup(session RouteSession, constructor func() RouteActi } return action } + +func (m *RouteMapping) Close() error { + m.cancel(net.ErrClosed) + return nil +} diff --git a/system.go b/system.go index 9705d07..2067dbd 100644 --- a/system.go +++ b/system.go @@ -65,7 +65,7 @@ func NewSystem(options StackOptions) (Stack, error) { underPlatform: options.UnderPlatform, } if stack.router != nil { - stack.routeMapping = NewRouteMapping(options.Context, options.UDPTimeout) + stack.routeMapping = NewRouteMapping(options.UDPTimeout) } if len(options.Inet4Address) > 0 { if options.Inet4Address[0].Bits() == 32 { @@ -91,6 +91,8 @@ func (s *System) Close() error { return common.Close( s.tcpListener, s.tcpListener6, + s.udpNat, + common.PtrOrNil(s.routeMapping), ) } @@ -118,7 +120,7 @@ func (s *System) Start() error { go s.acceptLoop(tcpListener) } s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout)) - s.udpNat = udpnat.New[netip.AddrPort](s.ctx, s.udpTimeout, s.handler) + s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler) go s.tunLoop() return nil }