Skip to content

Commit

Permalink
proper error handling in smux
Browse files Browse the repository at this point in the history
  • Loading branch information
xtaci committed May 10, 2019
1 parent 401b0da commit 3752dae
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 54 deletions.
87 changes: 56 additions & 31 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ const (
defaultAcceptBacklog = 1024
)

const (
errBrokenPipe = "broken pipe"
errInvalidProtocol = "invalid protocol version"
errGoAway = "stream id overflows, should start a new connection"
var (
errInvalidProtocol = errors.New("invalid protocol")
errGoAway = errors.New("stream id overflows, should start a new connection")
errTimeout = errors.New("timeout")
)

type writeRequest struct {
Expand Down Expand Up @@ -48,8 +48,10 @@ type Session struct {
streams map[uint32]*Stream // all streams in this session
streamLock sync.Mutex // locks streams

die chan struct{} // flag session has died
dieLock sync.Mutex
die chan struct{} // flag session has died
dieOnce sync.Once
socketError atomic.Value // errors from underlying conn

chAccepts chan *Stream

dataReady int32 // flag data has arrived
Expand Down Expand Up @@ -86,36 +88,40 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
// OpenStream is used to create a new stream
func (s *Session) OpenStream() (*Stream, error) {
if s.IsClosed() {
return nil, errors.New(errBrokenPipe)
return nil, errors.WithStack(io.ErrClosedPipe)
}

// generate stream id
s.nextStreamIDLock.Lock()
if s.goAway > 0 {
s.nextStreamIDLock.Unlock()
return nil, errors.New(errGoAway)
return nil, errors.WithStack(errGoAway)
}

s.nextStreamID += 2
sid := s.nextStreamID
if sid == sid%2 { // stream-id overflows
s.goAway = 1
s.nextStreamIDLock.Unlock()
return nil, errors.New(errGoAway)
return nil, errors.WithStack(errGoAway)
}
s.nextStreamIDLock.Unlock()

stream := newStream(sid, s.config.MaxFrameSize, s)

if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
return nil, errors.Wrap(err, "writeFrame")
return nil, errors.WithStack(err)
}

s.streamLock.Lock()
defer s.streamLock.Unlock()
select {
case <-s.die:
return nil, errors.New(errBrokenPipe)
if err := s.socketError.Load(); err != nil {
return nil, errors.WithStack(err.(error))
} else {
return nil, errors.WithStack(io.ErrClosedPipe)
}
default:
s.streams[sid] = stream
return stream, nil
Expand All @@ -135,31 +141,40 @@ func (s *Session) AcceptStream() (*Stream, error) {
case stream := <-s.chAccepts:
return stream, nil
case <-deadline:
return nil, errTimeout
return nil, errors.WithStack(errTimeout)
case <-s.die:
return nil, errors.New(errBrokenPipe)
if err := s.socketError.Load(); err != nil {
return nil, errors.WithStack(err.(error))
} else {
return nil, errors.WithStack(io.ErrClosedPipe)
}
}
}

// Close is used to close the session and all streams.
func (s *Session) Close() (err error) {
s.dieLock.Lock()

select {
case <-s.die:
s.dieLock.Unlock()
return errors.New(errBrokenPipe)
default:
close(s.die)
s.dieLock.Unlock()
err = s.conn.Close()
func (s *Session) Close() error {
var once bool
s.dieOnce.Do(func() {
if err := s.conn.Close(); err != nil {
s.socketError.Store(errors.WithStack(err))
}
s.streamLock.Lock()
for k := range s.streams {
s.streams[k].sessionClose()
}
s.streamLock.Unlock()
return
close(s.die)
})

if err := s.socketError.Load(); err != nil {
return errors.WithStack(err.(error))
}

if !once {
return errors.WithStack(io.ErrClosedPipe)
}

return nil
}

// notifyBucket notifies recvLoop that bucket is available
Expand Down Expand Up @@ -269,15 +284,18 @@ func (s *Session) recvLoop() {
}
s.streamLock.Unlock()
} else {
s.socketError.Store(errors.WithStack(err))
s.Close()
return
}
}
default:
s.socketError.Store(errors.WithStack(errInvalidProtocol))
s.Close()
return
}
} else {
s.socketError.Store(errors.WithStack(err))
s.Close()
return
}
Expand Down Expand Up @@ -345,11 +363,18 @@ func (s *Session) sendLoop() {

result := writeResult{
n: n,
err: err,
err: errors.WithStack(err),
}

request.result <- result
close(request.result)

// store conn error
if err != nil {
s.socketError.Store(errors.WithStack(err))
s.Close()
return
}
}
}
}
Expand All @@ -368,18 +393,18 @@ func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time) (int, e
}
select {
case <-s.die:
return 0, errors.New(errBrokenPipe)
return 0, errors.WithStack(io.ErrClosedPipe)
case s.writes <- req:
case <-deadline:
return 0, errTimeout
return 0, errors.WithStack(errTimeout)
}

select {
case result := <-req.result:
return result.n, result.err
return result.n, errors.WithStack(result.err)
case <-deadline:
return 0, errTimeout
return 0, errors.WithStack(errTimeout)
case <-s.die:
return 0, errors.New(errBrokenPipe)
return 0, errors.WithStack(io.ErrClosedPipe)
}
}
10 changes: 5 additions & 5 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ func TestWriteFrameInternal(t *testing.T) {
close(c)
f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32())
_, err := session.writeFrameInternal(f, c)
if err != errTimeout {
if !strings.Contains(err.Error(), "timeout") {
t.Fatal("write frame with deadline failed", err)
}
}
Expand All @@ -664,8 +664,8 @@ func TestWriteFrameInternal(t *testing.T) {
close(c)
}()
_, err = session.writeFrameInternal(f, c)
if err.Error() != errBrokenPipe {
t.Fatal("write frame with deadline failed", err)
if !strings.Contains(err.Error(), "closed pipe") {
t.Fatal("write frame with to closed conn failed", err)
}
}
}
Expand All @@ -688,7 +688,7 @@ func TestReadDeadline(t *testing.T) {
}
}
if readErr != nil {
if !strings.Contains(readErr.Error(), "i/o timeout") {
if !strings.Contains(readErr.Error(), "timeout") {
t.Fatalf("Wrong error: %v", readErr)
}
} else {
Expand All @@ -710,7 +710,7 @@ func TestWriteDeadline(t *testing.T) {
for {
stream.SetWriteDeadline(time.Now().Add(-1 * time.Minute))
if _, writeErr = stream.Write(buf); writeErr != nil {
if !strings.Contains(writeErr.Error(), "i/o timeout") {
if !strings.Contains(writeErr.Error(), "timeout") {
t.Fatalf("Wrong error: %v", writeErr)
}
break
Expand Down
28 changes: 10 additions & 18 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (s *Stream) Read(b []byte) (n int, err error) {
if len(b) == 0 {
select {
case <-s.die:
return 0, errors.New(errBrokenPipe)
return 0, errors.WithStack(io.ErrClosedPipe)
default:
return 0, nil
}
Expand Down Expand Up @@ -76,16 +76,16 @@ READ:
return n, nil
} else if atomic.LoadInt32(&s.rstflag) == 1 {
_ = s.Close()
return 0, io.EOF
return 0, errors.WithStack(io.EOF)
}

select {
case <-s.chReadEvent:
goto READ
case <-deadline:
return n, errTimeout
return n, errors.WithStack(errTimeout)
case <-s.die:
return 0, errors.New(errBrokenPipe)
return 0, errors.WithStack(io.ErrClosedPipe)
}
}

Expand All @@ -100,7 +100,7 @@ func (s *Stream) Write(b []byte) (n int, err error) {

select {
case <-s.die:
return 0, errors.New(errBrokenPipe)
return 0, errors.WithStack(io.ErrClosedPipe)
default:
}

Expand All @@ -118,7 +118,7 @@ func (s *Stream) Write(b []byte) (n int, err error) {
n, err := s.sess.writeFrameInternal(frame, deadline)
sent += n
if err != nil {
return sent, err
return sent, errors.WithStack(err)
}
}

Expand All @@ -132,13 +132,13 @@ func (s *Stream) Close() error {
select {
case <-s.die:
s.dieLock.Unlock()
return errors.New(errBrokenPipe)
return errors.WithStack(io.ErrClosedPipe)
default:
close(s.die)
s.dieLock.Unlock()
s.sess.streamClosed(s.id)
_, err := s.sess.writeFrame(newFrame(cmdFIN, s.id))
return err
return errors.WithStack(err)
}
}

Expand Down Expand Up @@ -169,10 +169,10 @@ func (s *Stream) SetWriteDeadline(t time.Time) error {
// A zero time value disables the deadlines.
func (s *Stream) SetDeadline(t time.Time) error {
if err := s.SetReadDeadline(t); err != nil {
return err
return errors.WithStack(err)
}
if err := s.SetWriteDeadline(t); err != nil {
return err
return errors.WithStack(err)
}
return nil
}
Expand Down Expand Up @@ -240,11 +240,3 @@ func (s *Stream) notifyReadEvent() {
func (s *Stream) markRST() {
atomic.StoreInt32(&s.rstflag, 1)
}

var errTimeout error = &timeoutError{}

type timeoutError struct{}

func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }

0 comments on commit 3752dae

Please sign in to comment.