Skip to content

Commit

Permalink
Merge pull request #3936 from amogh09/ebs-healthcheck
Browse files Browse the repository at this point in the history
Add health check feature to EBS CSI Driver
  • Loading branch information
fierlion authored Oct 2, 2023
2 parents 57037b5 + 886be23 commit 8e63694
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 4 deletions.
26 changes: 26 additions & 0 deletions ecs-agent/daemonimages/csidriver/driver/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ var (
FSTypeXfs: {},
FSTypeNtfs: {},
}

// nodeCaps represents the capabilities of node service.
nodeCaps = []csi.NodeServiceCapability_RPC_Type{
csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME,
csi.NodeServiceCapability_RPC_GET_VOLUME_STATS,
}
)

// nodeService represents the node service of CSI driver.
Expand Down Expand Up @@ -412,3 +418,23 @@ func hasMountOption(options []string, opt string) bool {
}
return false
}

// Returns the capabilities of this node service.
func (d *nodeService) NodeGetCapabilities(
ctx context.Context,
req *csi.NodeGetCapabilitiesRequest,
) (*csi.NodeGetCapabilitiesResponse, error) {
klog.V(4).InfoS("NodeGetCapabilities: called", "args", *req)
var caps []*csi.NodeServiceCapability
for _, cap := range nodeCaps {
c := &csi.NodeServiceCapability{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Type: cap,
},
},
}
caps = append(caps, c)
}
return &csi.NodeGetCapabilitiesResponse{Capabilities: caps}, nil
}
47 changes: 47 additions & 0 deletions ecs-agent/daemonimages/csidriver/driver/node_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//go:build unit
// +build unit

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package driver

import (
"context"
"testing"

"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// Tests that NodeGetCapabilities returns the node's capabilities
func TestNodeGetCapabilities(t *testing.T) {
node := &nodeService{}
res, err := node.NodeGetCapabilities(context.Background(), &csi.NodeGetCapabilitiesRequest{})
require.NoError(t, err)

capTypes := []csi.NodeServiceCapability_RPC_Type{}
for _, cap := range res.Capabilities {
capTypes = append(capTypes, cap.GetRpc().GetType())
}

expectedCapTypes := []csi.NodeServiceCapability_RPC_Type{
csi.NodeServiceCapability_RPC_GET_VOLUME_STATS,
csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME,
}
assert.Equal(t, len(expectedCapTypes), len(capTypes))
for _, expectedCapType := range expectedCapTypes {
assert.Contains(t, capTypes, expectedCapType)
}
}
68 changes: 68 additions & 0 deletions ecs-agent/daemonimages/csidriver/health/health.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

// This package contains utilities related to the health of the CSI Driver.
package health

import (
"context"
"fmt"
"net"
"time"

"github.com/aws/amazon-ecs-agent/ecs-agent/daemonimages/csidriver/util"
"github.com/container-storage-interface/spec/lib/go/csi"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"k8s.io/klog/v2"
)

const (
defaultResponseTimeout = 5 * time.Second
)

// Checks the health of an already running CSI Driver Server by querying for
// Node Service capabilities.
func CheckHealth(endpoint string) error {
// Parse the endpoint
scheme, endpoint, err := util.ParseEndpointNoRemove(endpoint)
if err != nil {
return fmt.Errorf("failed to parse endpoint: %w", err)
}

// Connect to the server
dialer := func(ctx context.Context, addr string) (net.Conn, error) {
return net.Dial(scheme, addr)
}
conn, err := grpc.Dial(
endpoint,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(dialer),
)
if err != nil {
return fmt.Errorf("failed to connect to the server: %w", err)
}
defer conn.Close()

// Call the server to fetch node capabilities
client := csi.NewNodeClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), defaultResponseTimeout)
defer cancel()
res, err := client.NodeGetCapabilities(ctx, &csi.NodeGetCapabilitiesRequest{})
if err != nil {
return fmt.Errorf("failed to get node capabilities: %w", err)
}
klog.Infof("Node capabilities: %s", res.String())

return nil
}
12 changes: 12 additions & 0 deletions ecs-agent/daemonimages/csidriver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ package main

import (
"flag"
"os"

"k8s.io/klog/v2"

"github.com/aws/amazon-ecs-agent/ecs-agent/daemonimages/csidriver/driver"
"github.com/aws/amazon-ecs-agent/ecs-agent/daemonimages/csidriver/health"
)

func main() {
Expand All @@ -32,6 +34,16 @@ func main() {

klog.V(4).InfoS("Server Options are provided", "ServerOptions", srvOptions)

if srvOptions.HealthCheck {
// Perform health check and exit
err := health.CheckHealth(srvOptions.Endpoint)
if err != nil {
klog.Errorf("Health check failed: %s", err.Error())
os.Exit(1)
}
os.Exit(0)
}

drv, err := driver.NewDriver(
driver.WithEndpoint(srvOptions.Endpoint),
)
Expand Down
4 changes: 4 additions & 0 deletions ecs-agent/daemonimages/csidriver/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@ const emptyCSIEndpoint = ""
type ServerOptions struct {
// Endpoint is the endpoint that the driver server should listen on.
Endpoint string
// If enabled, the program performs a health check on an existing server
// instead of starting a new server.
HealthCheck bool
}

func GetServerOptions(fs *flag.FlagSet) (*ServerOptions, error) {
serverOptions := &ServerOptions{}
fs.StringVar(&serverOptions.Endpoint, "endpoint", emptyCSIEndpoint, "Endpoint for the CSI driver server")
fs.BoolVar(&serverOptions.HealthCheck, "healthcheck", false, "Check health of an existing server")

args := os.Args[1:]
if err := fs.Parse(args); err != nil {
Expand Down
8 changes: 8 additions & 0 deletions ecs-agent/daemonimages/csidriver/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ func TestGetServerOptions(t *testing.T) {
assert.EqualError(t, err, "no argument is provided")
},
},
{
name: "healthcheck argument is given",
testFunc: func(t *testing.T) {
opts, err := testFunc(t, []string{"--endpoint=foo", "--healthcheck"})
assert.NoError(t, err)
assert.True(t, opts.HealthCheck)
},
},
}

for _, tc := range testCases {
Expand Down
21 changes: 17 additions & 4 deletions ecs-agent/daemonimages/csidriver/util/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,34 @@ import (
)

func ParseEndpoint(endpoint string) (string, string, error) {
scheme, addr, err := ParseEndpointNoRemove(endpoint)
if err != nil {
return "", "", err
}

if scheme == "unix" {
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
return "", "", fmt.Errorf("could not remove unix domain socket %q: %w", addr, err)
}
}

return scheme, addr, nil
}

// Parses the endpoint but doesn't remove the unix socket file if
// the endoint is a unix socket.
func ParseEndpointNoRemove(endpoint string) (string, string, error) {
u, err := url.Parse(endpoint)
if err != nil {
return "", "", fmt.Errorf("could not parse endpoint: %w", err)
}

addr := filepath.Join(u.Host, filepath.FromSlash(u.Path))

scheme := strings.ToLower(u.Scheme)
switch scheme {
case "tcp":
case "unix":
addr = filepath.Join("/", addr)
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
return "", "", fmt.Errorf("could not remove unix domain socket %q: %w", addr, err)
}
default:
return "", "", fmt.Errorf("unsupported protocol: %s", scheme)
}
Expand Down

0 comments on commit 8e63694

Please sign in to comment.