diff --git a/close.go b/close.go index ff2e878a..820354e4 100644 --- a/close.go +++ b/close.go @@ -232,12 +232,6 @@ func (c *Conn) waitGoroutines() error { t := time.NewTimer(time.Second * 15) defer t.Stop() - select { - case <-c.timeoutLoopDone: - case <-t.C: - return errors.New("failed to wait for timeoutLoop goroutine to exit") - } - c.closeReadMu.Lock() closeRead := c.closeReadCtx != nil c.closeReadMu.Unlock() diff --git a/conn.go b/conn.go index d7434a9d..052b1e32 100644 --- a/conn.go +++ b/conn.go @@ -52,9 +52,8 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer - readTimeout chan context.Context - writeTimeout chan context.Context - timeoutLoopDone chan struct{} + readTimeoutCloser atomic.Value + writeTimeoutCloser atomic.Value // Read state. readMu *mu @@ -104,10 +103,6 @@ func newConn(cfg connConfig) *Conn { br: cfg.br, bw: cfg.bw, - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), - timeoutLoopDone: make(chan struct{}), - closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } @@ -133,8 +128,6 @@ func newConn(cfg connConfig) *Conn { c.close() }) - go c.timeoutLoop() - return c } @@ -164,26 +157,42 @@ func (c *Conn) close() error { return err } -func (c *Conn) timeoutLoop() { - defer close(c.timeoutLoopDone) +func (c *Conn) setupWriteTimeout(ctx context.Context) { + hammerTime := context.AfterFunc(ctx, func() { + c.close() + }) - readCtx := context.Background() - writeCtx := context.Background() + if closer := c.writeTimeoutCloser.Swap(hammerTime); closer != nil { + if fn, ok := closer.(func() bool); ok { + fn() + } + } +} - for { - select { - case <-c.closed: - return - - case writeCtx = <-c.writeTimeout: - case readCtx = <-c.readTimeout: - - case <-readCtx.Done(): - c.close() - return - case <-writeCtx.Done(): - c.close() - return +func (c *Conn) clearWriteTimeout() { + if closer := c.writeTimeoutCloser.Load(); closer != nil { + if fn, ok := closer.(func() bool); ok { + fn() + } + } +} + +func (c *Conn) setupReadTimeout(ctx context.Context) { + hammerTime := context.AfterFunc(ctx, func() { + c.close() + }) + + if closer := c.readTimeoutCloser.Swap(hammerTime); closer != nil { + if fn, ok := closer.(func() bool); ok { + fn() + } + } +} + +func (c *Conn) clearReadTimeout() { + if closer := c.readTimeoutCloser.Load(); closer != nil { + if fn, ok := closer.(func() bool); ok { + fn() } } } diff --git a/go.mod b/go.mod index 336411a5..ddd420d8 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/coder/websocket -go 1.19 +go 1.21 diff --git a/internal/examples/go.mod b/internal/examples/go.mod index 2aa1ee02..d559967b 100644 --- a/internal/examples/go.mod +++ b/internal/examples/go.mod @@ -1,6 +1,6 @@ module github.com/coder/websocket/examples -go 1.19 +go 1.21 replace github.com/coder/websocket => ../.. diff --git a/internal/thirdparty/go.mod b/internal/thirdparty/go.mod index e060ce67..be128ef4 100644 --- a/internal/thirdparty/go.mod +++ b/internal/thirdparty/go.mod @@ -1,6 +1,6 @@ module github.com/coder/websocket/internal/thirdparty -go 1.19 +go 1.21 replace github.com/coder/websocket => ../.. diff --git a/internal/thirdparty/go.sum b/internal/thirdparty/go.sum index 2352ac75..a7be7082 100644 --- a/internal/thirdparty/go.sum +++ b/internal/thirdparty/go.sum @@ -16,6 +16,7 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= @@ -31,6 +32,7 @@ github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakr github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= @@ -96,6 +98,7 @@ golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/read.go b/read.go index e2699da5..89292356 100644 --- a/read.go +++ b/read.go @@ -221,7 +221,8 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: return header{}, net.ErrClosed - case c.readTimeout <- ctx: + default: + c.setupReadTimeout(ctx) } h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) @@ -239,7 +240,8 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: return header{}, net.ErrClosed - case c.readTimeout <- context.Background(): + default: + c.clearReadTimeout() } return h, nil @@ -249,7 +251,8 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: return 0, net.ErrClosed - case c.readTimeout <- ctx: + default: + c.setupReadTimeout(ctx) } n, err := io.ReadFull(c.br, p) @@ -267,7 +270,8 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: return n, net.ErrClosed - case c.readTimeout <- context.Background(): + default: + c.clearReadTimeout() } return n, err diff --git a/write.go b/write.go index e294a680..ac0a1ac7 100644 --- a/write.go +++ b/write.go @@ -252,7 +252,8 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco select { case <-c.closed: return 0, net.ErrClosed - case c.writeTimeout <- ctx: + default: + c.setupWriteTimeout(ctx) } defer func() { @@ -309,7 +310,8 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco return n, nil } return n, net.ErrClosed - case c.writeTimeout <- context.Background(): + default: + c.clearWriteTimeout() } return n, nil