From 0e6749a50f6a1e419d105cd9eb7b77c21a0655c0 Mon Sep 17 00:00:00 2001 From: Nikolai Kabanenkov <58106770+nikolaikabanenkov@users.noreply.github.com> Date: Thu, 23 Jan 2025 22:53:55 +0500 Subject: [PATCH] feat(x): add https and http/2 support to the CONNECT client --- x/httpconnect/connect_client.go | 116 ++++++++++++++++++--------- x/httpconnect/connect_client_test.go | 89 ++++++++++++++++++-- x/httpconnect/pipe_conn.go | 3 +- 3 files changed, 162 insertions(+), 46 deletions(-) diff --git a/x/httpconnect/connect_client.go b/x/httpconnect/connect_client.go index 03990bc2..37e5aadc 100644 --- a/x/httpconnect/connect_client.go +++ b/x/httpconnect/connect_client.go @@ -16,28 +16,31 @@ package httpconnect import ( "context" + "crypto/tls" "errors" "fmt" "github.com/Jigsaw-Code/outline-sdk/transport" + "golang.org/x/net/http2" "io" "net" "net/http" ) -// connectClient is a [transport.StreamDialer] implementation that dials [proxyAddr] with the given [dialer] -// and sends a CONNECT request to the dialed proxy. -type connectClient struct { +// ConnectClient is a [transport.StreamDialer] implementation that dials proxyAddr with the given dialer and sends a CONNECT request to the dialed proxy. +// By default, the client uses "http", but it can be changed to "https" with the [WithHTTPS] option. +type ConnectClient struct { dialer transport.StreamDialer proxyAddr string - - headers http.Header + scheme string + tlsConfig *tls.Config + headers http.Header } -var _ transport.StreamDialer = (*connectClient)(nil) +var _ transport.StreamDialer = (*ConnectClient)(nil) -type ClientOption func(c *connectClient) +type ClientOption func(c *ConnectClient) -func NewConnectClient(dialer transport.StreamDialer, proxyAddr string, opts ...ClientOption) (transport.StreamDialer, error) { +func NewConnectClient(dialer transport.StreamDialer, proxyAddr string, opts ...ClientOption) (*ConnectClient, error) { if dialer == nil { return nil, errors.New("dialer must not be nil") } @@ -46,10 +49,10 @@ func NewConnectClient(dialer transport.StreamDialer, proxyAddr string, opts ...C return nil, fmt.Errorf("failed to parse proxy address %s: %w", proxyAddr, err) } - cc := &connectClient{ + cc := &ConnectClient{ dialer: dialer, proxyAddr: proxyAddr, - headers: make(http.Header), + scheme: "http", } for _, opt := range opts { @@ -59,69 +62,102 @@ func NewConnectClient(dialer transport.StreamDialer, proxyAddr string, opts ...C return cc, nil } -// WithHeaders appends the given [headers] to the CONNECT request +// WithHTTPS sets the scheme to "https" and the given tlsConfig to the transport +func WithHTTPS(tlsConfig *tls.Config) ClientOption { + return func(c *ConnectClient) { + c.scheme = "https" + c.tlsConfig = tlsConfig.Clone() + } +} + +// WithHeaders appends the given headers to the CONNECT request func WithHeaders(headers http.Header) ClientOption { - return func(c *connectClient) { + return func(c *ConnectClient) { c.headers = headers.Clone() } } // DialStream - connects to the proxy and sends a CONNECT request to it, closes the connection if the request fails -func (cc *connectClient) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { +func (cc *ConnectClient) DialStream(ctx context.Context, remoteAddr string) (streamConn transport.StreamConn, err error) { + _, _, err = net.SplitHostPort(remoteAddr) + if err != nil { + return nil, fmt.Errorf("failed to parse remote address %s: %w", remoteAddr, err) + } + innerConn, err := cc.dialer.DialStream(ctx, cc.proxyAddr) if err != nil { return nil, fmt.Errorf("failed to dial proxy %s: %w", cc.proxyAddr, err) } + defer func() { + if err != nil { + _ = innerConn.Close() + } + }() - conn, err := cc.doConnect(ctx, remoteAddr, innerConn) + roundTripper, err := cc.buildTransport(innerConn) + if err != nil { + return nil, fmt.Errorf("failed to build roundTripper: %w", err) + } + + reader, writer, err := doConnect(ctx, roundTripper, cc.scheme, remoteAddr, cc.headers) if err != nil { - _ = innerConn.Close() return nil, fmt.Errorf("doConnect %s: %w", remoteAddr, err) } - return conn, nil + return &pipeConn{ + reader: reader, + writer: writer, + StreamConn: innerConn, + }, nil } -func (cc *connectClient) doConnect(ctx context.Context, remoteAddr string, conn transport.StreamConn) (transport.StreamConn, error) { - _, _, err := net.SplitHostPort(remoteAddr) +func (cc *ConnectClient) buildTransport(conn transport.StreamConn) (http.RoundTripper, error) { + tr := &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return conn, nil + }, + TLSClientConfig: cc.tlsConfig, + } + + err := http2.ConfigureTransport(tr) if err != nil { - return nil, fmt.Errorf("failed to parse remote address %s: %w", remoteAddr, err) + return nil, fmt.Errorf("failed to configure transport for HTTP/2: %w", err) } - pr, pw := io.Pipe() + return tr, nil +} - req, err := http.NewRequestWithContext(ctx, http.MethodConnect, "http://"+remoteAddr, pr) // TODO: HTTPS support - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) +func doConnect( + ctx context.Context, + roundTripper http.RoundTripper, + scheme, remoteAddr string, + headers http.Header, +) (io.ReadCloser, io.WriteCloser, error) { + if scheme != "http" && scheme != "https" { + return nil, nil, fmt.Errorf("unsupported scheme: %s", scheme) } - req.ContentLength = -1 // -1 means length unknown - mergeHeaders(req.Header, cc.headers) - tr := &http.Transport{ - // TODO: HTTP/2 support with [http2.ConfigureTransport] - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return conn, nil - }, + pr, pw := io.Pipe() + remoteURL := fmt.Sprintf("%s://%s", scheme, remoteAddr) + req, err := http.NewRequestWithContext(ctx, http.MethodConnect, remoteURL, pr) + if err != nil { + return nil, nil, fmt.Errorf("failed to create request: %w", err) } + req.ContentLength = -1 // -1 means unknown length + mergeHeaders(req.Header, headers) hc := http.Client{ - Transport: tr, + Transport: roundTripper, } - resp, err := hc.Do(req) if err != nil { - return nil, fmt.Errorf("do: %w", err) + return nil, nil, fmt.Errorf("do: %w", err) } if resp.StatusCode != http.StatusOK { - _ = resp.Body.Close() - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + return nil, nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) } - return &pipeConn{ - reader: resp.Body, - writer: pw, - StreamConn: conn, - }, nil + return resp.Body, pw, nil } func mergeHeaders(dst http.Header, src http.Header) { diff --git a/x/httpconnect/connect_client_test.go b/x/httpconnect/connect_client_test.go index 0c28e03b..dcff8e28 100644 --- a/x/httpconnect/connect_client_test.go +++ b/x/httpconnect/connect_client_test.go @@ -17,10 +17,13 @@ package httpconnect import ( "bufio" "context" + "crypto/tls" + "crypto/x509" "encoding/base64" "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/x/httpproxy" "github.com/stretchr/testify/require" + "io" "net" "net/http" "net/http/httptest" @@ -28,16 +31,13 @@ import ( "testing" ) -func TestConnectClientOk(t *testing.T) { +func Test_ConnectClient_HTTP_Ok(t *testing.T) { t.Parallel() creds := base64.StdEncoding.EncodeToString([]byte("username:password")) targetSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodGet, r.Method, "Method") - w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte("HTTP/1.1 200 OK\r\n")) - require.NoError(t, err) })) defer targetSrv.Close() @@ -78,7 +78,7 @@ func TestConnectClientOk(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) } -func TestConnectClientFail(t *testing.T) { +func Test_ConnectClient_HTTP_Fail(t *testing.T) { t.Parallel() targetURL := "somehost:1234" @@ -107,3 +107,82 @@ func TestConnectClientFail(t *testing.T) { _, err = connClient.DialStream(context.Background(), targetURL) require.Error(t, err, "unexpected status code: 400") } + +func Test_ConnectClient_HTTP2_Ok(t *testing.T) { + t.Parallel() + + targetSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method, "Method") + w.Header().Set("Content-Type", "text/plain") + _, err := w.Write([]byte("Hello, world!")) + require.NoError(t, err) + })) + defer targetSrv.Close() + + targetURL, err := url.Parse(targetSrv.URL) + require.NoError(t, err) + + tcpDialer := &transport.TCPDialer{Dialer: net.Dialer{}} + proxySrv := httptest.NewUnstartedServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + require.Equal(t, "HTTP/2.0", request.Proto, "Proto") + require.Equal(t, http.MethodConnect, request.Method, "Method") + require.Equal(t, targetURL.Host, request.URL.Host, "Host") + + conn, err := tcpDialer.DialStream(request.Context(), request.URL.Host) + require.NoError(t, err, "DialStream") + + writer.WriteHeader(http.StatusOK) + writer.(http.Flusher).Flush() + + go func() { + _, _ = io.Copy(conn, request.Body) + require.NoError(t, err, "io.Copy") + }() + + _, _ = io.Copy(writer, conn) + require.NoError(t, err, "io.Copy") + })) + proxySrv.EnableHTTP2 = true + proxySrv.StartTLS() + defer proxySrv.Close() + + proxyURL, err := url.Parse(proxySrv.URL) + require.NoError(t, err, "Parse") + + certs := x509.NewCertPool() + for _, c := range proxySrv.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + require.NoError(t, err, "x509.ParseCertificates") + for _, root := range roots { + certs.AddCert(root) + } + } + + connClient, err := NewConnectClient( + tcpDialer, + proxyURL.Host, + WithHTTPS(&tls.Config{RootCAs: certs}), + ) + require.NoError(t, err, "NewConnectClient") + + streamConn, err := connClient.DialStream(context.Background(), targetURL.Host) + require.NoError(t, err, "DialStream") + require.NotNil(t, streamConn, "StreamConn") + + req, err := http.NewRequest(http.MethodGet, targetSrv.URL, nil) + require.NoError(t, err, "NewRequest") + req.Header.Add("Connection", "close") + + err = req.Write(streamConn) + require.NoError(t, err, "Write") + + rd := bufio.NewReader(streamConn) + resp, err := http.ReadResponse(rd, req) + require.NoError(t, err, "ReadResponse") + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "ReadAll") + require.Equal(t, "Hello, world!", string(body)) + + require.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/x/httpconnect/pipe_conn.go b/x/httpconnect/pipe_conn.go index 50f174a3..fc435781 100644 --- a/x/httpconnect/pipe_conn.go +++ b/x/httpconnect/pipe_conn.go @@ -22,7 +22,8 @@ import ( var _ transport.StreamConn = (*pipeConn)(nil) -// pipeConn is a [transport.StreamConn] that overrides [Read], [Write] (and corresponding [Close]) functions with the given [reader] and [writer] +// pipeConn is a [transport.StreamConn] that overrides the Read and Write functions with the provided [io.ReadCloser] and [io.WriteCloser], respectively. +// The CloseRead, CloseWrite, and Close functions first close the [io.ReadCloser] and [io.WriteCloser], and then call the corresponding functions on the connection. type pipeConn struct { reader io.ReadCloser writer io.WriteCloser