Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add limit chain elements to prevent grpc conn leaking #1663

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ linters:
- deadcode
- depguard
- dogsled
- dupl
- errcheck
- funlen
- gochecknoinits
Expand Down
4 changes: 3 additions & 1 deletion pkg/networkservice/chains/client/client.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2022 Cisco and/or its affiliates.
// Copyright (c) 2021-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -28,6 +28,7 @@ import (
"github.com/networkservicemesh/sdk/pkg/networkservice/common/clienturl"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/connect"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/dial"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/limit"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/null"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/refresh"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/trimpath"
Expand Down Expand Up @@ -63,6 +64,7 @@ func NewClient(ctx context.Context, clientOpts ...Option) networkservice.Network
dial.WithDialOptions(opts.dialOptions...),
dial.WithDialTimeout(opts.dialTimeout),
),
limit.NewClient(),
},
append(
opts.additionalFunctionality,
Expand Down
114 changes: 114 additions & 0 deletions pkg/networkservice/common/limit/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) 2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package limit provides a chain element that can set limits for the RPC calls.
package limit

import (
"context"
"time"

"github.com/golang/protobuf/ptypes/empty"
"github.com/networkservicemesh/api/pkg/api/networkservice"
"google.golang.org/grpc"

"github.com/networkservicemesh/sdk/pkg/networkservice/common/clientconn"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
"github.com/networkservicemesh/sdk/pkg/tools/log"
)

// Option overrides default values
type Option func(c *limitClient)

// WithDialLimit sets dial limit
func WithDialLimit(d time.Duration) Option {
return func(c *limitClient) {
c.dialLimit = d
}
}

type limitClient struct {
dialLimit time.Duration
}

// NewClient returns new NetworkServiceClient that limits rpc
func NewClient(opts ...Option) networkservice.NetworkServiceClient {
ret := &limitClient{
dialLimit: time.Minute,
}

for _, opt := range opts {
opt(ret)
}

return ret
}

func (n *limitClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) {
cc, ok := clientconn.Load(ctx)
if !ok {
return next.Server(ctx).Request(ctx, request)
}

closer, ok := cc.(interface{ Close() error })
if !ok {
return next.Server(ctx).Request(ctx, request)
}

doneCh := make(chan struct{})
defer close(doneCh)

logger := log.FromContext(ctx).WithField("throttleClient", "Request")

go func() {
select {
case <-time.After(n.dialLimit):
logger.Warn("Reached dial limit, closing connection...")
_ = closer.Close()
case <-doneCh:
return
}
}()
return next.Client(ctx).Request(ctx, request, opts...)
}

func (n *limitClient) Close(ctx context.Context, conn *networkservice.Connection, opts ...grpc.CallOption) (*empty.Empty, error) {
cc, ok := clientconn.Load(ctx)
if !ok {
return next.Server(ctx).Close(ctx, conn)
}

closer, ok := cc.(interface{ Close() error })
if !ok {
return next.Server(ctx).Close(ctx, conn)
}

doneCh := make(chan struct{})
defer close(doneCh)

logger := log.FromContext(ctx).WithField("throttleClient", "Close")

go func() {
select {
case <-time.After(n.dialLimit):
logger.Warn("Reached dial limit, closing connection...")
_ = closer.Close()
case <-doneCh:
return
}
}()
return next.Client(ctx).Close(ctx, conn, opts...)
}
106 changes: 106 additions & 0 deletions pkg/networkservice/common/limit/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package limit_test

import (
"context"
"sync/atomic"
"testing"
"time"

"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/grpc"

"github.com/networkservicemesh/sdk/pkg/networkservice/common/clientconn"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/limit"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/chain"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkclose"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkrequest"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata"
)

type myConnection struct {
closed atomic.Bool
grpc.ClientConnInterface
}

func (cc *myConnection) Close() error {
cc.closed.Store(true)
return nil
}

func Test_DialLimitShouldCalled_OnLimitReached_Request(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

var cc = new(myConnection)
var myChain = chain.NewNetworkServiceClient(
metadata.NewClient(),
clientconn.NewClient(cc),
limit.NewClient(limit.WithDialLimit(time.Second/5)),
checkrequest.NewClient(t, func(t *testing.T, nsr *networkservice.NetworkServiceRequest) {
time.Sleep(time.Second / 4)
}),
)

_, _ = myChain.Request(context.Background(), &networkservice.NetworkServiceRequest{})

require.Eventually(t, func() bool {
return cc.closed.Load()
}, time.Second/2, time.Millisecond*75)
}

func Test_DialLimitShouldCalled_OnLimitReached_Close(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

var cc = new(myConnection)
var myChain = chain.NewNetworkServiceClient(
metadata.NewClient(),
clientconn.NewClient(cc),
limit.NewClient(limit.WithDialLimit(time.Second/5)),
checkclose.NewClient(t, func(t *testing.T, nsr *networkservice.Connection) {
time.Sleep(time.Second / 4)
}),
)

_, _ = myChain.Request(context.Background(), &networkservice.NetworkServiceRequest{})
_, _ = myChain.Close(context.Background(), &networkservice.Connection{})

require.Eventually(t, func() bool {
return cc.closed.Load()
}, time.Second/2, time.Millisecond*75)
}

func Test_DialLimitShouldNotBeCalled_OnSuccesRequest(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

var cc = new(myConnection)
var myChain = chain.NewNetworkServiceClient(
metadata.NewClient(),
clientconn.NewClient(cc),
limit.NewClient(limit.WithDialLimit(time.Second/5)),
)

_, _ = myChain.Request(context.Background(), &networkservice.NetworkServiceRequest{})

require.Never(t, func() bool {
return cc.closed.Load()
}, time.Second/2, time.Millisecond*75)
}

func Test_DialLimitShouldNotBeCalled_OnSuccessClose(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

var cc = new(myConnection)
var myChain = chain.NewNetworkServiceClient(
metadata.NewClient(),
clientconn.NewClient(cc),
limit.NewClient(limit.WithDialLimit(time.Second/5)),
)

_, _ = myChain.Request(context.Background(), &networkservice.NetworkServiceRequest{})
_, _ = myChain.Close(context.Background(), &networkservice.Connection{})

require.Never(t, func() bool {
return cc.closed.Load()
}, time.Second/2, time.Millisecond*75)
}
2 changes: 2 additions & 0 deletions pkg/registry/chains/client/ns_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/networkservicemesh/sdk/pkg/registry/common/dial"
"github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata"
"github.com/networkservicemesh/sdk/pkg/registry/common/heal"
"github.com/networkservicemesh/sdk/pkg/registry/common/limit"
"github.com/networkservicemesh/sdk/pkg/registry/common/null"
"github.com/networkservicemesh/sdk/pkg/registry/common/retry"
"github.com/networkservicemesh/sdk/pkg/registry/core/chain"
Expand Down Expand Up @@ -63,6 +64,7 @@ func NewNetworkServiceRegistryClient(ctx context.Context, opts ...Option) regist
dial.WithDialTimeout(clientOpts.dialTimeout),
dial.WithDialOptions(clientOpts.dialOptions...),
),
limit.NewNetworkServiceRegistryClient(),
},
append(
clientOpts.nsAdditionalFunctionality,
Expand Down
2 changes: 2 additions & 0 deletions pkg/registry/chains/client/nse_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/networkservicemesh/sdk/pkg/registry/common/dial"
"github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata"
"github.com/networkservicemesh/sdk/pkg/registry/common/heal"
"github.com/networkservicemesh/sdk/pkg/registry/common/limit"
"github.com/networkservicemesh/sdk/pkg/registry/common/null"
"github.com/networkservicemesh/sdk/pkg/registry/common/refresh"
"github.com/networkservicemesh/sdk/pkg/registry/common/retry"
Expand Down Expand Up @@ -66,6 +67,7 @@ func NewNetworkServiceEndpointRegistryClient(ctx context.Context, opts ...Option
dial.WithDialTimeout(clientOpts.dialTimeout),
dial.WithDialOptions(clientOpts.dialOptions...),
),
limit.NewNetworkServiceEndpointRegistryClient(),
},
append(
clientOpts.nseAdditionalFunctionality,
Expand Down
4 changes: 2 additions & 2 deletions pkg/registry/chains/proxydns/server_ns_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2020-2022 Doc.ai and/or its affiliates.
//
// Copyright (c) 2022 Cisco Systems, Inc.
// Copyright (c) 2022-2024 Cisco Systems, Inc.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -143,7 +143,7 @@ func TestLocalDomain_NetworkServiceRegistry(t *testing.T) {
registryclient.WithDialOptions(grpc.WithTransportCredentials(insecure.NewCredentials())),
registryclient.WithClientURL(domain1.Registry.URL))

stream, err := client2.Find(context.Background(), &registryapi.NetworkServiceQuery{
stream, err := client2.Find(ctx, &registryapi.NetworkServiceQuery{
NetworkService: &registryapi.NetworkService{
Name: expected.Name,
},
Expand Down
33 changes: 33 additions & 0 deletions pkg/registry/common/limit/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package limit

import "time"

type limitConfig struct {
dialLimit time.Duration
}

// Option overrides default values
type Option func(cfg *limitConfig)

// WithDialLimit sets dial time limit
func WithDialLimit(t time.Duration) Option {
return Option(func(cfg *limitConfig) {
cfg.dialLimit = t
})
}
Loading
Loading