Skip to content

Commit

Permalink
Embed mutex in the server's websocket client (#200)
Browse files Browse the repository at this point in the history
This will prevent concurrent writes both in server implementations
that store a reference to the client from the OnConnectedFunc callback.
Concurrent writes can either occur when the server library
responds to agent messages or when the implementation
tries to send messages to the server on different threads.
  • Loading branch information
evan-bradley authored Sep 20, 2023
1 parent deb3388 commit fd3066f
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 10 deletions.
5 changes: 0 additions & 5 deletions internal/examples/server/data/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ type Agent struct {

// Connection to the Agent.
conn types.Connection
// Mutex to protect Send() operation.
connMutex sync.Mutex

// mutex for the fields that follow it.
mux sync.RWMutex
Expand Down Expand Up @@ -421,9 +419,6 @@ func (agent *Agent) calcConnectionSettings(response *protobufs.ServerToAgent) {
}

func (agent *Agent) SendToAgent(msg *protobufs.ServerToAgent) {
agent.connMutex.Lock()
defer agent.connMutex.Unlock()

agent.conn.Send(context.Background(), msg)
}

Expand Down
3 changes: 2 additions & 1 deletion server/serverimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"net"
"net/http"
"sync"

"github.com/gorilla/websocket"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -179,7 +180,7 @@ func (s *server) httpHandler(w http.ResponseWriter, req *http.Request) {
}

func (s *server) handleWSConnection(wsConn *websocket.Conn, connectionCallbacks serverTypes.ConnectionCallbacks) {
agentConn := wsConnection{wsConn: wsConn}
agentConn := wsConnection{wsConn: wsConn, connMutex: &sync.Mutex{}}

defer func() {
// Close the connection when all is done.
Expand Down
122 changes: 122 additions & 0 deletions server/serverimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -723,3 +724,124 @@ func TestDecodeMessage(t *testing.T) {
}
}
}

func TestConnectionAllowsConcurrentWrites(t *testing.T) {
srvConnVal := atomic.Value{}
callbacks := CallbacksStruct{
OnConnectingFunc: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{
OnConnectedFunc: func(conn types.Connection) {
srvConnVal.Store(conn)
},
}}
},
}

// Start a Server.
settings := &StartSettings{Settings: Settings{Callbacks: callbacks}}
srv := startServer(t, settings)
defer srv.Stop(context.Background())

// Connect to the Server.
conn, _, err := dialClient(settings)

// Verify that the connection is successful.
assert.NoError(t, err)
assert.NotNil(t, conn)

defer conn.Close()

timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second)

select {
case <-timeout.Done():
t.Error("Client failed to connect before timeout")
default:
if _, ok := srvConnVal.Load().(types.Connection); ok == true {
break
}
}

cancel()

srvConn := srvConnVal.Load().(types.Connection)
for i := 0; i < 20; i++ {
go func() {
defer func() {
if recover() != nil {
require.Fail(t, "Sending to client panicked")
}
}()

srvConn.Send(context.Background(), &protobufs.ServerToAgent{})
}()
}
}

func BenchmarkSendToClient(b *testing.B) {
clientConnections := []*websocket.Conn{}
serverConnections := []types.Connection{}
srvConnectionsMutex := sync.Mutex{}
callbacks := CallbacksStruct{
OnConnectingFunc: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{
OnConnectedFunc: func(conn types.Connection) {
srvConnectionsMutex.Lock()
serverConnections = append(serverConnections, conn)
srvConnectionsMutex.Unlock()
},
}}
},
}

// Start a Server.
settings := &StartSettings{
Settings: Settings{Callbacks: callbacks},
ListenEndpoint: testhelpers.GetAvailableLocalAddress(),
ListenPath: "/",
}
srv := New(&sharedinternal.NopLogger{})
err := srv.Start(*settings)

if err != nil {
b.Error(err)
}

defer srv.Stop(context.Background())

for i := 0; i < b.N; i++ {
conn, resp, err := dialClient(settings)

if err != nil || resp == nil || conn == nil {
b.Error("Could not establish connection:", err)
}

clientConnections = append(clientConnections, conn)
}

timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second)

select {
case <-timeout.Done():
b.Error("Connections failed to establish in time")
default:
if len(serverConnections) == b.N {
break
}
}

cancel()

for _, conn := range serverConnections {
err := conn.Send(context.Background(), &protobufs.ServerToAgent{})

if err != nil {
b.Error(err)
}
}

for _, conn := range clientConnections {
conn.Close()
}

}
13 changes: 9 additions & 4 deletions server/wsconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"net"
"sync"

"github.com/gorilla/websocket"

Expand All @@ -13,7 +14,11 @@ import (

// wsConnection represents a persistent OpAMP connection over a WebSocket.
type wsConnection struct {
wsConn *websocket.Conn
// The websocket library does not allow multiple concurrent write operations,
// so ensure that we only have a single operation in progress at a time.
// For more: https://pkg.go.dev/github.com/gorilla/websocket#hdr-Concurrency
connMutex *sync.Mutex
wsConn *websocket.Conn
}

var _ types.Connection = (*wsConnection)(nil)
Expand All @@ -22,10 +27,10 @@ func (c wsConnection) Connection() net.Conn {
return c.wsConn.UnderlyingConn()
}

// Message header is currently uint64 zero value.
const wsMsgHeader = uint64(0)

func (c wsConnection) Send(_ context.Context, message *protobufs.ServerToAgent) error {
c.connMutex.Lock()
defer c.connMutex.Unlock()

return internal.WriteWSMessage(c.wsConn, message)
}

Expand Down

0 comments on commit fd3066f

Please sign in to comment.