Skip to content

Commit

Permalink
proxy: Allow for providing listener
Browse files Browse the repository at this point in the history
This also enables us to delete the port search code.

Signed-off-by: Mark Pashmfouroush <[email protected]>
  • Loading branch information
markpash committed Mar 15, 2024
1 parent 2d777a8 commit db5b413
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 87 deletions.
55 changes: 11 additions & 44 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"log/slog"
"net"
"net/netip"
"os"
"path/filepath"
Expand Down Expand Up @@ -124,12 +123,6 @@ func runWarp(ctx context.Context, l *slog.Logger, bind netip.AddrPort, endpoint
}

func runWarpWithPsiphon(ctx context.Context, l *slog.Logger, bind netip.AddrPort, endpoint string, country string) error {
// make a random bind address for warp
warpBindAddress, err := findFreePort("tcp")
if err != nil {
return err
}

conf, err := wiresocks.ParseConfig("./primary/wgcf-profile.ini", endpoint)
if err != nil {
return err
Expand All @@ -147,10 +140,13 @@ func runWarpWithPsiphon(ctx context.Context, l *slog.Logger, bind netip.AddrPort
return err
}

tnet.StartProxy(warpBindAddress)
warpBind, err := tnet.StartProxy(netip.MustParseAddrPort("127.0.0.1:0"))
if err != nil {
return err
}

// run psiphon
err = psiphon.RunPsiphon(ctx, l.With("subsystem", "psiphon"), warpBindAddress.String(), bind.String(), country)
err = psiphon.RunPsiphon(ctx, l.With("subsystem", "psiphon"), warpBind.String(), bind.String(), country)
if err != nil {
return fmt.Errorf("unable to run psiphon %w", err)
}
Expand Down Expand Up @@ -179,20 +175,14 @@ func runWarpInWarp(ctx context.Context, l *slog.Logger, bind netip.AddrPort, end
return err
}

// Run virtual endpoint
virtualEndpointBindAddress, err := findFreePort("udp")
if err != nil {
return err
}

// Create a UDP port forward between localhost and the remote endpoint
err = wiresocks.NewVtunUDPForwarder(ctx, virtualEndpointBindAddress.String(), endpoints[1], tnet, singleMTU)
addr, err := wiresocks.NewVtunUDPForwarder(ctx, netip.MustParseAddrPort("127.0.0.1:0"), endpoints[1], tnet, singleMTU)
if err != nil {
return err
}

// Run inner warp
conf, err = wiresocks.ParseConfig("./secondary/wgcf-profile.ini", virtualEndpointBindAddress.String())
conf, err = wiresocks.ParseConfig("./secondary/wgcf-profile.ini", addr.String())
if err != nil {
return err
}
Expand All @@ -208,36 +198,13 @@ func runWarpInWarp(ctx context.Context, l *slog.Logger, bind netip.AddrPort, end
return err
}

tnet.StartProxy(bind)

l.Info("serving proxy", "address", bind)
return nil
}

func findFreePort(network string) (netip.AddrPort, error) {
if network == "udp" {
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
return netip.AddrPort{}, err
}

conn, err := net.ListenUDP("udp", addr)
if err != nil {
return netip.AddrPort{}, err
}
defer conn.Close()

return netip.MustParseAddrPort(conn.LocalAddr().String()), nil
}
// Listen on TCP port 0, which tells the OS to pick a free port.
listener, err := net.Listen(network, "127.0.0.1:0")
_, err = tnet.StartProxy(bind)
if err != nil {
return netip.AddrPort{}, err // Return error if unable to listen on a port
return err
}
defer listener.Close() // Ensure the listener is closed when the function returns

// Get the port from the listener's address
return netip.MustParseAddrPort(listener.Addr().String()), nil
l.Info("serving proxy", "address", bind)
return nil
}

func createPrimaryAndSecondaryIdentities(l *slog.Logger, license string) error {
Expand Down
19 changes: 14 additions & 5 deletions proxy/pkg/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ import (
type Server struct {
// bind is the address to listen on
Bind string

Listener net.Listener

// ProxyDial specifies the optional proxyDial function for
// establishing the transport connection.
ProxyDial statute.ProxyDialFunc
Expand Down Expand Up @@ -47,14 +50,20 @@ type ServerOption func(*Server)

func (s *Server) ListenAndServe() error {
// Create a new listener
ln, err := net.Listen("tcp", s.Bind)
if err != nil {
return err // Return error if binding was unsuccessful
if s.Listener == nil {
ln, err := net.Listen("tcp", s.Bind)
if err != nil {
return err // Return error if binding was unsuccessful
}
s.Listener = ln
}

s.Bind = s.Listener.Addr().(*net.TCPAddr).String()
s.Logger.Debug("started proxy", "address", s.Bind)

// ensure listener will be closed
defer func() {
_ = ln.Close()
_ = s.Listener.Close()
}()

// Create a cancelable context based on s.Context
Expand All @@ -67,7 +76,7 @@ func (s *Server) ListenAndServe() error {
case <-ctx.Done():
return ctx.Err()
default:
conn, err := ln.Accept()
conn, err := s.Listener.Accept()
if err != nil {
s.Logger.Error(err.Error())
continue
Expand Down
10 changes: 10 additions & 0 deletions proxy/pkg/mixed/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mixed
import (
"context"
"log/slog"
"net"

"github.com/bepass-org/warp-plus/proxy/pkg/statute"
)
Expand All @@ -16,6 +17,15 @@ func WithBindAddress(binAddress string) Option {
}
}

func WithListener(ln net.Listener) Option {
return func(p *Proxy) {
p.listener = ln
p.socks5Proxy.Listener = ln
p.socks4Proxy.Listener = ln
p.httpProxy.Listener = ln
}
}

func WithLogger(logger *slog.Logger) Option {
return func(p *Proxy) {
p.logger = logger
Expand Down
18 changes: 13 additions & 5 deletions proxy/pkg/mixed/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ type userHandler func(request *statute.ProxyRequest) error
type Proxy struct {
// bind is the address to listen on
bind string

listener net.Listener

// socks5Proxy is a socks5 server with tcp and udp support
socks5Proxy *socks5.Server
// socks4Proxy is a socks4 server with tcp support
Expand Down Expand Up @@ -78,15 +81,20 @@ func (c *SwitchConn) Read(p []byte) (n int, err error) {

func (p *Proxy) ListenAndServe() error {
// Create a new listener
ln, err := net.Listen("tcp", p.bind)
if err != nil {
return err // Return error if binding was unsuccessful
if p.listener == nil {
ln, err := net.Listen("tcp", p.bind)
if err != nil {
return err // Return error if binding was unsuccessful
}
p.listener = ln
}

p.bind = p.listener.Addr().(*net.TCPAddr).String()
p.logger.Debug("started proxy", "address", p.bind)

// ensure listener will be closed
defer func() {
_ = ln.Close()
_ = p.listener.Close()
}()

// Create a cancelable context based on p.Context
Expand All @@ -99,7 +107,7 @@ func (p *Proxy) ListenAndServe() error {
case <-ctx.Done():
return ctx.Err()
default:
conn, err := ln.Accept()
conn, err := p.listener.Accept()
if err != nil {
p.logger.Error(err.Error())
continue
Expand Down
18 changes: 13 additions & 5 deletions proxy/pkg/socks4/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
type Server struct {
// bind is the address to listen on
Bind string

Listener net.Listener

// ProxyDial specifies the optional proxyDial function for
// establishing the transport connection.
ProxyDial statute.ProxyDialFunc
Expand Down Expand Up @@ -45,15 +48,20 @@ type ServerOption func(*Server)

func (s *Server) ListenAndServe() error {
// Create a new listener
ln, err := net.Listen("tcp", s.Bind)
if err != nil {
return err // Return error if binding was unsuccessful
if s.Listener == nil {
ln, err := net.Listen("tcp", s.Bind)
if err != nil {
return err // Return error if binding was unsuccessful
}
s.Listener = ln
}

s.Bind = s.Listener.Addr().(*net.TCPAddr).String()
s.Logger.Debug("started proxy", "address", s.Bind)

// ensure listener will be closed
defer func() {
_ = ln.Close()
_ = s.Listener.Close()
}()

// Create a cancelable context based on s.Context
Expand All @@ -66,7 +74,7 @@ func (s *Server) ListenAndServe() error {
case <-ctx.Done():
return ctx.Err()
default:
conn, err := ln.Accept()
conn, err := s.Listener.Accept()
if err != nil {
s.Logger.Error(err.Error())
continue
Expand Down
19 changes: 14 additions & 5 deletions proxy/pkg/socks5/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ import (
type Server struct {
// bind is the address to listen on
Bind string

Listener net.Listener

// ProxyDial specifies the optional proxyDial function for
// establishing the transport connection.
ProxyDial statute.ProxyDialFunc
Expand Down Expand Up @@ -56,14 +59,20 @@ type ServerOption func(*Server)

func (s *Server) ListenAndServe() error {
// Create a new listener
ln, err := net.Listen("tcp", s.Bind)
if err != nil {
return err // Return error if binding was unsuccessful
if s.Listener == nil {
ln, err := net.Listen("tcp", s.Bind)
if err != nil {
return err // Return error if binding was unsuccessful
}
s.Listener = ln
}

s.Bind = s.Listener.Addr().(*net.TCPAddr).String()
s.Logger.Debug("started proxy", "address", s.Bind)

// ensure listener will be closed
defer func() {
_ = ln.Close()
_ = s.Listener.Close()
}()

// Create a cancelable context based on s.Context
Expand All @@ -76,7 +85,7 @@ func (s *Server) ListenAndServe() error {
case <-ctx.Done():
return ctx.Err()
default:
conn, err := ln.Accept()
conn, err := s.Listener.Accept()
if err != nil {
s.Logger.Error(err.Error())
continue
Expand Down
24 changes: 12 additions & 12 deletions wiresocks/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"context"
"io"
"log/slog"
"net"
"net/netip"
"time"

"github.com/bepass-org/warp-plus/proxy/pkg/mixed"
"github.com/bepass-org/warp-plus/proxy/pkg/statute"
Expand All @@ -23,9 +23,14 @@ type VirtualTun struct {
}

// StartProxy spawns a socks5 server.
func (vt *VirtualTun) StartProxy(bindAddress netip.AddrPort) {
func (vt *VirtualTun) StartProxy(bindAddress netip.AddrPort) (netip.AddrPort, error) {
ln, err := net.Listen("tcp", bindAddress.String())
if err != nil {
return netip.AddrPort{}, err // Return error if binding was unsuccessful
}

proxy := mixed.NewProxy(
mixed.WithBindAddress(bindAddress.String()),
mixed.WithListener(ln),
mixed.WithLogger(vt.Logger),
mixed.WithContext(vt.Ctx),
mixed.WithUserHandler(func(request *statute.ProxyRequest) error {
Expand All @@ -36,16 +41,11 @@ func (vt *VirtualTun) StartProxy(bindAddress netip.AddrPort) {
_ = proxy.ListenAndServe()
}()
go func() {
for {
select {
case <-vt.Ctx.Done():
vt.Stop()
return
default:
time.Sleep(500 * time.Millisecond)
}
}
<-vt.Ctx.Done()
vt.Stop()
}()

return ln.Addr().(*net.TCPAddr).AddrPort(), nil
}

func (vt *VirtualTun) generalHandler(req *statute.ProxyRequest) error {
Expand Down
19 changes: 8 additions & 11 deletions wiresocks/udpfw.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,24 @@ package wiresocks
import (
"context"
"net"
"net/netip"
"sync"
)

func NewVtunUDPForwarder(ctx context.Context, localBind, dest string, vtun *VirtualTun, mtu int) error {
localAddr, err := net.ResolveUDPAddr("udp", localBind)
if err != nil {
return err
}

func NewVtunUDPForwarder(ctx context.Context, localBind netip.AddrPort, dest string, vtun *VirtualTun, mtu int) (netip.AddrPort, error) {
destAddr, err := net.ResolveUDPAddr("udp", dest)
if err != nil {
return err
return netip.AddrPort{}, err
}

listener, err := net.ListenUDP("udp", localAddr)
listener, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(localBind))
if err != nil {
return err
return netip.AddrPort{}, err
}

rconn, err := vtun.Tnet.DialUDP(nil, destAddr)
if err != nil {
return err
return netip.AddrPort{}, err
}

var clientAddr *net.UDPAddr
Expand Down Expand Up @@ -73,5 +69,6 @@ func NewVtunUDPForwarder(ctx context.Context, localBind, dest string, vtun *Virt
_ = listener.Close()
_ = rconn.Close()
}()
return nil

return listener.LocalAddr().(*net.UDPAddr).AddrPort(), nil
}

0 comments on commit db5b413

Please sign in to comment.