generated from mccutchen/go-pkg-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwebsocket.go
401 lines (363 loc) · 11 KB
/
websocket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
// Package websocket implements a basic websocket server.
package websocket
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"unicode/utf8"
)
// Default options.
const (
DefaultMaxFrameSize int = 16 << 10 // 16KiB
DefaultMaxMessageSize int = 256 << 10 // 256KiB
)
// Mode enalbes server or client behavior
type Mode bool
// Valid modes
const (
ServerMode Mode = false
ClientMode = true
)
// Options define the limits imposed on a websocket connection.
type Options struct {
Hooks Hooks
ReadTimeout time.Duration
WriteTimeout time.Duration
MaxFrameSize int
MaxMessageSize int
}
type deadliner interface {
SetReadDeadline(time.Time) error
SetWriteDeadline(time.Time) error
}
// Websocket is a websocket connection.
type Websocket struct {
// connection state
conn io.ReadWriteCloser
closedCh chan struct{}
mode Mode
// observability
clientKey ClientKey
hooks Hooks
// limits
readTimeout time.Duration
writeTimeout time.Duration
maxFrameSize int
maxMessageSize int
}
// Accept handles the initial HTTP-based handshake and upgrades the TCP
// connection to a websocket connection.
func Accept(w http.ResponseWriter, r *http.Request, opts Options) (*Websocket, error) {
clientKey, err := Handshake(w, r)
if err != nil {
return nil, fmt.Errorf("websocket: accept: handshake failed: %w", err)
}
hj, ok := w.(http.Hijacker)
if !ok {
panic("websocket: accept: server does not support hijacking")
}
conn, _, err := hj.Hijack()
if err != nil {
panic(fmt.Errorf("websocket: accept: hijack failed: %s", err))
}
return New(conn, clientKey, ServerMode, opts), nil
}
// New is a low-level API that manually creates a new websocket connection.
// Caller is responsible for completing initial handshake before creating a
// websocket connection.
//
// Prefer the higher-level [Accept] API when possible. See also [Handshake] if
// using New directly.
func New(src io.ReadWriteCloser, clientKey ClientKey, mode Mode, opts Options) *Websocket {
setDefaults(&opts)
if opts.ReadTimeout != 0 || opts.WriteTimeout != 0 {
if _, ok := src.(deadliner); !ok {
panic("ReadTimeout and WriteTimeout may only be used when input source supports setting read/write deadlines")
}
}
return &Websocket{
conn: src,
closedCh: make(chan struct{}),
mode: mode,
clientKey: clientKey,
hooks: opts.Hooks,
readTimeout: opts.ReadTimeout,
writeTimeout: opts.WriteTimeout,
maxFrameSize: opts.MaxFrameSize,
maxMessageSize: opts.MaxMessageSize,
}
}
// setDefaults sets the default values for any unset options.
func setDefaults(opts *Options) {
if opts.MaxFrameSize <= 0 {
opts.MaxFrameSize = DefaultMaxFrameSize
}
if opts.MaxMessageSize <= 0 {
opts.MaxMessageSize = DefaultMaxMessageSize
}
setupHooks(&opts.Hooks)
}
// Handshake is a low-level helper that validates the request and performs
// the WebSocket Handshake, after which only websocket frames should be
// written to the underlying connection.
//
// Prefer the higher-level [Accept] API when possible.
func Handshake(w http.ResponseWriter, r *http.Request) (ClientKey, error) {
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
return "", fmt.Errorf("missing required `Upgrade: websocket` header")
}
if v := r.Header.Get("Sec-Websocket-Version"); v != requiredVersion {
return "", fmt.Errorf("only websocket version %q is supported, got %q", requiredVersion, v)
}
clientKey := r.Header.Get("Sec-Websocket-Key")
if clientKey == "" {
return "", fmt.Errorf("missing required `Sec-Websocket-Key` header")
}
w.Header().Set("Connection", "upgrade")
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Sec-Websocket-Accept", acceptKey(clientKey))
w.WriteHeader(http.StatusSwitchingProtocols)
return ClientKey(clientKey), nil
}
// ReadMessage reads a single [Message] from the connection, handling
// fragments and control frames automatically. The connection will be closed
// on any error.
func (ws *Websocket) ReadMessage(ctx context.Context) (*Message, error) {
var msg *Message
for {
select {
case <-ws.closedCh:
return nil, io.EOF
case <-ctx.Done():
_ = ws.Close()
return nil, ctx.Err()
default:
ws.resetReadDeadline()
}
frame, err := ReadFrame(ws.conn, ws.mode, ws.maxFrameSize)
if err != nil {
return nil, ws.closeOnReadError(err)
}
if err := validateFrame(frame); err != nil {
return nil, ws.closeOnReadError(err)
}
ws.hooks.OnReadFrame(ws.clientKey, frame)
opcode := frame.Opcode()
switch opcode {
case OpcodeBinary, OpcodeText:
if msg != nil {
return nil, ws.closeOnReadError(ErrContinuationExpected)
}
msg = &Message{
Binary: opcode == OpcodeBinary,
Payload: frame.Payload,
}
case OpcodeContinuation:
if msg == nil {
return nil, ws.closeOnReadError(ErrContinuationUnexpected)
}
if len(msg.Payload)+len(frame.Payload) > ws.maxMessageSize {
return nil, ws.closeOnReadError(ErrMessageTooLarge)
}
msg.Payload = append(msg.Payload, frame.Payload...)
case OpcodeClose:
_ = ws.Close()
return nil, io.EOF
case OpcodePing:
frame = NewFrame(OpcodePong, true, frame.Payload)
ws.hooks.OnWriteFrame(ws.clientKey, frame)
if err := WriteFrame(ws.conn, ws.mask(), frame); err != nil {
return nil, err
}
continue
case OpcodePong:
continue // no-op
default:
return nil, ws.closeOnReadError(ErrOpcodeUnknown)
}
if frame.Fin() {
if !msg.Binary && !utf8.Valid(msg.Payload) {
return nil, ws.closeOnReadError(ErrInvalidFramePayload)
}
ws.hooks.OnReadMessage(ws.clientKey, msg)
return msg, nil
}
}
}
// WriteMessage writes a single [Message] to the connection, after splitting
// it into fragments (if necessary). The connection will be closed on any
// error.
func (ws *Websocket) WriteMessage(ctx context.Context, msg *Message) error {
ws.hooks.OnWriteMessage(ws.clientKey, msg)
for _, frame := range FrameMessage(msg, ws.maxFrameSize) {
ws.hooks.OnWriteFrame(ws.clientKey, frame)
select {
case <-ctx.Done():
return ctx.Err()
case <-ws.closedCh:
return io.EOF
default:
ws.resetWriteDeadline()
}
if err := WriteFrame(ws.conn, ws.mask(), frame); err != nil {
return ws.closeOnWriteError(err)
}
}
return nil
}
// Serve is a high-level convienience method for request-response style
// websocket connections, where the given [Handler] is called for each
// incoming message and its return value is sent back to the client.
//
// See also [EchoHandler].
func (ws *Websocket) Serve(ctx context.Context, handler Handler) {
for {
msg, err := ws.ReadMessage(ctx)
if err != nil {
// an error in Read() closes the connection
return
}
resp, err := handler(ctx, msg)
if err != nil {
_ = ws.closeWithError(err)
return
}
if resp != nil {
if err := ws.WriteMessage(ctx, resp); err != nil {
// an error in Write() closes the connection
return
}
}
}
}
// mask returns an appropriate masking key for use when writing a message's
// frames.
func (ws *Websocket) mask() MaskingKey {
if ws.mode == ServerMode {
return Unmasked
}
return NewMaskingKey()
}
// Close closes a websocket connection.
func (ws *Websocket) Close() error {
return ws.closeWithError(nil)
}
func (ws *Websocket) closeWithError(err error) error {
code, reason := statusCodeForError(err)
ws.hooks.OnClose(ws.clientKey, code, err)
close(ws.closedCh)
if err := WriteFrame(ws.conn, ws.mask(), NewCloseFrame(code, reason)); err != nil {
return fmt.Errorf("websocket: failed to write close frame: %w", err)
}
if err := ws.conn.Close(); err != nil {
return fmt.Errorf("websocket: failed to close connection: %s", err)
}
return nil
}
func (ws *Websocket) closeOnReadError(err error) error {
ws.hooks.OnReadError(ws.clientKey, err)
_ = ws.closeWithError(err)
return err
}
func (ws *Websocket) closeOnWriteError(err error) error {
ws.hooks.OnWriteError(ws.clientKey, err)
_ = ws.closeWithError(err)
return err
}
func (ws *Websocket) resetReadDeadline() {
if ws.readTimeout <= 0 {
return
}
if err := ws.conn.(deadliner).SetReadDeadline(time.Now().Add(ws.readTimeout)); err != nil {
panic(fmt.Sprintf("websocket: failed to set read deadline: %s", err))
}
}
func (ws *Websocket) resetWriteDeadline() {
if ws.writeTimeout <= 0 {
return
}
if err := ws.conn.(deadliner).SetWriteDeadline(time.Now().Add(ws.writeTimeout)); err != nil {
panic(fmt.Sprintf("websocket: failed to set write deadline: %s", err))
}
}
// ClientKey returns the [ClientKey] for a connection.
func (ws *Websocket) ClientKey() ClientKey {
return ws.clientKey
}
// Handler handles a single websocket [Message] as part of the high level
// [Serve] request-response API.
//
// If the returned message is non-nil, it will be sent to the client. If an
// error is returned, the connection will be closed.
type Handler func(ctx context.Context, msg *Message) (*Message, error)
// EchoHandler is a [Handler] that echoes each incoming [Message] back to the
// client.
var EchoHandler Handler = func(_ context.Context, msg *Message) (*Message, error) {
return msg, nil
}
// statusCodeForError returns an appropriate close frame status code and reason
// for the given error. If error is nil, returns a normal closure status code.
func statusCodeForError(err error) (StatusCode, string) {
if err == nil {
return StatusNormalClosure, ""
}
var protoErr *Error
if errors.As(err, &protoErr) {
return protoErr.code, protoErr.Error()
}
var (
code = StatusInternalError
reason = err.Error()
)
switch {
case errors.Is(err, io.EOF):
code = StatusNormalClosure
}
return code, reason
}
// Hooks define the callbacks that are called during the lifecycle of a
// websocket connection.
type Hooks struct {
// OnClose is called when the connection is closed.
OnClose func(ClientKey, StatusCode, error)
// OnReadError is called when a read error occurs.
OnReadError func(ClientKey, error)
// OnReadFrame is called when a frame is read.
OnReadFrame func(ClientKey, *Frame)
// OnReadMessage is called when a complete message is read.
OnReadMessage func(ClientKey, *Message)
// OnWriteError is called when a write error occurs.
OnWriteError func(ClientKey, error)
// OnWriteFrame is called when a frame is written.
OnWriteFrame func(ClientKey, *Frame)
// OnWriteMessage is called when a complete message is written.
OnWriteMessage func(ClientKey, *Message)
}
// setupHooks ensures that all hooks have a default no-op function if unset.
func setupHooks(hooks *Hooks) {
if hooks.OnClose == nil {
hooks.OnClose = func(ClientKey, StatusCode, error) {}
}
if hooks.OnReadError == nil {
hooks.OnReadError = func(ClientKey, error) {}
}
if hooks.OnReadFrame == nil {
hooks.OnReadFrame = func(ClientKey, *Frame) {}
}
if hooks.OnReadMessage == nil {
hooks.OnReadMessage = func(ClientKey, *Message) {}
}
if hooks.OnWriteError == nil {
hooks.OnWriteError = func(ClientKey, error) {}
}
if hooks.OnWriteFrame == nil {
hooks.OnWriteFrame = func(ClientKey, *Frame) {}
}
if hooks.OnWriteMessage == nil {
hooks.OnWriteMessage = func(ClientKey, *Message) {}
}
}