From 2d777a87b54bbff3388e63f9f248c08b94cac11c Mon Sep 17 00:00:00 2001 From: Mark Pashmfouroush Date: Thu, 14 Mar 2024 19:38:11 +0000 Subject: [PATCH] proxy: copypasta proxy into here for fast fixes Signed-off-by: Mark Pashmfouroush --- app/app.go | 6 +- go.mod | 1 - go.sum | 2 - proxy/README.md | 67 ++++ proxy/example/customHandler/main.go | 35 +++ proxy/example/minimal/main.go | 10 + proxy/example/udpClient/main.go | 69 +++++ proxy/pkg/http/common.go | 84 +++++ proxy/pkg/http/server.go | 237 +++++++++++++++ proxy/pkg/mixed/handlers.go | 90 ++++++ proxy/pkg/mixed/proxy.go | 147 +++++++++ proxy/pkg/socks4/common.go | 167 ++++++++++ proxy/pkg/socks4/server.go | 242 +++++++++++++++ proxy/pkg/socks5/common.go | 415 +++++++++++++++++++++++++ proxy/pkg/socks5/server.go | 455 ++++++++++++++++++++++++++++ proxy/pkg/statute/statute.go | 76 +++++ proxy/pkg/statute/tunnel.go | 80 +++++ wiresocks/proxy.go | 9 +- 18 files changed, 2181 insertions(+), 11 deletions(-) create mode 100644 proxy/README.md create mode 100644 proxy/example/customHandler/main.go create mode 100644 proxy/example/minimal/main.go create mode 100644 proxy/example/udpClient/main.go create mode 100644 proxy/pkg/http/common.go create mode 100644 proxy/pkg/http/server.go create mode 100644 proxy/pkg/mixed/handlers.go create mode 100644 proxy/pkg/mixed/proxy.go create mode 100644 proxy/pkg/socks4/common.go create mode 100644 proxy/pkg/socks4/server.go create mode 100644 proxy/pkg/socks5/common.go create mode 100644 proxy/pkg/socks5/server.go create mode 100644 proxy/pkg/statute/statute.go create mode 100644 proxy/pkg/statute/tunnel.go diff --git a/app/app.go b/app/app.go index ffa8f2703..ebee0de97 100644 --- a/app/app.go +++ b/app/app.go @@ -118,7 +118,7 @@ func runWarp(ctx context.Context, l *slog.Logger, bind netip.AddrPort, endpoint } tnet.StartProxy(bind) - l.Info("Serving proxy", "address", bind) + l.Info("serving proxy", "address", bind) return nil } @@ -155,7 +155,7 @@ func runWarpWithPsiphon(ctx context.Context, l *slog.Logger, bind netip.AddrPort return fmt.Errorf("unable to run psiphon %w", err) } - l.Info("Serving proxy", "address", bind) + l.Info("serving proxy", "address", bind) return nil } @@ -210,7 +210,7 @@ func runWarpInWarp(ctx context.Context, l *slog.Logger, bind netip.AddrPort, end tnet.StartProxy(bind) - l.Info("Serving proxy", "address", bind) + l.Info("serving proxy", "address", bind) return nil } diff --git a/go.mod b/go.mod index ba1211a14..bbc54c0fc 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ replace github.com/Psiphon-Labs/psiphon-tunnel-core => github.com/bepass-org/psi require ( github.com/Psiphon-Labs/psiphon-tunnel-core v2.0.28+incompatible - github.com/bepass-org/proxy v0.0.0-20240201095508-c86216dd0aea github.com/fatih/color v1.16.0 github.com/flynn/noise v1.1.0 github.com/frankban/quicktest v1.14.6 diff --git a/go.sum b/go.sum index 730a87232..678cc3077 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,6 @@ github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/ github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f h1:SaJ6yqg936TshyeFZqQE+N+9hYkIeL9AMr7S4voCl10= github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU= -github.com/bepass-org/proxy v0.0.0-20240201095508-c86216dd0aea h1:6GKkjxDUxqq7uwA8U15N4PFURhdNN0OrxFuXc58MGUU= -github.com/bepass-org/proxy v0.0.0-20240201095508-c86216dd0aea/go.mod h1:RlF0oO3D6Ju6VYjtL1I6lVLdc3l8jA4ggleJc8S+P0Y= github.com/bepass-org/psiphon-tunnel-core v0.0.0-20240311155012-9c2e10df08e5 h1:UVdsUQXhviRMzVA02BGzEHUYUBAAeSJYijqKWJvMCxs= github.com/bepass-org/psiphon-tunnel-core v0.0.0-20240311155012-9c2e10df08e5/go.mod h1:vA5iCui7nfavWyBN8MsLYZ5xpKItjrTvPC0SuMWz48Q= github.com/bifurcation/mint v0.0.0-20180306135233-198357931e61 h1:BU+NxuoaYPIvvp8NNkNlLr8aA0utGyuunf4Q3LJ0bh0= diff --git a/proxy/README.md b/proxy/README.md new file mode 100644 index 000000000..e258a37f1 --- /dev/null +++ b/proxy/README.md @@ -0,0 +1,67 @@ +# Table of Contents +- [Introduction](#introduction) +- [Features](#features) +- [Installation](#installation) +- [Examples](#examples) + - [Minimal](#minimal) + - [Customized](#customized) + + +## Introduction +The proxy module simplifies connection handling and offers a generic way to work with both HTTP and SOCKS connections, +making it a powerful tool for managing network traffic. + + +## Features +The Inbound Proxy project offers the following features: + +- Full support for `HTTP`, `SOCKS5`, `SOCKS5h`, `SOCKS4` and `SOCKS4a` protocols. +- Handling of `HTTP` and `HTTPS-connect` proxy requests. +- Full support for both `IPv4` and `IPv6`. +- Able to handle both `TCP` and `UDP` traffic. + +## Installation + +```bash +go get github.com/bepass-org/proxy +``` + +### Examples + +#### Minimal + +```go +package main + +import ( + "github.com/bepass-org/proxy/pkg/mixed" +) + +func main() { + proxy := mixed.NewProxy() + _ = proxy.ListenAndServe() +} +``` + +#### Customized + +```go +package main + +import ( + "github.com/bepass-org/proxy/pkg/mixed" +) + +func main() { + proxy := mixed.NewProxy( + mixed.WithBindAddress("0.0.0.0:8080"), + ) + _ = proxy.ListenAndServe() +} + +``` + +There are other examples provided in the [example](https://github.com/bepass-org/proxy/tree/main/example) directory + + + diff --git a/proxy/example/customHandler/main.go b/proxy/example/customHandler/main.go new file mode 100644 index 000000000..fd59fc150 --- /dev/null +++ b/proxy/example/customHandler/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "io" + "log" + "net" + + "github.com/bepass-org/warp-plus/proxy/pkg/mixed" + "github.com/bepass-org/warp-plus/proxy/pkg/statute" +) + +func main() { + proxy := mixed.NewProxy( + mixed.WithBindAddress("127.0.0.1:1080"), + mixed.WithUserHandler(generalHandler), + ) + _ = proxy.ListenAndServe() +} + +func generalHandler(req *statute.ProxyRequest) error { + fmt.Println("handling request to", req.Destination) + conn, err := net.Dial(req.Network, req.Destination) + if err != nil { + return err + } + go func() { + _, err := io.Copy(conn, req.Conn) + if err != nil { + log.Println(err) + } + }() + _, err = io.Copy(req.Conn, conn) + return err +} diff --git a/proxy/example/minimal/main.go b/proxy/example/minimal/main.go new file mode 100644 index 000000000..ddb5546f7 --- /dev/null +++ b/proxy/example/minimal/main.go @@ -0,0 +1,10 @@ +package main + +import ( + "github.com/bepass-org/warp-plus/proxy/pkg/mixed" +) + +func main() { + proxy := mixed.NewProxy() + _ = proxy.ListenAndServe() +} diff --git a/proxy/example/udpClient/main.go b/proxy/example/udpClient/main.go new file mode 100644 index 000000000..9046b5aa4 --- /dev/null +++ b/proxy/example/udpClient/main.go @@ -0,0 +1,69 @@ +package main + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "strconv" +) + +func main() { + proxyAddr := "127.0.0.1:1080" + targetAddr := ":4444" + + // Connect to SOCKS5 proxy + conn, err := net.Dial("tcp", proxyAddr) + if err != nil { + panic(err) + } + defer conn.Close() + + // Send greeting to SOCKS5 proxy + conn.Write([]byte{0x05, 0x01, 0x00}) + + // Read greeting response + response := make([]byte, 2) + io.ReadFull(conn, response) + + // Send UDP ASSOCIATE request + conn.Write([]byte{0x05, 0x03, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + + // Read UDP ASSOCIATE response + response = make([]byte, 10) + io.ReadFull(conn, response) + + // Extract the bind address and port + bindIP := net.IP(response[4:8]) + bindPort := binary.BigEndian.Uint16(response[8:10]) + + // Print the bind address + fmt.Printf("Bind address: %s:%d\n", bindIP, bindPort) + + // Create UDP connection + udpConn, err := net.Dial("udp", fmt.Sprintf("%s:%d", bindIP, bindPort)) + if err != nil { + panic(err) + } + defer udpConn.Close() + + // Extract target IP and port + dstIP, dstPortStr, _ := net.SplitHostPort(targetAddr) + dstPort, _ := strconv.Atoi(dstPortStr) + + // Construct the UDP packet with the target address and message + packet := make([]byte, 0) + packet = append(packet, 0x00, 0x00, 0x00) // RSV and FRAG + packet = append(packet, 0x01) // ATYP for IPv4 + packet = append(packet, net.ParseIP(dstIP).To4()...) + packet = append(packet, byte(dstPort>>8), byte(dstPort&0xFF)) + packet = append(packet, []byte("Hello, UDP through SOCKS5!")...) + + // Send the UDP packet + udpConn.Write(packet) + + // Read the response + buffer := make([]byte, 1024) + n, _ := udpConn.Read(buffer) + fmt.Println("Received:", string(buffer[10:n])) +} diff --git a/proxy/pkg/http/common.go b/proxy/pkg/http/common.go new file mode 100644 index 000000000..28c2410c1 --- /dev/null +++ b/proxy/pkg/http/common.go @@ -0,0 +1,84 @@ +package http + +import ( + "bytes" + "fmt" + "net" + "net/http" + "sync" +) + +// copyBuffer is a helper function to copy data between two net.Conn objects. +// func copyBuffer(dst, src net.Conn, buf []byte) (int64, error) { +// return io.CopyBuffer(dst, src, buf) +// } + +type responseWriter struct { + conn net.Conn + headers http.Header + status int + written bool +} + +func NewHTTPResponseWriter(conn net.Conn) http.ResponseWriter { + return &responseWriter{ + conn: conn, + headers: http.Header{}, + status: http.StatusOK, + } +} + +func (rw *responseWriter) Header() http.Header { + return rw.headers +} + +func (rw *responseWriter) WriteHeader(statusCode int) { + if rw.written { + return + } + rw.status = statusCode + rw.written = true + + statusText := http.StatusText(statusCode) + if statusText == "" { + statusText = fmt.Sprintf("status code %d", statusCode) + } + _, _ = fmt.Fprintf(rw.conn, "HTTP/1.1 %d %s\r\n", statusCode, statusText) + _ = rw.headers.Write(rw.conn) + _, _ = rw.conn.Write([]byte("\r\n")) +} + +func (rw *responseWriter) Write(data []byte) (int, error) { + if !rw.written { + rw.WriteHeader(http.StatusOK) + } + return rw.conn.Write(data) +} + +type customConn struct { + net.Conn + req *http.Request + initialData []byte + once sync.Once +} + +func (c *customConn) Read(p []byte) (n int, err error) { + c.once.Do(func() { + buf := &bytes.Buffer{} + err = c.req.Write(buf) + if err != nil { + n = 0 + return + } + c.initialData = buf.Bytes() + }) + + if len(c.initialData) > 0 { + copy(p, c.initialData) + n = len(p) + c.initialData = nil + return + } + + return c.Conn.Read(p) +} diff --git a/proxy/pkg/http/server.go b/proxy/pkg/http/server.go new file mode 100644 index 000000000..101c8ca1b --- /dev/null +++ b/proxy/pkg/http/server.go @@ -0,0 +1,237 @@ +package http + +import ( + "bufio" + "context" + "io" + "log/slog" + "net" + "net/http" + "strconv" + + "github.com/bepass-org/warp-plus/proxy/pkg/statute" +) + +type Server struct { + // bind is the address to listen on + Bind string + // ProxyDial specifies the optional proxyDial function for + // establishing the transport connection. + ProxyDial statute.ProxyDialFunc + // UserConnectHandle gives the user control to handle the TCP CONNECT requests + UserConnectHandle statute.UserConnectHandler + // Logger error log + Logger *slog.Logger + // Context is default context + Context context.Context + // BytesPool getting and returning temporary bytes for use by io.CopyBuffer + BytesPool statute.BytesPool +} + +func NewServer(options ...ServerOption) *Server { + s := &Server{ + Bind: statute.DefaultBindAddress, + ProxyDial: statute.DefaultProxyDial(), + Logger: slog.Default(), + Context: statute.DefaultContext(), + } + + for _, option := range options { + option(s) + } + + return s +} + +type ServerOption func(*Server) + +func (s *Server) ListenAndServe() error { + // Create a new listener + ln, err := net.Listen("tcp", s.Bind) + if err != nil { + return err // Return error if binding was unsuccessful + } + + // ensure listener will be closed + defer func() { + _ = ln.Close() + }() + + // Create a cancelable context based on s.Context + ctx, cancel := context.WithCancel(s.Context) + defer cancel() // Ensure resources are cleaned up + + // Start to accept connections and serve them + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + conn, err := ln.Accept() + if err != nil { + s.Logger.Error(err.Error()) + continue + } + + // Start a new goroutine to handle each connection + // This way, the server can handle multiple connections concurrently + go func() { + err := s.ServeConn(conn) + if err != nil { + s.Logger.Error(err.Error()) // Log errors from ServeConn + } + }() + } + } +} + +func WithLogger(logger *slog.Logger) ServerOption { + return func(s *Server) { + s.Logger = logger + } +} + +func WithBind(bindAddress string) ServerOption { + return func(s *Server) { + s.Bind = bindAddress + } +} + +func WithConnectHandle(handler statute.UserConnectHandler) ServerOption { + return func(s *Server) { + s.UserConnectHandle = handler + } +} + +func WithProxyDial(proxyDial statute.ProxyDialFunc) ServerOption { + return func(s *Server) { + s.ProxyDial = proxyDial + } +} + +func WithContext(ctx context.Context) ServerOption { + return func(s *Server) { + s.Context = ctx + } +} + +func WithBytesPool(bytesPool statute.BytesPool) ServerOption { + return func(s *Server) { + s.BytesPool = bytesPool + } +} + +func (s *Server) ServeConn(conn net.Conn) error { + reader := bufio.NewReader(conn) + req, err := http.ReadRequest(reader) + if err != nil { + return err + } + + return s.handleHTTP(conn, req, req.Method == http.MethodConnect) +} + +func (s *Server) handleHTTP(conn net.Conn, req *http.Request, isConnectMethod bool) error { + if s.UserConnectHandle == nil { + return s.embedHandleHTTP(conn, req, isConnectMethod) + } + + if isConnectMethod { + _, err := conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) + if err != nil { + return err + } + } else { + cConn := &customConn{ + Conn: conn, + req: req, + } + conn = cConn + } + + targetAddr := req.URL.Host + host, portStr, err := net.SplitHostPort(targetAddr) + if err != nil { + host = targetAddr + if req.URL.Scheme == "https" || isConnectMethod { + portStr = "443" + } else { + portStr = "80" + } + targetAddr = net.JoinHostPort(host, portStr) + } + + portInt, err := strconv.Atoi(portStr) + if err != nil { + return err // Handle the error if the port string is not a valid integer. + } + port := int32(portInt) + + proxyReq := &statute.ProxyRequest{ + Conn: conn, + Reader: io.Reader(conn), + Writer: io.Writer(conn), + Network: "tcp", + Destination: targetAddr, + DestHost: host, + DestPort: port, + } + + return s.UserConnectHandle(proxyReq) +} + +func (s *Server) embedHandleHTTP(conn net.Conn, req *http.Request, isConnectMethod bool) error { + defer func() { + _ = conn.Close() + }() + + host, portStr, err := net.SplitHostPort(req.URL.Host) + if err != nil { + host = req.URL.Host + if req.URL.Scheme == "https" || isConnectMethod { + portStr = "443" + } else { + portStr = "80" + } + } + targetAddr := net.JoinHostPort(host, portStr) + + target, err := s.ProxyDial(s.Context, "tcp", targetAddr) + if err != nil { + http.Error( + NewHTTPResponseWriter(conn), + err.Error(), + http.StatusServiceUnavailable, + ) + return err + } + defer func() { + _ = target.Close() + }() + + if isConnectMethod { + _, err = conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) + if err != nil { + return err + } + } else { + err = req.Write(target) + if err != nil { + return err + } + } + + var buf1, buf2 []byte + if s.BytesPool != nil { + buf1 = s.BytesPool.Get() + buf2 = s.BytesPool.Get() + defer func() { + s.BytesPool.Put(buf1) + s.BytesPool.Put(buf2) + }() + } else { + buf1 = make([]byte, 32*1024) + buf2 = make([]byte, 32*1024) + } + return statute.Tunnel(s.Context, target, conn, buf1, buf2) +} diff --git a/proxy/pkg/mixed/handlers.go b/proxy/pkg/mixed/handlers.go new file mode 100644 index 000000000..9d960863b --- /dev/null +++ b/proxy/pkg/mixed/handlers.go @@ -0,0 +1,90 @@ +package mixed + +import ( + "context" + "log/slog" + + "github.com/bepass-org/warp-plus/proxy/pkg/statute" +) + +func WithBindAddress(binAddress string) Option { + return func(p *Proxy) { + p.bind = binAddress + p.socks5Proxy.Bind = binAddress + p.socks4Proxy.Bind = binAddress + p.httpProxy.Bind = binAddress + } +} + +func WithLogger(logger *slog.Logger) Option { + return func(p *Proxy) { + p.logger = logger + p.socks5Proxy.Logger = logger + p.socks4Proxy.Logger = logger + p.httpProxy.Logger = logger + } +} + +func WithUserHandler(handler userHandler) Option { + return func(p *Proxy) { + p.userHandler = handler + p.socks5Proxy.UserConnectHandle = statute.UserConnectHandler(handler) + p.socks5Proxy.UserAssociateHandle = statute.UserAssociateHandler(handler) + p.socks4Proxy.UserConnectHandle = statute.UserConnectHandler(handler) + p.httpProxy.UserConnectHandle = statute.UserConnectHandler(handler) + } +} + +func WithUserTCPHandler(handler userHandler) Option { + return func(p *Proxy) { + p.userTCPHandler = handler + p.socks5Proxy.UserConnectHandle = statute.UserConnectHandler(handler) + p.socks4Proxy.UserConnectHandle = statute.UserConnectHandler(handler) + p.httpProxy.UserConnectHandle = statute.UserConnectHandler(handler) + } +} + +func WithUserUDPHandler(handler userHandler) Option { + return func(p *Proxy) { + p.userUDPHandler = handler + p.socks5Proxy.UserAssociateHandle = statute.UserAssociateHandler(handler) + } +} + +func WithUserDialFunc(proxyDial statute.ProxyDialFunc) Option { + return func(p *Proxy) { + p.userDialFunc = proxyDial + p.socks5Proxy.ProxyDial = proxyDial + p.socks4Proxy.ProxyDial = proxyDial + p.httpProxy.ProxyDial = proxyDial + } +} + +func WithUserListenPacketFunc(proxyListenPacket statute.ProxyListenPacket) Option { + return func(p *Proxy) { + p.socks5Proxy.ProxyListenPacket = proxyListenPacket + } +} + +func WithUserForwardAddressFunc(packetForwardAddress statute.PacketForwardAddress) Option { + return func(p *Proxy) { + p.socks5Proxy.PacketForwardAddress = packetForwardAddress + } +} + +func WithContext(ctx context.Context) Option { + return func(p *Proxy) { + p.ctx = ctx + p.socks5Proxy.Context = ctx + p.socks4Proxy.Context = ctx + p.httpProxy.Context = ctx + } +} + +func WithBytesPool(bytesPool statute.BytesPool) Option { + return func(p *Proxy) { + p.socks5Proxy.BytesPool = bytesPool + p.socks4Proxy.BytesPool = bytesPool + p.httpProxy.BytesPool = bytesPool + } +} diff --git a/proxy/pkg/mixed/proxy.go b/proxy/pkg/mixed/proxy.go new file mode 100644 index 000000000..41833eab9 --- /dev/null +++ b/proxy/pkg/mixed/proxy.go @@ -0,0 +1,147 @@ +package mixed + +import ( + "bufio" + "context" + "log/slog" + "net" + + "github.com/bepass-org/warp-plus/proxy/pkg/http" + "github.com/bepass-org/warp-plus/proxy/pkg/socks4" + "github.com/bepass-org/warp-plus/proxy/pkg/socks5" + "github.com/bepass-org/warp-plus/proxy/pkg/statute" +) + +type userHandler func(request *statute.ProxyRequest) error + +type Proxy struct { + // bind is the address to listen on + bind string + // socks5Proxy is a socks5 server with tcp and udp support + socks5Proxy *socks5.Server + // socks4Proxy is a socks4 server with tcp support + socks4Proxy *socks4.Server + // httpProxy is a http proxy server with http and http-connect support + httpProxy *http.Server + // userConnectHandle is a user handler for tcp and udp requests(its general handler) + userHandler userHandler + // if user doesnt set userHandler, it can specify userTCPHandler for manual handling of tcp requests + userTCPHandler userHandler + // if user doesnt set userHandler, it can specify userUDPHandler for manual handling of udp requests + userUDPHandler userHandler + // overwrite dial functions of http, socks4, socks5 + userDialFunc statute.ProxyDialFunc + // logger error log + logger *slog.Logger + // ctx is default context + ctx context.Context +} + +func NewProxy(options ...Option) *Proxy { + p := &Proxy{ + bind: statute.DefaultBindAddress, + socks5Proxy: socks5.NewServer(), + socks4Proxy: socks4.NewServer(), + httpProxy: http.NewServer(), + userDialFunc: statute.DefaultProxyDial(), + logger: slog.Default(), + ctx: statute.DefaultContext(), + } + + for _, option := range options { + option(p) + } + + return p +} + +type Option func(*Proxy) + +// SwitchConn wraps a net.Conn and a bufio.Reader +type SwitchConn struct { + net.Conn + reader *bufio.Reader +} + +// NewSwitchConn creates a new SwitchConn +func NewSwitchConn(conn net.Conn) *SwitchConn { + return &SwitchConn{ + Conn: conn, + reader: bufio.NewReader(conn), + } +} + +// Read reads data into p, first from the bufio.Reader, then from the net.Conn +func (c *SwitchConn) Read(p []byte) (n int, err error) { + return c.reader.Read(p) +} + +func (p *Proxy) ListenAndServe() error { + // Create a new listener + ln, err := net.Listen("tcp", p.bind) + if err != nil { + return err // Return error if binding was unsuccessful + } + p.logger.Debug("started proxy", "address", p.bind) + + // ensure listener will be closed + defer func() { + _ = ln.Close() + }() + + // Create a cancelable context based on p.Context + ctx, cancel := context.WithCancel(p.ctx) + defer cancel() // Ensure resources are cleaned up + + // Start to accept connections and serve them + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + conn, err := ln.Accept() + if err != nil { + p.logger.Error(err.Error()) + continue + } + + // Start a new goroutine to handle each connection + // This way, the server can handle multiple connections concurrently + go func() { + err := p.handleConnection(conn) + if err != nil { + p.logger.Error(err.Error()) // Log errors from ServeConn + } + }() + } + } +} + +func (p *Proxy) handleConnection(conn net.Conn) error { + // Create a SwitchConn + switchConn := NewSwitchConn(conn) + + // Read one byte to determine the protocol + buf := make([]byte, 1) + _, err := switchConn.Read(buf) + if err != nil { + return err + } + + // Unread the byte so it's available for the next read + err = switchConn.reader.UnreadByte() + if err != nil { + return err + } + + switch { + case buf[0] == 5: + err = p.socks5Proxy.ServeConn(switchConn) + case buf[0] == 4: + err = p.socks4Proxy.ServeConn(switchConn) + default: + err = p.httpProxy.ServeConn(switchConn) + } + + return err +} diff --git a/proxy/pkg/socks4/common.go b/proxy/pkg/socks4/common.go new file mode 100644 index 000000000..978119c67 --- /dev/null +++ b/proxy/pkg/socks4/common.go @@ -0,0 +1,167 @@ +package socks4 + +import ( + "bytes" + "encoding/binary" + "io" + "net" + "strconv" +) + +var ( + isSocks4a = []byte{0, 0, 0, 1} + isNone = []byte{0, 0, 0, 0} +) + +const ( + socks4Version = 0x04 +) + +const ( + ConnectCommand Command = 0x01 +) + +// Command is a SOCKS Command. +type Command byte + +func (cmd Command) String() string { + switch cmd { + case ConnectCommand: + return "socks connect" + default: + return "socks " + strconv.Itoa(int(cmd)) + } +} + +const ( + grantedReply reply = 0x5a + rejectedReply reply = 0x5b + noIdentdReply reply = 0x5c + invalidUserReply reply = 0x5d +) + +// reply is a SOCKS Command reply code. +type reply byte + +func (code reply) String() string { + switch code { + case grantedReply: + return "request granted" + case rejectedReply: + return "request rejected or failed" + case noIdentdReply: + return "request rejected becasue SOCKS server cannot connect to identd on the client" + case invalidUserReply: + return "request rejected because the client program and identd report different user-ids" + default: + return "unknown code: " + strconv.Itoa(int(code)) + } +} + +// address is a SOCKS-specific address. +// Either Name or IP is used exclusively. +type address struct { + Name string // fully-qualified domain name + IP net.IP + Port int +} + +func (a *address) Network() string { return "socks4" } + +func (a *address) String() string { + if a == nil { + return "" + } + return a.Address() +} + +// Address returns a string suitable to dial; prefer returning IP-based +// address, fallback to Name +func (a address) Address() string { + port := strconv.Itoa(a.Port) + if a.Name != "" { + return net.JoinHostPort(a.Name, port) + } + return net.JoinHostPort(a.IP.String(), port) +} + +type AddrAnfUser struct { + address + Username string +} + +func readBytes(r io.Reader) ([]byte, error) { + buf := []byte{} + var data [1]byte + for { + _, err := r.Read(data[:]) + if err != nil { + return nil, err + } + if data[0] == 0 { + return buf, nil + } + buf = append(buf, data[0]) + } +} + +func readByte(r io.Reader) (byte, error) { + var buf [1]byte + _, err := r.Read(buf[:]) + if err != nil { + return 0, err + } + return buf[0], nil +} + +func readAddrAndUser(r io.Reader) (*AddrAnfUser, error) { + address := &AddrAnfUser{} + var port [2]byte + if _, err := io.ReadFull(r, port[:]); err != nil { + return nil, err + } + address.Port = int(binary.BigEndian.Uint16(port[:])) + ip := make(net.IP, net.IPv4len) + if _, err := io.ReadFull(r, ip); err != nil { + return nil, err + } + socks4a := bytes.Equal(ip, isSocks4a) + + username, err := readBytes(r) + if err != nil { + return nil, err + } + address.Username = string(username) + if socks4a { + hostname, err := readBytes(r) + if err != nil { + return nil, err + } + address.Name = string(hostname) + } else { + address.IP = ip + } + return address, nil +} + +func writeAddr(w io.Writer, addr *address) error { + var ip net.IP + var port uint16 + if addr != nil { + ip = addr.IP.To4() + port = uint16(addr.Port) + } + var p [2]byte + binary.BigEndian.PutUint16(p[:], port) + _, err := w.Write(p[:]) + if err != nil { + return err + } + + if ip == nil { + _, err = w.Write(isNone) + } else { + _, err = w.Write(ip) + } + return err +} diff --git a/proxy/pkg/socks4/server.go b/proxy/pkg/socks4/server.go new file mode 100644 index 000000000..01b0c2a6d --- /dev/null +++ b/proxy/pkg/socks4/server.go @@ -0,0 +1,242 @@ +package socks4 + +import ( + "context" + "fmt" + "io" + "log/slog" + "net" + + "github.com/bepass-org/warp-plus/proxy/pkg/statute" +) + +// Server is accepting connections and handling the details of the SOCKS4 protocol +type Server struct { + // bind is the address to listen on + Bind string + // ProxyDial specifies the optional proxyDial function for + // establishing the transport connection. + ProxyDial statute.ProxyDialFunc + // UserConnectHandle gives the user control to handle the TCP CONNECT requests + UserConnectHandle statute.UserConnectHandler + // Logger error log + Logger *slog.Logger + // Context is default context + Context context.Context + // BytesPool getting and returning temporary bytes for use by io.CopyBuffer + BytesPool statute.BytesPool +} + +func NewServer(options ...ServerOption) *Server { + s := &Server{ + ProxyDial: statute.DefaultProxyDial(), + Logger: slog.Default(), + Context: statute.DefaultContext(), + } + + for _, option := range options { + option(s) + } + + return s +} + +type ServerOption func(*Server) + +func (s *Server) ListenAndServe() error { + // Create a new listener + ln, err := net.Listen("tcp", s.Bind) + if err != nil { + return err // Return error if binding was unsuccessful + } + s.Logger.Debug("started proxy", "address", s.Bind) + + // ensure listener will be closed + defer func() { + _ = ln.Close() + }() + + // Create a cancelable context based on s.Context + ctx, cancel := context.WithCancel(s.Context) + defer cancel() // Ensure resources are cleaned up + + // Start to accept connections and serve them + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + conn, err := ln.Accept() + if err != nil { + s.Logger.Error(err.Error()) + continue + } + + // Start a new goroutine to handle each connection + // This way, the server can handle multiple connections concurrently + go func() { + err := s.ServeConn(conn) + if err != nil { + s.Logger.Error(err.Error()) // Log errors from ServeConn + } + }() + } + } +} + +func WithLogger(logger *slog.Logger) ServerOption { + return func(s *Server) { + s.Logger = logger + } +} + +func WithBind(bindAddress string) ServerOption { + return func(s *Server) { + s.Bind = bindAddress + } +} + +func WithConnectHandle(handler statute.UserConnectHandler) ServerOption { + return func(s *Server) { + s.UserConnectHandle = handler + } +} + +func WithProxyDial(proxyDial statute.ProxyDialFunc) ServerOption { + return func(s *Server) { + s.ProxyDial = proxyDial + } +} + +func WithContext(ctx context.Context) ServerOption { + return func(s *Server) { + s.Context = ctx + } +} + +func WithBytesPool(bytesPool statute.BytesPool) ServerOption { + return func(s *Server) { + s.BytesPool = bytesPool + } +} + +func (s *Server) ServeConn(conn net.Conn) error { + version, err := readByte(conn) + if err != nil { + return err + } + if version != socks4Version { + return fmt.Errorf("unsupported SOCKS version: %d", version) + } + req := &request{ + Version: socks4Version, + Conn: conn, + } + + cmd, err := readByte(conn) + if err != nil { + return err + } + req.Command = Command(cmd) + + addr, err := readAddrAndUser(conn) + if err != nil { + if err := sendReply(req.Conn, rejectedReply, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return err + } + req.DestinationAddr = &addr.address + req.Username = addr.Username + return s.handle(req) +} + +func (s *Server) handle(req *request) error { + switch req.Command { + case ConnectCommand: + return s.handleConnect(req) + default: + if err := sendReply(req.Conn, rejectedReply, nil); err != nil { + return err + } + return fmt.Errorf("unsupported Command: %v", req.Command) + } +} + +func (s *Server) handleConnect(req *request) error { + if s.UserConnectHandle == nil { + return s.embedHandleConnect(req) + } + + if err := sendReply(req.Conn, grantedReply, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + host := req.DestinationAddr.IP.String() + if req.DestinationAddr.Name != "" { + host = req.DestinationAddr.Name + } + + proxyReq := &statute.ProxyRequest{ + Conn: req.Conn, + Reader: io.Reader(req.Conn), + Writer: io.Writer(req.Conn), + Network: "tcp", + Destination: req.DestinationAddr.String(), + DestHost: host, + DestPort: int32(req.DestinationAddr.Port), + } + + return s.UserConnectHandle(proxyReq) +} + +func (s *Server) embedHandleConnect(req *request) error { + defer func() { + _ = req.Conn.Close() + }() + target, err := s.ProxyDial(s.Context, "tcp", req.DestinationAddr.Address()) + if err != nil { + if err := sendReply(req.Conn, rejectedReply, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return fmt.Errorf("connect to %v failed: %w", req.DestinationAddr, err) + } + defer func() { + _ = target.Close() + }() + local := target.LocalAddr().(*net.TCPAddr) + bind := address{IP: local.IP, Port: local.Port} + if err := sendReply(req.Conn, grantedReply, &bind); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + + var buf1, buf2 []byte + if s.BytesPool != nil { + buf1 = s.BytesPool.Get() + buf2 = s.BytesPool.Get() + defer func() { + s.BytesPool.Put(buf1) + s.BytesPool.Put(buf2) + }() + } else { + buf1 = make([]byte, 32*1024) + buf2 = make([]byte, 32*1024) + } + return statute.Tunnel(s.Context, target, req.Conn, buf1, buf2) +} + +func sendReply(w io.Writer, resp reply, addr *address) error { + _, err := w.Write([]byte{0, byte(resp)}) + if err != nil { + return err + } + err = writeAddr(w, addr) + return err +} + +type request struct { + Version uint8 + Command Command + DestinationAddr *address + Username string + Conn net.Conn +} diff --git a/proxy/pkg/socks5/common.go b/proxy/pkg/socks5/common.go new file mode 100644 index 000000000..bffccf645 --- /dev/null +++ b/proxy/pkg/socks5/common.go @@ -0,0 +1,415 @@ +package socks5 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "net" + "strconv" + "strings" + "sync" +) + +var ( + errStringTooLong = errors.New("string too long") + errNoSupportedAuth = errors.New("no supported authentication mechanism") + errUnrecognizedAddrType = errors.New("unrecognized address type") +) + +const ( + maxUdpPacket = math.MaxUint16 - 28 +) + +const ( + socks5Version = 0x05 +) + +const ( + ConnectCommand Command = 0x01 + AssociateCommand Command = 0x03 +) + +// Command is a SOCKS Command. +type Command byte + +func (cmd Command) String() string { + switch cmd { + case ConnectCommand: + return "socks connect" + case AssociateCommand: + return "socks associate" + default: + return "socks " + strconv.Itoa(int(cmd)) + } +} + +const ( + successReply reply = 0x00 + serverFailure reply = 0x01 + ruleFailure reply = 0x02 + networkUnreachable reply = 0x03 + hostUnreachable reply = 0x04 + connectionRefused reply = 0x05 + ttlExpired reply = 0x06 + commandNotSupported reply = 0x07 + addrTypeNotSupported reply = 0x08 +) + +func errToReply(err error) reply { + if err == nil { + return successReply + } + msg := err.Error() + resp := hostUnreachable + if strings.Contains(msg, "refused") { + resp = connectionRefused + } else if strings.Contains(msg, "network is unreachable") { + resp = networkUnreachable + } + return resp +} + +// reply is a SOCKS Command reply code. +type reply byte + +func (code reply) String() string { + switch code { + case successReply: + return "succeeded" + case serverFailure: + return "general SOCKS server failure" + case ruleFailure: + return "connection not allowed by ruleset" + case networkUnreachable: + return "network unreachable" + case hostUnreachable: + return "host unreachable" + case connectionRefused: + return "connection refused" + case ttlExpired: + return "TTL expired" + case commandNotSupported: + return "Command not supported" + case addrTypeNotSupported: + return "address type not supported" + default: + return "unknown code: " + strconv.Itoa(int(code)) + } +} + +const ( + ipv4Address = 0x01 + fqdnAddress = 0x03 + ipv6Address = 0x04 +) + +// address is a SOCKS-specific address. +// Either Name or IP is used exclusively. +type address struct { + Name string // fully-qualified domain name + IP net.IP + Port int +} + +func (a *address) Network() string { return "socks5" } + +func (a *address) String() string { + if a == nil { + return "" + } + return a.Address() +} + +// Address returns a string suitable to dial; prefer returning IP-based +// address, fallback to Name +func (a address) Address() string { + port := strconv.Itoa(a.Port) + if len(a.IP) != 0 { + return net.JoinHostPort(a.IP.String(), port) + } + return net.JoinHostPort(a.Name, port) +} + +// authMethod is a SOCKS authentication method. +type authMethod byte + +const ( + noAuth authMethod = 0x00 // no authentication required + noAcceptable authMethod = 0xff // no acceptable authentication methods +) + +func readBytes(r io.Reader) ([]byte, error) { + var buf [1]byte + _, err := r.Read(buf[:]) + if err != nil { + return nil, err + } + bytes := make([]byte, buf[0]) + _, err = io.ReadFull(r, bytes) + if err != nil { + return nil, err + } + return bytes, nil +} + +// func writeBytes(w io.Writer, b []byte) error { +// _, err := w.Write([]byte{byte(len(b))}) +// if err != nil { +// return err +// } +// _, err = w.Write(b) +// return err +// } + +func readByte(r io.Reader) (byte, error) { + var buf [1]byte + _, err := r.Read(buf[:]) + if err != nil { + return 0, err + } + return buf[0], nil +} + +func readAddr(r io.Reader) (*address, error) { + address := &address{} + + var addrType [1]byte + if _, err := r.Read(addrType[:]); err != nil { + return nil, err + } + + switch addrType[0] { + case ipv4Address: + addr := make(net.IP, net.IPv4len) + if _, err := io.ReadFull(r, addr); err != nil { + return nil, err + } + address.IP = addr + case ipv6Address: + addr := make(net.IP, net.IPv6len) + if _, err := io.ReadFull(r, addr); err != nil { + return nil, err + } + address.IP = addr + case fqdnAddress: + if _, err := r.Read(addrType[:]); err != nil { + return nil, err + } + addrLen := int(addrType[0]) + fqdn := make([]byte, addrLen) + if _, err := io.ReadFull(r, fqdn); err != nil { + return nil, err + } + address.Name = string(fqdn) + default: + return nil, errUnrecognizedAddrType + } + var port [2]byte + if _, err := io.ReadFull(r, port[:]); err != nil { + return nil, err + } + address.Port = int(binary.BigEndian.Uint16(port[:])) + return address, nil +} + +func writeAddr(w io.Writer, addr *address) error { + if addr == nil { + _, err := w.Write([]byte{ipv4Address, 0, 0, 0, 0, 0, 0}) + if err != nil { + return err + } + return nil + } + if addr.IP != nil { + if ip4 := addr.IP.To4(); ip4 != nil { + _, err := w.Write([]byte{ipv4Address}) + if err != nil { + return err + } + _, err = w.Write(ip4) + if err != nil { + return err + } + } else if ip6 := addr.IP.To16(); ip6 != nil { + _, err := w.Write([]byte{ipv6Address}) + if err != nil { + return err + } + _, err = w.Write(ip6) + if err != nil { + return err + } + } else { + _, err := w.Write([]byte{ipv4Address, 0, 0, 0, 0}) + if err != nil { + return err + } + } + } else if addr.Name != "" { + if len(addr.Name) > 255 { + return errStringTooLong + } + _, err := w.Write([]byte{fqdnAddress, byte(len(addr.Name))}) + if err != nil { + return err + } + _, err = w.Write([]byte(addr.Name)) + if err != nil { + return err + } + } else { + _, err := w.Write([]byte{ipv4Address, 0, 0, 0, 0}) + if err != nil { + return err + } + } + var p [2]byte + binary.BigEndian.PutUint16(p[:], uint16(addr.Port)) + _, err := w.Write(p[:]) + return err +} + +func writeAddrWithStr(w io.Writer, addr string) error { + host, port, err := splitHostPort(addr) + if err != nil { + return err + } + if ip := net.ParseIP(host); ip != nil { + return writeAddr(w, &address{IP: ip, Port: port}) + } + return writeAddr(w, &address{Name: host, Port: port}) +} + +func splitHostPort(address string) (string, int, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return "", 0, err + } + portnum, err := strconv.Atoi(port) + if err != nil { + return "", 0, err + } + if 0 > portnum || portnum > 0xffff { + return "", 0, errors.New("port number out of range " + port) + } + return host, portnum, nil +} + +type readStruct struct { + data []byte + err error +} + +type udpCustomConn struct { + net.PacketConn + assocTCPConn net.Conn + lock sync.Mutex + sourceAddr net.Addr + targetAddr net.Addr + replyPrefix []byte + firstRead sync.Once + frc chan bool + packetQueue chan *readStruct +} + +func (cc *udpCustomConn) RemoteAddr() net.Addr { + return cc.targetAddr +} + +func (cc *udpCustomConn) asyncReadPackets() { + go func() { + for { + tempBuf := make([]byte, maxUdpPacket) + n, addr, err := cc.ReadFrom(tempBuf) + if err != nil { + cc.packetQueue <- &readStruct{ + data: nil, + err: err, + } + break + } + if cc.sourceAddr == nil { + cc.sourceAddr = addr + } + packetData := tempBuf[:n] + if len(packetData) < 3 { + cc.packetQueue <- &readStruct{ + data: nil, + err: err, + } + break + } + reader := bytes.NewBuffer(packetData[3:]) + targetAddr, err := readAddr(reader) + + if err != nil { + cc.packetQueue <- &readStruct{ + data: nil, + err: err, + } + break + } + if cc.targetAddr == nil { + cc.targetAddr = &net.UDPAddr{ + IP: targetAddr.IP, + Port: targetAddr.Port, + } + } + if targetAddr.String() != cc.targetAddr.String() { + cc.packetQueue <- &readStruct{ + data: nil, + err: fmt.Errorf("ignore non-target addresses %s", targetAddr.String()), + } + break + } + cc.firstRead.Do(func() { + // ok we have source and destination address now user can handle new ProxyReq + cc.frc <- true + }) + cc.packetQueue <- &readStruct{ + data: reader.Bytes(), + err: nil, + } + } + }() +} + +func (cc *udpCustomConn) Read(b []byte) (int, error) { + // wait for packet data + read := <-cc.packetQueue + if read.err != nil { + return 0, read.err + } + copy(b, read.data) + return len(read.data), nil +} + +func (cc *udpCustomConn) Write(b []byte) (int, error) { + cc.lock.Lock() + defer cc.lock.Unlock() + if cc.replyPrefix == nil { + prefix := bytes.NewBuffer(make([]byte, 3, 16)) + err := writeAddrWithStr(prefix, cc.targetAddr.String()) + if err != nil { + return 0, err + } + cc.replyPrefix = prefix.Bytes() + } + buff := append(cc.replyPrefix, b...) + _, err := cc.WriteTo(buff[:len(cc.replyPrefix)+len(b)], cc.sourceAddr) + return len(b), err +} + +func (cc *udpCustomConn) Close() error { + cc.lock.Lock() + defer cc.lock.Unlock() + udpErr := cc.PacketConn.Close() + tcpErr := cc.assocTCPConn.Close() + if udpErr != nil { + return udpErr + } + return tcpErr +} diff --git a/proxy/pkg/socks5/server.go b/proxy/pkg/socks5/server.go new file mode 100644 index 000000000..3541ed0b9 --- /dev/null +++ b/proxy/pkg/socks5/server.go @@ -0,0 +1,455 @@ +package socks5 + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "net" + + "github.com/bepass-org/warp-plus/proxy/pkg/statute" +) + +// Server is accepting connections and handling the details of the SOCKS5 protocol +type Server struct { + // bind is the address to listen on + Bind string + // ProxyDial specifies the optional proxyDial function for + // establishing the transport connection. + ProxyDial statute.ProxyDialFunc + // ProxyListenPacket specifies the optional proxyListenPacket function for + // establishing the transport connection. + ProxyListenPacket statute.ProxyListenPacket + // PacketForwardAddress specifies the packet forwarding address + PacketForwardAddress statute.PacketForwardAddress + // UserConnectHandle gives the user control to handle the TCP CONNECT requests + UserConnectHandle statute.UserConnectHandler + // UserAssociateHandle gives the user control to handle the UDP ASSOCIATE requests + UserAssociateHandle statute.UserAssociateHandler + // Logger error log + Logger *slog.Logger + // Context is default context + Context context.Context + // BytesPool getting and returning temporary bytes for use by io.CopyBuffer + BytesPool statute.BytesPool +} + +func NewServer(options ...ServerOption) *Server { + s := &Server{ + Bind: statute.DefaultBindAddress, + ProxyDial: statute.DefaultProxyDial(), + ProxyListenPacket: statute.DefaultProxyListenPacket(), + PacketForwardAddress: defaultReplyPacketForwardAddress, + Logger: slog.Default(), + Context: statute.DefaultContext(), + } + + for _, option := range options { + option(s) + } + + return s +} + +type ServerOption func(*Server) + +func (s *Server) ListenAndServe() error { + // Create a new listener + ln, err := net.Listen("tcp", s.Bind) + if err != nil { + return err // Return error if binding was unsuccessful + } + + // ensure listener will be closed + defer func() { + _ = ln.Close() + }() + + // Create a cancelable context based on s.Context + ctx, cancel := context.WithCancel(s.Context) + defer cancel() // Ensure resources are cleaned up + + // Start to accept connections and serve them + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + conn, err := ln.Accept() + if err != nil { + s.Logger.Error(err.Error()) + continue + } + + // Start a new goroutine to handle each connection + // This way, the server can handle multiple connections concurrently + go func() { + err := s.ServeConn(conn) + if err != nil { + s.Logger.Error(err.Error()) // Log errors from ServeConn + } + }() + } + } +} + +func WithLogger(logger *slog.Logger) ServerOption { + return func(s *Server) { + s.Logger = logger + } +} + +func WithBind(bindAddress string) ServerOption { + return func(s *Server) { + s.Bind = bindAddress + } +} + +func WithConnectHandle(handler statute.UserConnectHandler) ServerOption { + return func(s *Server) { + s.UserConnectHandle = handler + } +} + +func WithAssociateHandle(handler statute.UserAssociateHandler) ServerOption { + return func(s *Server) { + s.UserAssociateHandle = handler + } +} + +func WithProxyDial(proxyDial statute.ProxyDialFunc) ServerOption { + return func(s *Server) { + s.ProxyDial = proxyDial + } +} + +func WithProxyListenPacket(proxyListenPacket statute.ProxyListenPacket) ServerOption { + return func(s *Server) { + s.ProxyListenPacket = proxyListenPacket + } +} + +func WithPacketForwardAddress(packetForwardAddress statute.PacketForwardAddress) ServerOption { + return func(s *Server) { + s.PacketForwardAddress = packetForwardAddress + } +} + +func WithContext(ctx context.Context) ServerOption { + return func(s *Server) { + s.Context = ctx + } +} + +func WithBytesPool(bytesPool statute.BytesPool) ServerOption { + return func(s *Server) { + s.BytesPool = bytesPool + } +} + +func (s *Server) ServeConn(conn net.Conn) error { + version, err := readByte(conn) + if err != nil { + return err + } + if version != socks5Version { + return fmt.Errorf("unsupported SOCKS version: %d", version) + } + + req := &request{ + Version: socks5Version, + Conn: conn, + } + + methods, err := readBytes(conn) + if err != nil { + return err + } + + if bytes.IndexByte(methods, byte(noAuth)) != -1 { + _, err := conn.Write([]byte{socks5Version, byte(noAuth)}) + if err != nil { + return err + } + } else { + _, err := conn.Write([]byte{socks5Version, byte(noAcceptable)}) + if err != nil { + return err + } + return errNoSupportedAuth + } + + var header [3]byte + _, err = io.ReadFull(conn, header[:]) + if err != nil { + return err + } + + if header[0] != socks5Version { + return fmt.Errorf("unsupported Command version: %d", header[0]) + } + + req.Command = Command(header[1]) + + dest, err := readAddr(conn) + if err != nil { + if err == errUnrecognizedAddrType { + err := sendReply(conn, addrTypeNotSupported, nil) + if err != nil { + return err + } + } + return err + } + req.DestinationAddr = dest + err = s.handle(req) + if err != nil { + return err + } + + return nil +} + +func (s *Server) handle(req *request) error { + switch req.Command { + case ConnectCommand: + return s.handleConnect(req) + case AssociateCommand: + return s.handleAssociate(req) + default: + if err := sendReply(req.Conn, commandNotSupported, nil); err != nil { + return err + } + return fmt.Errorf("unsupported Command: %v", req.Command) + } +} + +func (s *Server) handleConnect(req *request) error { + if s.UserConnectHandle == nil { + return s.embedHandleConnect(req) + } + + if err := sendReply(req.Conn, successReply, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + host := req.DestinationAddr.IP.String() + if req.DestinationAddr.Name != "" { + host = req.DestinationAddr.Name + } + + proxyReq := &statute.ProxyRequest{ + Conn: req.Conn, + Reader: io.Reader(req.Conn), + Writer: io.Writer(req.Conn), + Network: "tcp", + Destination: req.DestinationAddr.String(), + DestHost: host, + DestPort: int32(req.DestinationAddr.Port), + } + + return s.UserConnectHandle(proxyReq) +} + +func (s *Server) embedHandleConnect(req *request) error { + defer func() { + _ = req.Conn.Close() + }() + + target, err := s.ProxyDial(s.Context, "tcp", req.DestinationAddr.Address()) + if err != nil { + if err := sendReply(req.Conn, errToReply(err), nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return fmt.Errorf("connect to %v failed: %w", req.DestinationAddr, err) + } + defer func() { + _ = target.Close() + }() + + localAddr := target.LocalAddr() + local, ok := localAddr.(*net.TCPAddr) + if !ok { + return fmt.Errorf("connect to %v failed: local address is %s://%s", req.DestinationAddr, localAddr.Network(), localAddr.String()) + } + bind := address{IP: local.IP, Port: local.Port} + if err := sendReply(req.Conn, successReply, &bind); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + + var buf1, buf2 []byte + if s.BytesPool != nil { + buf1 = s.BytesPool.Get() + buf2 = s.BytesPool.Get() + defer func() { + s.BytesPool.Put(buf1) + s.BytesPool.Put(buf2) + }() + } else { + buf1 = make([]byte, 32*1024) + buf2 = make([]byte, 32*1024) + } + return statute.Tunnel(s.Context, target, req.Conn, buf1, buf2) +} + +func (s *Server) handleAssociate(req *request) error { + destinationAddr := req.DestinationAddr.String() + udpConn, err := s.ProxyListenPacket(s.Context, "udp", destinationAddr) + if err != nil { + if err := sendReply(req.Conn, errToReply(err), nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return fmt.Errorf("connect to %v failed: %w", req.DestinationAddr, err) + } + + ip, port, err := s.PacketForwardAddress(s.Context, destinationAddr, udpConn, req.Conn) + if err != nil { + return err + } + bind := address{IP: ip, Port: port} + if err := sendReply(req.Conn, successReply, &bind); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + + if s.UserAssociateHandle == nil { + return s.embedHandleAssociate(req, udpConn) + } + + cConn := &udpCustomConn{ + PacketConn: udpConn, + assocTCPConn: req.Conn, + frc: make(chan bool), + packetQueue: make(chan *readStruct), + } + + cConn.asyncReadPackets() + + // wait for first packet so that target sender and receiver get known + <-cConn.frc + + proxyReq := &statute.ProxyRequest{ + Conn: cConn, + Reader: cConn, + Writer: cConn, + Network: "udp", + Destination: cConn.targetAddr.String(), + DestHost: cConn.targetAddr.(*net.UDPAddr).IP.String(), + DestPort: int32(cConn.targetAddr.(*net.UDPAddr).Port), + } + + return s.UserAssociateHandle(proxyReq) +} + +func (s *Server) embedHandleAssociate(req *request, udpConn net.PacketConn) error { + defer func() { + _ = udpConn.Close() + }() + + go func() { + var buf [1]byte + for { + _, err := req.Conn.Read(buf[:]) + if err != nil { + _ = udpConn.Close() + break + } + } + }() + + var ( + sourceAddr net.Addr + wantSource string + targetAddr net.Addr + wantTarget string + replyPrefix []byte + buf [maxUdpPacket]byte + ) + + for { + n, addr, err := udpConn.ReadFrom(buf[:]) + if err != nil { + return err + } + + if sourceAddr == nil { + sourceAddr = addr + wantSource = sourceAddr.String() + } + + gotAddr := addr.String() + if wantSource == gotAddr { + if n < 3 { + continue + } + reader := bytes.NewBuffer(buf[3:n]) + addr, err := readAddr(reader) + if err != nil { + s.Logger.Debug(err.Error()) + continue + } + if targetAddr == nil { + targetAddr = &net.UDPAddr{ + IP: addr.IP, + Port: addr.Port, + } + wantTarget = targetAddr.String() + } + if addr.String() != wantTarget { + s.Logger.Debug("ignore non-target addresses", "address", addr) + continue + } + _, err = udpConn.WriteTo(reader.Bytes(), targetAddr) + if err != nil { + return err + } + } else if targetAddr != nil && wantTarget == gotAddr { + if replyPrefix == nil { + b := bytes.NewBuffer(make([]byte, 3, 16)) + err = writeAddrWithStr(b, wantTarget) + if err != nil { + return err + } + replyPrefix = b.Bytes() + } + copy(buf[len(replyPrefix):len(replyPrefix)+n], buf[:n]) + copy(buf[:len(replyPrefix)], replyPrefix) + _, err = udpConn.WriteTo(buf[:len(replyPrefix)+n], sourceAddr) + if err != nil { + return err + } + } + } +} + +func sendReply(w io.Writer, resp reply, addr *address) error { + _, err := w.Write([]byte{socks5Version, byte(resp), 0}) + if err != nil { + return err + } + err = writeAddr(w, addr) + return err +} + +type request struct { + Version uint8 + Command Command + DestinationAddr *address + Username string + Password string + Conn net.Conn +} + +func defaultReplyPacketForwardAddress(_ context.Context, destinationAddr string, packet net.PacketConn, conn net.Conn) (net.IP, int, error) { + udpLocal := packet.LocalAddr() + udpLocalAddr, ok := udpLocal.(*net.UDPAddr) + if !ok { + return nil, 0, fmt.Errorf("connect to %v failed: local address is %s://%s", destinationAddr, udpLocal.Network(), udpLocal.String()) + } + + tcpLocal := conn.LocalAddr() + tcpLocalAddr, ok := tcpLocal.(*net.TCPAddr) + if !ok { + return nil, 0, fmt.Errorf("connect to %v failed: local address is %s://%s", destinationAddr, tcpLocal.Network(), tcpLocal.String()) + } + return tcpLocalAddr.IP, udpLocalAddr.Port, nil +} diff --git a/proxy/pkg/statute/statute.go b/proxy/pkg/statute/statute.go new file mode 100644 index 000000000..344a95d96 --- /dev/null +++ b/proxy/pkg/statute/statute.go @@ -0,0 +1,76 @@ +package statute + +import ( + "context" + "fmt" + "io" + "net" +) + +type Logger interface { + Debug(v ...interface{}) + Error(v ...interface{}) +} + +type DefaultLogger struct{} + +func (l DefaultLogger) Debug(v ...interface{}) { + fmt.Println(v...) +} + +func (l DefaultLogger) Error(v ...interface{}) { + fmt.Println(v...) +} + +type ProxyRequest struct { + Conn net.Conn + Reader io.Reader + Writer io.Writer + Network string + Destination string + DestHost string + DestPort int32 +} + +// UserConnectHandler is used for socks5, socks4 and http +type UserConnectHandler func(request *ProxyRequest) error + +// UserAssociateHandler is used for socks5 +type UserAssociateHandler func(request *ProxyRequest) error + +// ProxyDialFunc is used for socks5, socks4 and http +type ProxyDialFunc func(ctx context.Context, network string, address string) (net.Conn, error) + +// DefaultProxyDial for ProxyDialFunc type +func DefaultProxyDial() ProxyDialFunc { + var dialer net.Dialer + return dialer.DialContext +} + +// ProxyListenPacket specifies the optional proxyListenPacket function for +// establishing the transport connection. +type ProxyListenPacket func(ctx context.Context, network string, address string) (net.PacketConn, error) + +// DefaultProxyListenPacket for ProxyListenPacket type +func DefaultProxyListenPacket() ProxyListenPacket { + var listener net.ListenConfig + return listener.ListenPacket +} + +// PacketForwardAddress specifies the packet forwarding address +type PacketForwardAddress func(ctx context.Context, destinationAddr string, + packet net.PacketConn, conn net.Conn) (net.IP, int, error) + +// BytesPool is an interface for getting and returning temporary +// bytes for use by io.CopyBuffer. +type BytesPool interface { + Get() []byte + Put([]byte) +} + +// DefaultContext for context.Context type +func DefaultContext() context.Context { + return context.Background() +} + +const DefaultBindAddress = "127.0.0.1:1080" diff --git a/proxy/pkg/statute/tunnel.go b/proxy/pkg/statute/tunnel.go new file mode 100644 index 000000000..34afb2e61 --- /dev/null +++ b/proxy/pkg/statute/tunnel.go @@ -0,0 +1,80 @@ +package statute + +import ( + "context" + "io" + "net" + "os" + "reflect" + "runtime" + "strings" +) + +// isClosedConnError reports whether err is an error from use of a closed +// network connection. +func isClosedConnError(err error) bool { + if err == nil { + return false + } + + str := err.Error() + if strings.Contains(str, "use of closed network connection") { + return true + } + + if runtime.GOOS == "windows" { + if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { + if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" { + const WSAECONNABORTED = 10053 + const WSAECONNRESET = 10054 + if n := errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED { + return true + } + } + } + } + return false +} + +func errno(v error) uintptr { + if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr { + return uintptr(rv.Uint()) + } + return 0 +} + +// Tunnel create tunnels for two io.ReadWriteCloser +func Tunnel(ctx context.Context, c1, c2 io.ReadWriteCloser, buf1, buf2 []byte) error { + ctx, cancel := context.WithCancel(ctx) + var errs tunnelErr + go func() { + _, errs[0] = io.CopyBuffer(c1, c2, buf1) + cancel() + }() + go func() { + _, errs[1] = io.CopyBuffer(c2, c1, buf2) + cancel() + }() + <-ctx.Done() + errs[2] = c1.Close() + errs[3] = c2.Close() + errs[4] = ctx.Err() + if errs[4] == context.Canceled { + errs[4] = nil + } + return errs.FirstError() +} + +type tunnelErr [5]error + +func (t tunnelErr) FirstError() error { + for _, err := range t { + if err != nil { + if isClosedConnError(err) { + return nil + } + return err + } + } + return nil +} diff --git a/wiresocks/proxy.go b/wiresocks/proxy.go index 99c58515a..94cd8b100 100644 --- a/wiresocks/proxy.go +++ b/wiresocks/proxy.go @@ -7,8 +7,8 @@ import ( "net/netip" "time" - "github.com/bepass-org/proxy/pkg/mixed" - "github.com/bepass-org/proxy/pkg/statute" + "github.com/bepass-org/warp-plus/proxy/pkg/mixed" + "github.com/bepass-org/warp-plus/proxy/pkg/statute" "github.com/bepass-org/warp-plus/wireguard/device" "github.com/bepass-org/warp-plus/wireguard/tun/netstack" ) @@ -26,8 +26,7 @@ type VirtualTun struct { func (vt *VirtualTun) StartProxy(bindAddress netip.AddrPort) { proxy := mixed.NewProxy( mixed.WithBindAddress(bindAddress.String()), - // TODO - // mixed.WithLogger(vt.Logger), + mixed.WithLogger(vt.Logger), mixed.WithContext(vt.Ctx), mixed.WithUserHandler(func(request *statute.ProxyRequest) error { return vt.generalHandler(request) @@ -50,7 +49,7 @@ func (vt *VirtualTun) StartProxy(bindAddress netip.AddrPort) { } func (vt *VirtualTun) generalHandler(req *statute.ProxyRequest) error { - vt.Logger.Debug("handling request", "protocol", req.Network, "destination", req.Destination) + vt.Logger.Info("handling connection", "protocol", req.Network, "destination", req.Destination) conn, err := vt.Tnet.Dial(req.Network, req.Destination) if err != nil { return err