Skip to content

Commit

Permalink
实现了模拟管道用于未来的单元测试
Browse files Browse the repository at this point in the history
优化了SNI扩展服务端ACK,以及服务端对域名证书的验证
  • Loading branch information
Trisia committed Jan 14, 2025
1 parent 253632f commit 47f36e5
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 2 deletions.
15 changes: 13 additions & 2 deletions tlcp/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,7 @@ var errNoCertificates = errors.New("tlcp: no certificates configured")
// getCertificate 根据 客户端Hello消息中的信息选择最佳的数字证书
// 默认返还 Config.Certificates[0] 的数字证书
func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) {
if c.GetCertificate != nil &&
(len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) {
if c.GetCertificate != nil && len(c.Certificates) == 0 {
cert, err := c.GetCertificate(clientHello)
if cert != nil || err != nil {
return cert, err
Expand All @@ -573,6 +572,18 @@ func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, err
if len(c.Certificates) == 0 {
return nil, errNoCertificates
}

//
// 域名证书主机名验证交由证书验证阶段完成
//
//// 如果服务端名称不为空,那么验证证书是否匹配
//if clientHello.ServerName != "" {
// err := c.Certificates[0].Leaf.VerifyHostname(clientHello.ServerName)
// if err != nil {
// return nil, fmt.Errorf("tlcp: certificate does not match requested host name: %v", err)
// }
//}

return &c.Certificates[0], nil
}

Expand Down
5 changes: 5 additions & 0 deletions tlcp/handshake_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ func (hs *serverHandshakeState) processClientHello() error {
}
return err
}
if hs.clientHello.serverName != "" && hs.sigCert != nil {
// 服务端证书中的主机名与客户端提供的主机名匹配
// 设置主机名ACK标志,发送ACK
hs.hello.serverNameAck = true
}

// 选择服务端加密证书
hs.encCert, err = c.config.getEKCertificate(helloInfo)
Expand Down
31 changes: 31 additions & 0 deletions tlcp/handshake_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,34 @@ func Test_ResumedSession(t *testing.T) {
_ = conn.Close()
}
}

func Test_processClientHello(t *testing.T) {
//c, s := mockPipe()
//cli := Client(c, &Config{
// InsecureSkipVerify: true,
// Time: runtimeTime,
//})
//svr := Server(s, &Config{
// Certificates: []Certificate{sigCert, encCert},
// Time: runtimeTime,
//})
//
//done := make(chan bool)
//
//go func() {
// defer close(done)
//
// if err := svr.Handshake(); err != nil {
// t.Errorf("server: %s", err)
// return
// }
// s.Close()
//}()
//if err := cli.Handshake(); err != nil {
// t.Fatalf("client: %s", err)
//}
//
//c.Close()
//<-done

}
134 changes: 134 additions & 0 deletions tlcp/handshake_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package tlcp

import (
"bytes"
"context"
"io"
"net"
"sync/atomic"
"testing"
"time"
)

// pipeConn 实现了 net.Conn 接口
type pipeConn struct {
readCh chan []byte
writeCh chan []byte
closed <-chan struct{}
buffer []byte
closedInt int32
cancel context.CancelFunc
}

// NewMockConn 创建两个 pipeConn 对象,一个用于客户端,一个用于服务器端
func mockPipe() (cli net.Conn, svr net.Conn) {
readCh := make(chan []byte, 1)
writeCh := make(chan []byte, 1)
ctx, cancel := context.WithCancel(context.Background())
return &pipeConn{
readCh: readCh,
writeCh: writeCh,
closed: ctx.Done(),
cancel: cancel,
}, &pipeConn{
readCh: writeCh,
writeCh: readCh,
closed: ctx.Done(),
cancel: cancel,
}
}

// Read 实现了 net.Conn 接口的 Read 方法
func (c *pipeConn) Read(b []byte) (n int, err error) {
if atomic.LoadInt32(&c.closedInt) == 1 {
err = io.EOF
return
}
if len(c.buffer) > 0 {
n = copy(b, c.buffer)
c.buffer = c.buffer[n:]
return n, nil
}

select {
case data := <-c.readCh:
n = copy(b, data)
if n < len(data) {
c.buffer = data[n:]
}
case <-c.closed:
err = io.EOF
}
return
}

// Write 实现了 net.Conn 接口的 Write 方法
func (c *pipeConn) Write(b []byte) (n int, err error) {
if atomic.LoadInt32(&c.closedInt) == 1 {
err = io.EOF
return
}
select {
case c.writeCh <- b:
n = len(b)
case <-c.closed:
err = io.EOF
}
return
}

// Close 实现了 net.Conn 接口的 Close 方法
func (c *pipeConn) Close() error {
if atomic.CompareAndSwapInt32(&c.closedInt, 0, 1) {
c.cancel()
}
return nil
}

// LocalAddr 实现了 net.Conn 接口的 LocalAddr 方法
func (c *pipeConn) LocalAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}
}

// RemoteAddr 实现了 net.Conn 接口的 RemoteAddr 方法
func (c *pipeConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8081}
}

// SetDeadline 实现了 net.Conn 接口的 SetDeadline 方法
func (c *pipeConn) SetDeadline(t time.Time) error {
return nil // 这里没有实现超时逻辑
}

// SetReadDeadline 实现了 net.Conn 接口的 SetReadDeadline 方法
func (c *pipeConn) SetReadDeadline(t time.Time) error {
return nil // 这里没有实现超时逻辑
}

// SetWriteDeadline 实现了 net.Conn 接口的 SetWriteDeadline 方法
func (c *pipeConn) SetWriteDeadline(t time.Time) error {
return nil // 这里没有实现超时逻辑
}

func Test_pipeConn(t *testing.T) {
cli, svr := mockPipe()
defer cli.Close()
defer svr.Close()

data := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
go func() {
time.Sleep(100 * time.Millisecond)
_, _ = svr.Write(data)
}()
buf := make([]byte, 10)
n, err := cli.Read(buf)
if err != nil {
t.Fatal(err)
}
if n != 10 {
t.Fatalf("should be read 10 bytes,but not %d", n)
}
if !bytes.Equal(buf[:n], data) {
t.Fatalf("result not match expect, %02X", buf)
}
}

0 comments on commit 47f36e5

Please sign in to comment.