diff --git a/README.md b/README.md index 3f23b4bd..0a906375 100644 --- a/README.md +++ b/README.md @@ -218,8 +218,12 @@ You'll need the following AWS infrastructure components: 1. Amazon Simple Queue Service (SQS) Queue 2. AutoScaling Group Termination Lifecycle Hook -3. Amazon EventBridge Rule -4. IAM Role for the aws-node-termination-handler Queue Processing Pods +3. Instance Tagging +4. Amazon EventBridge Rule +5. IAM Role for the aws-node-termination-handler Queue Processing Pods + +Optional AWS infrastructure components: +1. AutoScaling Group Launch Lifecycle Hook #### 1. Create an SQS Queue: @@ -290,7 +294,7 @@ aws autoscaling put-lifecycle-hook \ --lifecycle-transition=autoscaling:EC2_INSTANCE_TERMINATING \ --default-result=CONTINUE \ --heartbeat-timeout=300 \ - --notification-target-arn \ + --notification-target-arn \ --role-arn ``` @@ -398,6 +402,36 @@ IAM Policy for aws-node-termination-handler Deployment: } ``` +#### 1. Handle ASG Instance Launch Lifecycle Notifications (optional): + +NTH can monitor for new instances launched by an ASG and notify the ASG when the instance is available in the EKS cluster. + +NTH will need to receive notifications of new instance launches within the ASG. We can add a lifecycle hook to the ASG that will send instance launch notifications via EventBridge: + +``` +aws autoscaling put-lifecycle-hook \ + --lifecycle-hook-name=my-k8s-launch-hook \ + --auto-scaling-group-name=my-k8s-asg \ + --lifecycle-transition=autoscaling:EC2_INSTANCE_LAUNCHING \ + --default-result="ABANDON" \ + --heartbeat-timeout=300 +``` + +Alternatively, ASG can send the instance launch notification directly to an SQS Queue: + +``` +aws autoscaling put-lifecycle-hook \ + --lifecycle-hook-name=my-k8s-launch-hook \ + --auto-scaling-group-name=my-k8s-asg \ + --lifecycle-transition=autoscaling:EC2_INSTANCE_LAUNCHING \ + --default-result="ABANDON" \ + --heartbeat-timeout=300 \ + --notification-target-arn \ + --role-arn +``` + +When NTH receives a launch notification, it will periodically check for a node backed by the EC2 instance to join the cluster and for the node to have a status of 'ready.' Once a node becomes ready, NTH will complete the lifecycle hook, prompting the ASG to proceed with terminating the previous instance. If the lifecycle hook is not completed before the timeout, the ASG will take the default action. If the default action is 'ABANDON', the new instance will be terminated, and the notification process will be repeated with another new instance. + ### Installation #### Pod Security Admission diff --git a/cmd/node-termination-handler.go b/cmd/node-termination-handler.go index 9f7dcaf1..6145e0e5 100644 --- a/cmd/node-termination-handler.go +++ b/cmd/node-termination-handler.go @@ -25,6 +25,8 @@ import ( "github.com/aws/aws-node-termination-handler/pkg/config" "github.com/aws/aws-node-termination-handler/pkg/ec2metadata" + "github.com/aws/aws-node-termination-handler/pkg/interruptionevent/asg/launch" + "github.com/aws/aws-node-termination-handler/pkg/interruptionevent/draincordon" "github.com/aws/aws-node-termination-handler/pkg/interruptioneventstore" "github.com/aws/aws-node-termination-handler/pkg/logging" "github.com/aws/aws-node-termination-handler/pkg/monitor" @@ -43,8 +45,9 @@ import ( "github.com/aws/aws-sdk-go/service/sqs" "github.com/rs/zerolog" "github.com/rs/zerolog/log" - "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" ) const ( @@ -56,6 +59,10 @@ const ( duplicateErrThreshold = 3 ) +type interruptionEventHandler interface { + HandleEvent(*monitor.InterruptionEvent) error +} + func main() { // Zerolog uses json formatting by default, so change that to a human-readable format instead log.Logger = log.Output(logging.RoutingLevelWriter{ @@ -97,7 +104,16 @@ func main() { nthConfig.Print() log.Fatal().Err(err).Msg("Webhook validation failed,") } - node, err := node.New(nthConfig) + + clusterConfig, err := rest.InClusterConfig() + if err != nil { + log.Fatal().Err(err).Msgf("retreiving cluster config") + } + clientset, err := kubernetes.NewForConfig(clusterConfig) + if err != nil { + log.Fatal().Err(err).Msgf("creating new clientset with config: %v", err) + } + node, err := node.New(nthConfig, clientset) if err != nil { nthConfig.Print() log.Fatal().Err(err).Msg("Unable to instantiate a node for various kubernetes node functions,") @@ -137,7 +153,7 @@ func main() { log.Fatal().Msgf("Unable to find the AWS region to process queue events.") } - recorder, err := observability.InitK8sEventRecorder(nthConfig.EmitKubernetesEvents, nthConfig.NodeName, nthConfig.EnableSQSTerminationDraining, nodeMetadata, nthConfig.KubernetesEventsExtraAnnotations) + recorder, err := observability.InitK8sEventRecorder(nthConfig.EmitKubernetesEvents, nthConfig.NodeName, nthConfig.EnableSQSTerminationDraining, nodeMetadata, nthConfig.KubernetesEventsExtraAnnotations, clientset) if err != nil { nthConfig.Print() log.Fatal().Err(err).Msg("Unable to create Kubernetes event recorder,") @@ -243,6 +259,9 @@ func main() { var wg sync.WaitGroup + asgLaunchHandler := launch.New(interruptionEventStore, *node, nthConfig, metrics, recorder, clientset) + drainCordonHander := draincordon.New(interruptionEventStore, *node, nthConfig, nodeMetadata, metrics, recorder) + for range time.NewTicker(1 * time.Second).C { select { case <-signalChan: @@ -257,7 +276,7 @@ func main() { event.InProgress = true wg.Add(1) recorder.Emit(event.NodeName, observability.Normal, observability.GetReasonForKind(event.Kind, event.Monitor), event.Description) - go drainOrCordonIfNecessary(interruptionEventStore, event, *node, nthConfig, nodeMetadata, metrics, recorder, &wg) + go processInterruptionEvent(interruptionEventStore, event, []interruptionEventHandler{asgLaunchHandler, drainCordonHander}, &wg) default: log.Warn().Msg("all workers busy, waiting") break EventLoop @@ -329,122 +348,25 @@ func watchForCancellationEvents(cancelChan <-chan monitor.InterruptionEvent, int } } -func drainOrCordonIfNecessary(interruptionEventStore *interruptioneventstore.Store, drainEvent *monitor.InterruptionEvent, node node.Node, nthConfig config.Config, nodeMetadata ec2metadata.NodeMetadata, metrics observability.Metrics, recorder observability.K8sEventRecorder, wg *sync.WaitGroup) { +func processInterruptionEvent(interruptionEventStore *interruptioneventstore.Store, event *monitor.InterruptionEvent, eventHandlers []interruptionEventHandler, wg *sync.WaitGroup) { defer wg.Done() - nodeFound := true - nodeName := drainEvent.NodeName - if nthConfig.UseProviderId { - newNodeName, err := node.GetNodeNameFromProviderID(drainEvent.ProviderID) + if event == nil { + log.Error().Msg("processing nil interruption event") + <-interruptionEventStore.Workers + return + } + var err error + for _, eventHandler := range eventHandlers { + err = eventHandler.HandleEvent(event) if err != nil { - log.Err(err).Msgf("Unable to get node name for node with ProviderID '%s' using original AWS event node name ", drainEvent.ProviderID) - } else { - nodeName = newNodeName + log.Error().Err(err).Interface("event", event).Msg("handling event") } } - - nodeLabels, err := node.GetNodeLabels(nodeName) - if err != nil { - log.Err(err).Msgf("Unable to fetch node labels for node '%s' ", nodeName) - nodeFound = false - } - drainEvent.NodeLabels = nodeLabels - if drainEvent.PreDrainTask != nil { - runPreDrainTask(node, nodeName, drainEvent, metrics, recorder) - } - - podNameList, err := node.FetchPodNameList(nodeName) - if err != nil { - log.Err(err).Msgf("Unable to fetch running pods for node '%s' ", nodeName) - } - drainEvent.Pods = podNameList - err = node.LogPods(podNameList, nodeName) - if err != nil { - log.Err(err).Msg("There was a problem while trying to log all pod names on the node") - } - - if nthConfig.CordonOnly || (!nthConfig.EnableSQSTerminationDraining && drainEvent.IsRebalanceRecommendation() && !nthConfig.EnableRebalanceDraining) { - err = cordonNode(node, nodeName, drainEvent, metrics, recorder) - } else { - err = cordonAndDrainNode(node, nodeName, drainEvent, metrics, recorder, nthConfig.EnableSQSTerminationDraining) - } - - if nthConfig.WebhookURL != "" { - webhook.Post(nodeMetadata, drainEvent, nthConfig) - } - - if err != nil { - interruptionEventStore.CancelInterruptionEvent(drainEvent.EventID) - } else { - interruptionEventStore.MarkAllAsProcessed(nodeName) - } - - if (err == nil || (!nodeFound && nthConfig.DeleteSqsMsgIfNodeNotFound)) && drainEvent.PostDrainTask != nil { - runPostDrainTask(node, nodeName, drainEvent, metrics, recorder) - } <-interruptionEventStore.Workers } -func runPreDrainTask(node node.Node, nodeName string, drainEvent *monitor.InterruptionEvent, metrics observability.Metrics, recorder observability.K8sEventRecorder) { - err := drainEvent.PreDrainTask(*drainEvent, node) - if err != nil { - log.Err(err).Msg("There was a problem executing the pre-drain task") - recorder.Emit(nodeName, observability.Warning, observability.PreDrainErrReason, observability.PreDrainErrMsgFmt, err.Error()) - } else { - recorder.Emit(nodeName, observability.Normal, observability.PreDrainReason, observability.PreDrainMsg) - } - metrics.NodeActionsInc("pre-drain", nodeName, drainEvent.EventID, err) -} - -func cordonNode(node node.Node, nodeName string, drainEvent *monitor.InterruptionEvent, metrics observability.Metrics, recorder observability.K8sEventRecorder) error { - err := node.Cordon(nodeName, drainEvent.Description) - if err != nil { - if errors.IsNotFound(err) { - log.Err(err).Msgf("node '%s' not found in the cluster", nodeName) - } else { - log.Err(err).Msg("There was a problem while trying to cordon the node") - recorder.Emit(nodeName, observability.Warning, observability.CordonErrReason, observability.CordonErrMsgFmt, err.Error()) - } - return err - } else { - log.Info().Str("node_name", nodeName).Str("reason", drainEvent.Description).Msg("Node successfully cordoned") - metrics.NodeActionsInc("cordon", nodeName, drainEvent.EventID, err) - recorder.Emit(nodeName, observability.Normal, observability.CordonReason, observability.CordonMsg) - } - return nil -} - -func cordonAndDrainNode(node node.Node, nodeName string, drainEvent *monitor.InterruptionEvent, metrics observability.Metrics, recorder observability.K8sEventRecorder, sqsTerminationDraining bool) error { - err := node.CordonAndDrain(nodeName, drainEvent.Description, recorder.EventRecorder) - if err != nil { - if errors.IsNotFound(err) { - log.Err(err).Msgf("node '%s' not found in the cluster", nodeName) - } else { - log.Err(err).Msg("There was a problem while trying to cordon and drain the node") - metrics.NodeActionsInc("cordon-and-drain", nodeName, drainEvent.EventID, err) - recorder.Emit(nodeName, observability.Warning, observability.CordonAndDrainErrReason, observability.CordonAndDrainErrMsgFmt, err.Error()) - } - return err - } else { - log.Info().Str("node_name", nodeName).Str("reason", drainEvent.Description).Msg("Node successfully cordoned and drained") - metrics.NodeActionsInc("cordon-and-drain", nodeName, drainEvent.EventID, err) - recorder.Emit(nodeName, observability.Normal, observability.CordonAndDrainReason, observability.CordonAndDrainMsg) - } - return nil -} - -func runPostDrainTask(node node.Node, nodeName string, drainEvent *monitor.InterruptionEvent, metrics observability.Metrics, recorder observability.K8sEventRecorder) { - err := drainEvent.PostDrainTask(*drainEvent, node) - if err != nil { - log.Err(err).Msg("There was a problem executing the post-drain task") - recorder.Emit(nodeName, observability.Warning, observability.PostDrainErrReason, observability.PostDrainErrMsgFmt, err.Error()) - } else { - recorder.Emit(nodeName, observability.Normal, observability.PostDrainReason, observability.PostDrainMsg) - } - metrics.NodeActionsInc("post-drain", nodeName, drainEvent.EventID, err) -} - func getRegionFromQueueURL(queueURL string) string { for _, partition := range endpoints.DefaultPartitions() { for regionID := range partition.Regions() { diff --git a/pkg/interruptionevent/asg/launch/handler.go b/pkg/interruptionevent/asg/launch/handler.go new file mode 100644 index 00000000..00df82c4 --- /dev/null +++ b/pkg/interruptionevent/asg/launch/handler.go @@ -0,0 +1,155 @@ +// Copyright 2016-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License + +package launch + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/aws-node-termination-handler/pkg/config" + "github.com/aws/aws-node-termination-handler/pkg/interruptionevent/internal/common" + "github.com/aws/aws-node-termination-handler/pkg/interruptioneventstore" + "github.com/aws/aws-node-termination-handler/pkg/monitor" + "github.com/aws/aws-node-termination-handler/pkg/node" + "github.com/aws/aws-node-termination-handler/pkg/observability" + "github.com/rs/zerolog/log" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/selection" + "k8s.io/client-go/kubernetes" +) + +const instanceIDLabel = "alpha.eksctl.io/instance-id" + +type Handler struct { + commonHandler *common.Handler + clientset *kubernetes.Clientset +} + +func New(interruptionEventStore *interruptioneventstore.Store, node node.Node, nthConfig config.Config, metrics observability.Metrics, recorder observability.K8sEventRecorder, clientset *kubernetes.Clientset) *Handler { + commonHandler := &common.Handler{ + InterruptionEventStore: interruptionEventStore, + Node: node, + NthConfig: nthConfig, + Metrics: metrics, + Recorder: recorder, + } + + return &Handler{ + commonHandler: commonHandler, + clientset: clientset, + } +} + +func (h *Handler) HandleEvent(drainEvent *monitor.InterruptionEvent) error { + if drainEvent == nil { + return fmt.Errorf("drainEvent is nil") + } + + if !common.IsAllowedKind(drainEvent.Kind, monitor.ASGLaunchLifecycleKind) { + return nil + } + + isNodeReady, err := h.isNodeReady(drainEvent.InstanceID) + if err != nil { + h.commonHandler.InterruptionEventStore.CancelInterruptionEvent(drainEvent.EventID) + return fmt.Errorf("check if node (instanceID=%s) is present and ready: %w", drainEvent.InstanceID, err) + } + if !isNodeReady { + h.commonHandler.InterruptionEventStore.CancelInterruptionEvent(drainEvent.EventID) + return nil + } + + nodeName, err := h.commonHandler.GetNodeName(drainEvent) + if err != nil { + return fmt.Errorf("get node name for instanceID=%s: %w", drainEvent.InstanceID, err) + } + + if drainEvent.PostDrainTask != nil { + h.commonHandler.RunPostDrainTask(nodeName, drainEvent) + } + return nil +} + +func (h *Handler) isNodeReady(instanceID string) (bool, error) { + nodes, err := h.getNodesWithInstanceID(instanceID) + if err != nil { + return false, fmt.Errorf("find node(s) with instanceId=%s: %w", instanceID, err) + } + + if len(nodes) == 0 { + log.Info().Str("instanceID", instanceID).Msg("EC2 instance not found") + return false, nil + } + + for _, node := range nodes { + conditions := node.Status.Conditions + for _, condition := range conditions { + if condition.Type == "Ready" && condition.Status != "True" { + log.Info().Str("instanceID", instanceID).Msg("EC2 instance found, but not ready") + return false, nil + } + } + } + log.Info().Str("instanceID", instanceID).Msg("EC2 instance is found and ready") + return true, nil +} + +// Gets Nodes connected to K8s cluster +func (h *Handler) getNodesWithInstanceID(instanceID string) ([]v1.Node, error) { + nodes, err := h.getNodesWithInstanceFromLabel(instanceID) + if err != nil { + return nil, err + } + if len(nodes) != 0 { + return nodes, nil + } + + nodes, err = h.getNodesWithInstanceFromProviderID(instanceID) + if err != nil { + return nil, err + } + return nodes, nil +} + +func (h *Handler) getNodesWithInstanceFromLabel(instanceID string) ([]v1.Node, error) { + instanceIDReq, err := labels.NewRequirement(instanceIDLabel, selection.Equals, []string{instanceID}) + if err != nil { + return nil, fmt.Errorf("construct node search requirement %s=%s: %w", instanceIDLabel, instanceID, err) + } + selector := labels.NewSelector().Add(*instanceIDReq) + nodeList, err := h.clientset.CoreV1().Nodes().List(context.TODO(), metav1.ListOptions{LabelSelector: selector.String()}) + if err != nil { + return nil, fmt.Errorf("list nodes using selector %q: %w", selector.String(), err) + } + return nodeList.Items, nil +} + +func (h *Handler) getNodesWithInstanceFromProviderID(instanceID string) ([]v1.Node, error) { + nodeList, err := h.clientset.CoreV1().Nodes().List(context.TODO(), metav1.ListOptions{}) + if err != nil { + return nil, fmt.Errorf("list all nodes: %w", err) + } + + var filteredNodes []v1.Node + for _, node := range nodeList.Items { + if !strings.Contains(node.Spec.ProviderID, instanceID) { + continue + } + filteredNodes = append(filteredNodes, node) + } + return filteredNodes, nil +} diff --git a/pkg/interruptionevent/draincordon/handler.go b/pkg/interruptionevent/draincordon/handler.go new file mode 100644 index 00000000..0360a31c --- /dev/null +++ b/pkg/interruptionevent/draincordon/handler.go @@ -0,0 +1,160 @@ +// Copyright 2016-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License + +package draincordon + +import ( + "fmt" + + "github.com/aws/aws-node-termination-handler/pkg/config" + "github.com/aws/aws-node-termination-handler/pkg/ec2metadata" + "github.com/aws/aws-node-termination-handler/pkg/interruptionevent/internal/common" + "github.com/aws/aws-node-termination-handler/pkg/interruptioneventstore" + "github.com/aws/aws-node-termination-handler/pkg/monitor" + "github.com/aws/aws-node-termination-handler/pkg/node" + "github.com/aws/aws-node-termination-handler/pkg/observability" + "github.com/aws/aws-node-termination-handler/pkg/webhook" + "github.com/rs/zerolog/log" + "k8s.io/apimachinery/pkg/api/errors" +) + +var allowedKinds = []string{ + monitor.ASGLifecycleKind, + monitor.RebalanceRecommendationKind, + monitor.SQSTerminateKind, + monitor.ScheduledEventKind, + monitor.SpotITNKind, + monitor.StateChangeKind, +} + +type Handler struct { + commonHandler *common.Handler + nodeMetadata ec2metadata.NodeMetadata +} + +func New(interruptionEventStore *interruptioneventstore.Store, node node.Node, nthConfig config.Config, nodeMetadata ec2metadata.NodeMetadata, metrics observability.Metrics, recorder observability.K8sEventRecorder) *Handler { + commonHandler := &common.Handler{ + InterruptionEventStore: interruptionEventStore, + Node: node, + NthConfig: nthConfig, + Metrics: metrics, + Recorder: recorder, + } + + return &Handler{ + commonHandler: commonHandler, + nodeMetadata: nodeMetadata, + } +} + +func (h *Handler) HandleEvent(drainEvent *monitor.InterruptionEvent) error { + if !common.IsAllowedKind(drainEvent.Kind, allowedKinds...) { + return nil + } + + nodeFound := true + nodeName, err := h.commonHandler.GetNodeName(drainEvent) + if err != nil { + return fmt.Errorf("get node name for instanceID=%s: %w", drainEvent.InstanceID, err) + } + + nodeLabels, err := h.commonHandler.Node.GetNodeLabels(nodeName) + if err != nil { + log.Warn(). + Err(err). + Interface("fallbackNodeLabels", drainEvent.NodeLabels). + Str("nodeName", nodeName). + Msg("Failed to get node labels. Proceeding with fallback labels") + nodeFound = false + } else { + drainEvent.NodeLabels = nodeLabels + } + + if drainEvent.PreDrainTask != nil { + h.commonHandler.RunPreDrainTask(nodeName, drainEvent) + } + + podNameList, err := h.commonHandler.Node.FetchPodNameList(nodeName) + if err != nil { + log.Warn(). + Err(err). + Strs("fallbackPodNames", podNameList). + Str("nodeName", nodeName). + Msg("Failed to fetch pod names. Proceeding with fallback pod names") + } else { + drainEvent.Pods = podNameList + } + + err = h.commonHandler.Node.LogPods(podNameList, nodeName) + if err != nil { + log.Warn().Err(err).Str("nodeName", nodeName).Msg("Failed to log pods") + } + + if h.commonHandler.NthConfig.CordonOnly || (!h.commonHandler.NthConfig.EnableSQSTerminationDraining && drainEvent.IsRebalanceRecommendation() && !h.commonHandler.NthConfig.EnableRebalanceDraining) { + err = h.cordonNode(nodeName, drainEvent) + } else { + err = h.cordonAndDrainNode(nodeName, drainEvent) + } + + if h.commonHandler.NthConfig.WebhookURL != "" { + webhook.Post(h.nodeMetadata, drainEvent, h.commonHandler.NthConfig) + } + + if err != nil { + h.commonHandler.InterruptionEventStore.CancelInterruptionEvent(drainEvent.EventID) + } else { + h.commonHandler.InterruptionEventStore.MarkAllAsProcessed(nodeName) + } + + if (err == nil || (!nodeFound && h.commonHandler.NthConfig.DeleteSqsMsgIfNodeNotFound)) && drainEvent.PostDrainTask != nil { + h.commonHandler.RunPostDrainTask(nodeName, drainEvent) + } + return nil +} + +func (h *Handler) cordonNode(nodeName string, drainEvent *monitor.InterruptionEvent) error { + err := h.commonHandler.Node.Cordon(nodeName, drainEvent.Description) + if err != nil { + if errors.IsNotFound(err) { + log.Err(err).Msgf("node '%s' not found in the cluster", nodeName) + } else { + log.Err(err).Msg("There was a problem while trying to cordon the node") + h.commonHandler.Recorder.Emit(nodeName, observability.Warning, observability.CordonErrReason, observability.CordonErrMsgFmt, err.Error()) + } + return err + } else { + log.Info().Str("node_name", nodeName).Str("reason", drainEvent.Description).Msg("Node successfully cordoned") + h.commonHandler.Metrics.NodeActionsInc("cordon", nodeName, drainEvent.EventID, err) + h.commonHandler.Recorder.Emit(nodeName, observability.Normal, observability.CordonReason, observability.CordonMsg) + } + return nil +} + +func (h *Handler) cordonAndDrainNode(nodeName string, drainEvent *monitor.InterruptionEvent) error { + err := h.commonHandler.Node.CordonAndDrain(nodeName, drainEvent.Description, h.commonHandler.Recorder.EventRecorder) + if err != nil { + if errors.IsNotFound(err) { + log.Err(err).Msgf("node '%s' not found in the cluster", nodeName) + } else { + log.Err(err).Msg("There was a problem while trying to cordon and drain the node") + h.commonHandler.Metrics.NodeActionsInc("cordon-and-drain", nodeName, drainEvent.EventID, err) + h.commonHandler.Recorder.Emit(nodeName, observability.Warning, observability.CordonAndDrainErrReason, observability.CordonAndDrainErrMsgFmt, err.Error()) + } + return err + } else { + log.Info().Str("node_name", nodeName).Str("reason", drainEvent.Description).Msg("Node successfully cordoned and drained") + h.commonHandler.Metrics.NodeActionsInc("cordon-and-drain", nodeName, drainEvent.EventID, err) + h.commonHandler.Recorder.Emit(nodeName, observability.Normal, observability.CordonAndDrainReason, observability.CordonAndDrainMsg) + } + return nil +} diff --git a/pkg/interruptionevent/internal/common/handler.go b/pkg/interruptionevent/internal/common/handler.go new file mode 100644 index 00000000..0c58366a --- /dev/null +++ b/pkg/interruptionevent/internal/common/handler.go @@ -0,0 +1,76 @@ +// Copyright 2016-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License + +package common + +import ( + "fmt" + + "github.com/aws/aws-node-termination-handler/pkg/config" + "github.com/aws/aws-node-termination-handler/pkg/interruptioneventstore" + "github.com/aws/aws-node-termination-handler/pkg/monitor" + "github.com/aws/aws-node-termination-handler/pkg/node" + "github.com/aws/aws-node-termination-handler/pkg/observability" + "github.com/rs/zerolog/log" +) + +type Handler struct { + InterruptionEventStore *interruptioneventstore.Store + Node node.Node + NthConfig config.Config + Metrics observability.Metrics + Recorder observability.K8sEventRecorder +} + +func (h *Handler) GetNodeName(drainEvent *monitor.InterruptionEvent) (string, error) { + if !h.NthConfig.UseProviderId { + return drainEvent.NodeName, nil + } + + nodeName, err := h.Node.GetNodeNameFromProviderID(drainEvent.ProviderID) + if err != nil { + return "", fmt.Errorf("parse node name from providerID=%q: %w", drainEvent.ProviderID, err) + } + return nodeName, nil +} + +func (h *Handler) RunPreDrainTask(nodeName string, drainEvent *monitor.InterruptionEvent) { + err := drainEvent.PreDrainTask(*drainEvent, h.Node) + if err != nil { + log.Err(err).Msg("There was a problem executing the pre-drain task") + h.Recorder.Emit(nodeName, observability.Warning, observability.PreDrainErrReason, observability.PreDrainErrMsgFmt, err.Error()) + } else { + h.Recorder.Emit(nodeName, observability.Normal, observability.PreDrainReason, observability.PreDrainMsg) + } + h.Metrics.NodeActionsInc("pre-drain", nodeName, drainEvent.EventID, err) +} + +func (h *Handler) RunPostDrainTask(nodeName string, drainEvent *monitor.InterruptionEvent) { + err := drainEvent.PostDrainTask(*drainEvent, h.Node) + if err != nil { + log.Err(err).Msg("There was a problem executing the post-drain task") + h.Recorder.Emit(nodeName, observability.Warning, observability.PostDrainErrReason, observability.PostDrainErrMsgFmt, err.Error()) + } else { + h.Recorder.Emit(nodeName, observability.Normal, observability.PostDrainReason, observability.PostDrainMsg) + } + h.Metrics.NodeActionsInc("post-drain", nodeName, drainEvent.EventID, err) +} + +func IsAllowedKind(kind string, allowedKinds ...string) bool { + for _, allowedKind := range allowedKinds { + if kind == allowedKind { + return true + } + } + return false +} diff --git a/pkg/monitor/sqsevent/asg-lifecycle-event.go b/pkg/monitor/sqsevent/asg-lifecycle-event.go index 5c088030..c1262519 100644 --- a/pkg/monitor/sqsevent/asg-lifecycle-event.go +++ b/pkg/monitor/sqsevent/asg-lifecycle-event.go @@ -20,7 +20,6 @@ import ( "github.com/aws/aws-node-termination-handler/pkg/monitor" "github.com/aws/aws-node-termination-handler/pkg/node" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/autoscaling" "github.com/aws/aws-sdk-go/service/sqs" "github.com/rs/zerolog/log" @@ -50,6 +49,10 @@ import ( const TEST_NOTIFICATION = "autoscaling:TEST_NOTIFICATION" +type LifecycleDetailMessage struct { + Message interface{} `json:"Message"` +} + // LifecycleDetail provides the ASG lifecycle event details type LifecycleDetail struct { LifecycleActionToken string `json:"LifecycleActionToken"` @@ -92,26 +95,7 @@ func (m SQSMonitor) asgTerminationToInterruptionEvent(event *EventBridgeEvent, m } interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, _ node.Node) error { - _, err := m.completeLifecycleAction(&autoscaling.CompleteLifecycleActionInput{ - AutoScalingGroupName: &lifecycleDetail.AutoScalingGroupName, - LifecycleActionResult: aws.String("CONTINUE"), - LifecycleHookName: &lifecycleDetail.LifecycleHookName, - LifecycleActionToken: &lifecycleDetail.LifecycleActionToken, - InstanceId: &lifecycleDetail.EC2InstanceID, - }) - if err != nil { - if aerr, ok := err.(awserr.RequestFailure); ok && aerr.StatusCode() != 400 { - return err - } - } - log.Info().Msgf("Completed ASG Lifecycle Hook (%s) for instance %s", - lifecycleDetail.LifecycleHookName, - lifecycleDetail.EC2InstanceID) - errs := m.deleteMessages([]*sqs.Message{message}) - if errs != nil { - return errs[0] - } - return nil + return m.deleteMessage(message) } interruptionEvent.PreDrainTask = func(interruptionEvent monitor.InterruptionEvent, n node.Node) error { @@ -124,3 +108,72 @@ func (m SQSMonitor) asgTerminationToInterruptionEvent(event *EventBridgeEvent, m return &interruptionEvent, nil } + +func (m SQSMonitor) deleteMessage(message *sqs.Message) error { + errs := m.deleteMessages([]*sqs.Message{message}) + if errs != nil { + return errs[0] + } + return nil +} + +// Continues the lifecycle hook thereby indicating a successful action occured +func (m SQSMonitor) continueLifecycleAction(lifecycleDetail *LifecycleDetail) (*autoscaling.CompleteLifecycleActionOutput, error) { + return m.completeLifecycleAction(&autoscaling.CompleteLifecycleActionInput{ + AutoScalingGroupName: &lifecycleDetail.AutoScalingGroupName, + LifecycleActionResult: aws.String("CONTINUE"), + LifecycleHookName: &lifecycleDetail.LifecycleHookName, + LifecycleActionToken: &lifecycleDetail.LifecycleActionToken, + InstanceId: &lifecycleDetail.EC2InstanceID, + }) +} + +// Completes the ASG launch lifecycle hook if the new EC2 instance launched by ASG is Ready in the cluster +func (m SQSMonitor) createAsgInstanceLaunchEvent(event *EventBridgeEvent, message *sqs.Message) (*monitor.InterruptionEvent, error) { + if event == nil { + return nil, fmt.Errorf("event is nil") + } + + if message == nil { + return nil, fmt.Errorf("message is nil") + } + + lifecycleDetail := &LifecycleDetail{} + err := json.Unmarshal(event.Detail, lifecycleDetail) + if err != nil { + return nil, fmt.Errorf("unmarshaling message, %s, from ASG launch lifecycle event: %w", *message.MessageId, err) + } + + if lifecycleDetail.Event == TEST_NOTIFICATION || lifecycleDetail.LifecycleTransition == TEST_NOTIFICATION { + return nil, skip{fmt.Errorf("message is an ASG test notification")} + } + + nodeInfo, err := m.getNodeInfo(lifecycleDetail.EC2InstanceID) + if err != nil { + return nil, err + } + + interruptionEvent := monitor.InterruptionEvent{ + EventID: fmt.Sprintf("asg-lifecycle-term-%x", event.ID), + Kind: monitor.ASGLaunchLifecycleKind, + Monitor: SQSMonitorKind, + AutoScalingGroupName: lifecycleDetail.AutoScalingGroupName, + StartTime: event.getTime(), + NodeName: nodeInfo.Name, + IsManaged: nodeInfo.IsManaged, + InstanceID: lifecycleDetail.EC2InstanceID, + ProviderID: nodeInfo.ProviderID, + Description: fmt.Sprintf("ASG Lifecycle Launch event received. Instance will be interrupted at %s \n", event.getTime()), + } + + interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, _ node.Node) error { + _, err = m.continueLifecycleAction(lifecycleDetail) + if err != nil { + return fmt.Errorf("continuing ASG launch lifecycle: %w", err) + } + log.Info().Str("lifecycleHookName", lifecycleDetail.LifecycleHookName).Str("instanceID", lifecycleDetail.EC2InstanceID).Msg("Completed ASG Lifecycle Hook") + return m.deleteMessage(message) + } + + return &interruptionEvent, err +} diff --git a/pkg/monitor/sqsevent/sqs-monitor.go b/pkg/monitor/sqsevent/sqs-monitor.go index e028506c..7dea9308 100644 --- a/pkg/monitor/sqsevent/sqs-monitor.go +++ b/pkg/monitor/sqsevent/sqs-monitor.go @@ -38,7 +38,9 @@ const ( // SQSMonitorKind is a const to define this monitor kind SQSMonitorKind = "SQS_MONITOR" // ASGTagName is the name of the instance tag whose value is the AutoScaling group name - ASGTagName = "aws:autoscaling:groupName" + ASGTagName = "aws:autoscaling:groupName" + ASGTerminatingLifecycleTransition = "autoscaling:EC2_INSTANCE_TERMINATING" + ASGLaunchingLifecycleTransition = "autoscaling:EC2_INSTANCE_LAUNCHING" ) // SQSMonitor is a struct definition that knows how to process events from Amazon EventBridge @@ -60,6 +62,7 @@ type InterruptionEventWrapper struct { Err error } +// Used to skip processing an error, but acknowledge an error occured during a termination event type skip struct { err error } @@ -130,16 +133,41 @@ func (m SQSMonitor) processSQSMessage(message *sqs.Message) (*EventBridgeEvent, return &event, err } +func parseLifecycleEvent(message string) (LifecycleDetail, error) { + lifecycleEventMessage := LifecycleDetailMessage{} + lifecycleEvent := LifecycleDetail{} + err := json.Unmarshal([]byte(message), &lifecycleEventMessage) + if err != nil { + return lifecycleEvent, fmt.Errorf("unmarshalling SQS message: %w", err) + } + // Converts escaped JSON object to string, to lifecycle event + if lifecycleEventMessage.Message != nil { + err = json.Unmarshal([]byte(fmt.Sprintf("%v", lifecycleEventMessage.Message)), &lifecycleEvent) + if err != nil { + err = fmt.Errorf("unmarshalling message body from '.Message': %w", err) + } + } else { + err = json.Unmarshal([]byte(fmt.Sprintf("%v", message)), &lifecycleEvent) + if err != nil { + err = fmt.Errorf("unmarshalling message body: %w", err) + } + } + return lifecycleEvent, err +} + // processLifecycleEventFromASG checks for a Lifecycle event from ASG to SQS, and wraps it in an EventBridgeEvent func (m SQSMonitor) processLifecycleEventFromASG(message *sqs.Message) (EventBridgeEvent, error) { + log.Debug().Interface("message", message).Msg("processing lifecycle event from ASG") eventBridgeEvent := EventBridgeEvent{} - lifecycleEvent := LifecycleDetail{} - err := json.Unmarshal([]byte(*message.Body), &lifecycleEvent) + + if message == nil { + return eventBridgeEvent, fmt.Errorf("ASG event message is nil") + } + lifecycleEvent, err := parseLifecycleEvent(*message.Body) switch { case err != nil: - log.Err(err).Msg("only lifecycle events from ASG to SQS are supported outside EventBridge") - return eventBridgeEvent, err + return eventBridgeEvent, fmt.Errorf("parsing lifecycle event messsage from ASG: %w", err) case lifecycleEvent.Event == TEST_NOTIFICATION || lifecycleEvent.LifecycleTransition == TEST_NOTIFICATION: err := fmt.Errorf("message is a test notification") @@ -148,18 +176,15 @@ func (m SQSMonitor) processLifecycleEventFromASG(message *sqs.Message) (EventBri } return eventBridgeEvent, skip{err} - case lifecycleEvent.LifecycleTransition != "autoscaling:EC2_INSTANCE_TERMINATING": - log.Err(err).Msg("only lifecycle termination events from ASG to SQS are supported outside EventBridge") - err = fmt.Errorf("unsupported message type (%s)", message.String()) - return eventBridgeEvent, err + case lifecycleEvent.LifecycleTransition != ASGTerminatingLifecycleTransition && + lifecycleEvent.LifecycleTransition != ASGLaunchingLifecycleTransition: + return eventBridgeEvent, fmt.Errorf("lifecycle transition must be %s or %s. Got %s", ASGTerminatingLifecycleTransition, ASGLaunchingLifecycleTransition, lifecycleEvent.LifecycleTransition) } eventBridgeEvent.Source = "aws.autoscaling" eventBridgeEvent.Time = lifecycleEvent.Time eventBridgeEvent.ID = lifecycleEvent.RequestID eventBridgeEvent.Detail, err = json.Marshal(lifecycleEvent) - - log.Debug().Msg("processing lifecycle termination event from ASG") return eventBridgeEvent, err } @@ -169,10 +194,29 @@ func (m SQSMonitor) processEventBridgeEvent(eventBridgeEvent *EventBridgeEvent, interruptionEvent := &monitor.InterruptionEvent{} var err error + if eventBridgeEvent == nil { + return append(interruptionEventWrappers, InterruptionEventWrapper{nil, fmt.Errorf("eventBridgeEvent is nil")}) + } + if message == nil { + return append(interruptionEventWrappers, InterruptionEventWrapper{nil, fmt.Errorf("message is nil")}) + } + switch eventBridgeEvent.Source { case "aws.autoscaling": - interruptionEvent, err = m.asgTerminationToInterruptionEvent(eventBridgeEvent, message) - return append(interruptionEventWrappers, InterruptionEventWrapper{interruptionEvent, err}) + lifecycleEvent := LifecycleDetail{} + err = json.Unmarshal([]byte(eventBridgeEvent.Detail), &lifecycleEvent) + if err != nil { + interruptionEvent, err = nil, fmt.Errorf("unmarshaling message, %s, from ASG lifecycle event: %w", *message.MessageId, err) + interruptionEventWrappers = append(interruptionEventWrappers, InterruptionEventWrapper{interruptionEvent, err}) + } + if lifecycleEvent.LifecycleTransition == ASGLaunchingLifecycleTransition { + interruptionEvent, err = m.createAsgInstanceLaunchEvent(eventBridgeEvent, message) + interruptionEventWrappers = append(interruptionEventWrappers, InterruptionEventWrapper{interruptionEvent, err}) + } else if lifecycleEvent.LifecycleTransition == ASGTerminatingLifecycleTransition { + interruptionEvent, err = m.asgTerminationToInterruptionEvent(eventBridgeEvent, message) + interruptionEventWrappers = append(interruptionEventWrappers, InterruptionEventWrapper{interruptionEvent, err}) + } + return interruptionEventWrappers case "aws.ec2": if eventBridgeEvent.DetailType == "EC2 Instance State-change Notification" { diff --git a/pkg/monitor/sqsevent/sqs-monitor_test.go b/pkg/monitor/sqsevent/sqs-monitor_test.go index 61199ed2..8e827377 100644 --- a/pkg/monitor/sqsevent/sqs-monitor_test.go +++ b/pkg/monitor/sqsevent/sqs-monitor_test.go @@ -67,6 +67,26 @@ var asgLifecycleEvent = sqsevent.EventBridgeEvent{ }`), } +var asgLaunchLifecycleEvent = sqsevent.EventBridgeEvent{ + Version: "0", + ID: "83c632dd-0145-1ab0-ae93-a756ebf429b5", + DetailType: "EC2 Instance-launch Lifecycle Action", + Source: "aws.autoscaling", + Account: "123456789012", + Time: "2020-07-01T22:30:58Z", + Region: "us-east-1", + Resources: []string{ + "arn:aws:autoscaling:us-east-1:123456789012:autoScalingGroup:c4c64181-52c1-dd3f-20bb-f4a0965a09db:autoScalingGroupName/nth-test1", + }, + Detail: []byte(`{ + "LifecycleActionToken": "524632c5-3333-d52d-3992-d9633ec24ed7", + "AutoScalingGroupName": "nth-test1", + "LifecycleHookName": "node-termination-handler-launch", + "EC2InstanceId": "i-0a68bf5ef13e21b52", + "LifecycleTransition": "autoscaling:EC2_INSTANCE_LAUNCHING" + }`), +} + var asgLifecycleEventFromSQS = sqsevent.LifecycleDetail{ LifecycleHookName: "test-nth-asg-to-sqs", RequestID: "3775fac9-93c3-7ead-8713-159816566000", @@ -352,7 +372,7 @@ func TestMonitor_DrainTasks(t *testing.T) { } func TestMonitor_DrainTasks_Delay(t *testing.T) { - msg, err := getSQSMessageFromEvent(asgLifecycleEvent) + msg, err := getSQSMessageFromEvent(asgLaunchLifecycleEvent) h.Ok(t, err) sqsMock := h.MockedSQS{ @@ -384,13 +404,12 @@ func TestMonitor_DrainTasks_Delay(t *testing.T) { err = sqsMonitor.Monitor() h.Ok(t, err) - t.Run(asgLifecycleEvent.DetailType, func(st *testing.T) { + t.Run(asgLaunchLifecycleEvent.DetailType, func(st *testing.T) { result := <-drainChan - h.Equals(st, monitor.ASGLifecycleKind, result.Kind) + h.Equals(st, monitor.ASGLaunchLifecycleKind, result.Kind) h.Equals(st, sqsevent.SQSMonitorKind, result.Monitor) h.Equals(st, result.NodeName, dnsNodeName) h.Assert(st, result.PostDrainTask != nil, "PostDrainTask should have been set") - h.Assert(st, result.PreDrainTask != nil, "PreDrainTask should have been set") err := result.PostDrainTask(result, node.Node{}) h.Ok(st, err) h.Assert(st, hookCalled, "BeforeCompleteLifecycleAction hook not called") @@ -457,7 +476,7 @@ func TestMonitor_DrainTasks_Errors(t *testing.T) { } func TestMonitor_DrainTasksASGFailure(t *testing.T) { - msg, err := getSQSMessageFromEvent(asgLifecycleEvent) + msg, err := getSQSMessageFromEvent(asgLaunchLifecycleEvent) h.Ok(t, err) messages := []*sqs.Message{ &msg, @@ -492,11 +511,10 @@ func TestMonitor_DrainTasksASGFailure(t *testing.T) { select { case result := <-drainChan: - h.Equals(t, monitor.ASGLifecycleKind, result.Kind) + h.Equals(t, monitor.ASGLaunchLifecycleKind, result.Kind) h.Equals(t, sqsevent.SQSMonitorKind, result.Monitor) h.Equals(t, result.NodeName, dnsNodeName) h.Assert(t, result.PostDrainTask != nil, "PostDrainTask should have been set") - h.Assert(t, result.PreDrainTask != nil, "PreDrainTask should have been set") err = result.PostDrainTask(result, node.Node{}) h.Nok(t, err) default: diff --git a/pkg/monitor/types.go b/pkg/monitor/types.go index c3c587d2..93d56625 100644 --- a/pkg/monitor/types.go +++ b/pkg/monitor/types.go @@ -31,6 +31,8 @@ const ( StateChangeKind = "STATE_CHANGE" // ASGLifecycleKind is a const to define an ASG Lifecycle kind of interruption event ASGLifecycleKind = "ASG_LIFECYCLE" + // ASGLifecycleKind is a const to define an ASG Launch Lifecycle kind of interruption event + ASGLaunchLifecycleKind = "ASG_LAUNCH_LIFECYCLE" // SQSTerminateKind is a const to define an SQS termination kind of interruption event SQSTerminateKind = "SQS_TERMINATE" ) diff --git a/pkg/node/node.go b/pkg/node/node.go index ffd04bb5..2b62e768 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -31,7 +31,6 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/kubernetes" - "k8s.io/client-go/rest" "k8s.io/kubectl/pkg/drain" ) @@ -84,8 +83,8 @@ type Node struct { } // New will construct a node struct to perform various node function through the kubernetes api server -func New(nthConfig config.Config) (*Node, error) { - drainHelper, err := getDrainHelper(nthConfig) +func New(nthConfig config.Config, clientset *kubernetes.Clientset) (*Node, error) { + drainHelper, err := getDrainHelper(nthConfig, clientset) if err != nil { return nil, err } @@ -634,7 +633,7 @@ func (n Node) fetchAllPods(nodeName string) (*corev1.PodList, error) { }) } -func getDrainHelper(nthConfig config.Config) (*drain.Helper, error) { +func getDrainHelper(nthConfig config.Config, clientset *kubernetes.Clientset) (*drain.Helper, error) { drainHelper := &drain.Helper{ Ctx: context.TODO(), Client: &kubernetes.Clientset{}, @@ -652,17 +651,7 @@ func getDrainHelper(nthConfig config.Config) (*drain.Helper, error) { return drainHelper, nil } - clusterConfig, err := rest.InClusterConfig() - if err != nil { - return nil, err - } - // creates the clientset - clientset, err := kubernetes.NewForConfig(clusterConfig) - if err != nil { - return nil, err - } drainHelper.Client = clientset - return drainHelper, nil } diff --git a/pkg/node/node_test.go b/pkg/node/node_test.go index 93496872..945b98af 100644 --- a/pkg/node/node_test.go +++ b/pkg/node/node_test.go @@ -63,8 +63,13 @@ func getNode(t *testing.T, drainHelper *drain.Helper) *node.Node { return tNode } +func newNode(nthConfig config.Config, client *fake.Clientset) (*node.Node, error) { + drainHelper := getDrainHelper(client) + return node.NewWithValues(nthConfig, drainHelper, uptime.Uptime) +} + func TestDryRun(t *testing.T) { - tNode, err := node.New(config.Config{DryRun: true}) + tNode, err := newNode(config.Config{DryRun: true}, fake.NewSimpleClientset()) h.Ok(t, err) fakeRecorder := record.NewFakeRecorder(recorderBufferSize) @@ -103,7 +108,8 @@ func TestDryRun(t *testing.T) { } func TestNewFailure(t *testing.T) { - _, err := node.New(config.Config{}) + client := fake.NewSimpleClientset() + _, err := newNode(config.Config{}, client) h.Assert(t, true, "Failed to return error when creating new Node.", err != nil) } diff --git a/pkg/observability/k8s-events.go b/pkg/observability/k8s-events.go index a3da3778..6b7caf25 100644 --- a/pkg/observability/k8s-events.go +++ b/pkg/observability/k8s-events.go @@ -27,7 +27,6 @@ import ( "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/scheme" typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" - "k8s.io/client-go/rest" "k8s.io/client-go/tools/record" ) @@ -80,7 +79,7 @@ type K8sEventRecorder struct { } // InitK8sEventRecorder creates a Kubernetes event recorder -func InitK8sEventRecorder(enabled bool, nodeName string, sqsMode bool, nodeMetadata ec2metadata.NodeMetadata, extraAnnotationsStr string) (K8sEventRecorder, error) { +func InitK8sEventRecorder(enabled bool, nodeName string, sqsMode bool, nodeMetadata ec2metadata.NodeMetadata, extraAnnotationsStr string, clientSet *kubernetes.Clientset) (K8sEventRecorder, error) { if !enabled { return K8sEventRecorder{}, nil } @@ -107,16 +106,6 @@ func InitK8sEventRecorder(enabled bool, nodeName string, sqsMode bool, nodeMetad } } - config, err := rest.InClusterConfig() - if err != nil { - return K8sEventRecorder{}, err - } - - clientSet, err := kubernetes.NewForConfig(config) - if err != nil { - return K8sEventRecorder{}, err - } - broadcaster := record.NewBroadcaster() broadcaster.StartRecordingToSink(&typedcorev1.EventSinkImpl{Interface: clientSet.CoreV1().Events("")}) diff --git a/test/e2e/asg-launch-lifecycle-sqs-test b/test/e2e/asg-launch-lifecycle-sqs-test new file mode 100755 index 00000000..42f21c98 --- /dev/null +++ b/test/e2e/asg-launch-lifecycle-sqs-test @@ -0,0 +1,509 @@ +#!/bin/bash +set -euo pipefail + +node_group_name="linux-ng" +sqs_queue_name="nth-sqs-test" +sns_topic_name="nth-sns-test" +node_policy_name="nth-test-node-policy" +auto_scaling_role_name="AWSServiceRoleForAutoScaling_nth-test" +fis_role_name="nth-test-fis-role" +fis_template_name="nth-fis-test" +fis_policy_arn="arn:aws:iam::aws:policy/service-role/AWSFaultInjectionSimulatorEC2Access" +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +account_id=$(aws sts get-caller-identity | jq -r '.Account') +nth_label="Use-Case=NTH" +heartbeat_timeout=$((3 * 60)) +LAUNCH_CHECK_CYCLES=15 +LAUNCH_ACTIVITY_CHECK_SLEEP=15 +LAUNCH_STATUS_CHECK_SLEEP=$((heartbeat_timeout / LAUNCH_CHECK_CYCLES)) + +##### JSON FILES ##### + +### SQS ### +sqs_queue_policy=$(cat < /tmp/sqs-subscription-policy.json +{ + "Policy": "{\"Version\":\"2012-10-17\",\"Id\":\"MyQueuePolicy\",\"Statement\":[{\"Effect\":\"Allow\",\"Principal\":{\"Service\":[\"events.amazonaws.com\",\"sqs.amazonaws.com\"]},\"Action\":\"sqs:SendMessage\",\"Resource\":\"arn:aws:sqs:${REGION}:${account_id}:${sqs_queue_name}\"},{\"Sid\":\"topic-subscription-arn:aws:sns:${REGION}:${account_id}:${sns_topic_name}\",\"Effect\":\"Allow\",\"Principal\":{\"AWS\":\"*\"},\"Action\":\"SQS:SendMessage\",\"Resource\":\"arn:aws:sqs:${REGION}:${account_id}:${sqs_queue_name}\",\"Condition\":{\"ArnLike\":{\"aws:SourceArn\":\"arn:aws:sns:${REGION}:${account_id}:${sns_topic_name}\"}}}]}" +} +EOF + +cat << EOF > /tmp/queue-attributes.json +{ + "MessageRetentionPeriod": "300", + "Policy": "$(echo $sqs_queue_policy | sed 's/\"/\\"/g' | tr -d -s '\n' " ")", + "SqsManagedSseEnabled": "true" +} +EOF + +### NODEGROUP ### +cat << EOF > /tmp/nth-nodegroup-policy.json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "autoscaling:CompleteLifecycleAction", + "autoscaling:DescribeAutoScalingInstances", + "autoscaling:DescribeTags", + "ec2:DescribeInstances", + "sqs:DeleteMessage", + "sqs:ReceiveMessage" + ], + "Resource": "*" + } + ] +} +EOF + +### FIS ### +cat << EOF > /tmp/fis-role-trust-policy.json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "Service": [ + "fis.amazonaws.com" + ] + }, + "Action": "sts:AssumeRole" + } + ] +} +EOF + +function create_FIS_Template_JSON { +cat << EOF > /tmp/fis-experiment-template.json +{ + "description": "Test Spot Instance interruptions", + "targets": { + "oneSpotInstance": { + "resourceType": "aws:ec2:spot-instance", + "resourceTags": { + "Name": "interruptMe" + }, + "filters": [ + { + "path": "State.Name", + "values": [ + "running" + ] + } + ], + "selectionMode": "COUNT(1)" + } + }, + "actions": { + "interruptSpotInstance": { + "actionId": "aws:ec2:send-spot-instance-interruptions", + "parameters": { + "durationBeforeInterruption": "PT2M" + }, + "targets": { + "SpotInstances": "oneSpotInstance" + } + } + }, + "stopConditions": [ + { + "source": "none" + } + ], + "roleArn": "$fis_role_arn", + "tags": { + "Name": "$fis_template_name" + } +} +EOF +} + + +##### SETUP ##### + +function validate_aws_account { + if [[ -n "$account_id" ]]; then + echo "๐Ÿฅ‘ AWS Account ID: $account_id" + else + echo "โŒ Failed to retrieve AWS Account ID โŒ" + exit 1 + fi +} + +### SQS ### +function provision_sqs_queue { + queue_exists=$(aws sqs list-queues --queue-name-prefix $sqs_queue_name) + if [[ -z $queue_exists ]]; then + echo "๐Ÿฅ‘ Provisioning SQS Queue" + queue_url=$(aws sqs create-queue --queue-name "${sqs_queue_name}" --attributes file:///tmp/queue-attributes.json | jq -r .QueueUrl) + else + echo "๐Ÿฅ‘ $sqs_queue_name already exists; continuing with test run" + queue_url=$(aws sqs list-queues --queue-name-prefix $sqs_queue_name | jq -r '.QueueUrls | .[0]') + fi + sqs_arn=$(aws sqs get-queue-attributes --queue-url=$queue_url --attribute-names=QueueArn | jq -r .Attributes.QueueArn) + aws sqs set-queue-attributes --queue-url $queue_url --attributes file:///tmp/sqs-subscription-policy.json +} + +### SNS ### +function provision_sns_topic { + topic_exists=$(aws sns list-topics | grep "$sns_topic_name" || :) + if [[ -z $topic_exists ]]; then + echo "๐Ÿฅ‘ Provisioning SNS Topic" + sns_arn=$(aws sns create-topic --name $sns_topic_name | jq -r .TopicArn) + else + echo "๐Ÿฅ‘ $sns_topic_name already exists; continuing with test run" + sns_arn=$(aws sns list-topics | jq -r '.Topics | .[].TopicArn' | grep "$sns_topic_name") + fi +} + +function subscribe_sqs_to_sns { + num_subscriptions=$(aws sns list-subscriptions-by-topic --topic-arn $sns_arn | jq '.Subscriptions | length') + if [[ $num_subscriptions -eq 0 ]]; then + echo "๐Ÿฅ‘ Subscribing $sns_topic_name to $sqs_queue_name" + subscription_arn=$(aws sns subscribe --topic-arn $sns_arn --protocol sqs --notification-endpoint $sqs_arn | jq -r .SubscriptionArn) + else + echo "๐Ÿฅ‘ $sns_topic_name already subscribed to $sqs_queue_name; continuing with test run" + subscription_arn=$(aws sns list-subscriptions-by-topic --topic-arn $sns_arn | jq -r '.Subscriptions | .[0].SubscriptionArn') + fi +} + +### NODEGROUP ### +function update_node_group { + create_node_policy + + echo "๐Ÿฅ‘ Attaching Node policy to Node role" + get_node_role_name + aws iam attach-role-policy --role-name $node_role_name --policy-arn $node_policy_arn + + update_ASG + set_node_data + kubectl label nodes $nth_node_ip $nth_label +} + +function create_node_policy { + node_policy_exists=$(aws iam list-policies | grep "$node_policy_name" || :) + if [[ -z $node_policy_exists ]]; then + echo "๐Ÿฅ‘ Creating Node policy" + node_policy_arn=$(aws iam create-policy --policy-name $node_policy_name --policy-document file:///tmp/nth-nodegroup-policy.json | jq -r .Policy.Arn) + else + echo "๐Ÿฅ‘ $node_policy_name already exists; continuing with test run" + node_policy_arn=$(aws iam list-policies | jq -r --arg policy_name $node_policy_name '.Policies | .[] | select(.PolicyName | contains($policy_name)) | .Arn') + fi + + sleep 10 +} + +function get_node_role_name { + node_role_arn=$(aws eks describe-nodegroup --cluster-name $CLUSTER_NAME --nodegroup-name $node_group_name | jq -r .nodegroup.nodeRole) + IFS="/" read -r -a node_role_arn_array <<< "$node_role_arn" + node_role_name=${node_role_arn_array[1]} +} + +function set_node_data { + instance_ids=$(aws autoscaling describe-auto-scaling-groups --auto-scaling-group-names $asg_name | jq -r '.AutoScalingGroups | .[0].Instances | .[].InstanceId') + instance_data=$(aws ec2 describe-instances --instance-ids $instance_ids | jq -r '[.Reservations | .[].Instances | .[].InstanceId, .[].PrivateDnsName]') + + nth_node_ip=$(jq -r '.[1]' <<< $instance_data) + termination_node_id=$(jq -r '.[2]' <<< $instance_data) +} + +function update_ASG { + asg_name=$(eksctl get nodegroup --cluster=$CLUSTER_NAME --name=$node_group_name --output=json | jq -r '.[0].AutoScalingGroupName') + + echo "๐Ÿฅ‘ Setting Capacity Rebalance" + aws autoscaling update-auto-scaling-group --auto-scaling-group-name $asg_name --capacity-rebalance + echo "๐Ÿฅ‘ Tagging ASG" + aws autoscaling create-or-update-tags --tags ResourceId=$asg_name,ResourceType=auto-scaling-group,Key=aws-node-termination-handler/managed,Value=,PropagateAtLaunch=true + + create_auto_scaling_role + echo "๐Ÿฅ‘ Creating Lifecycle Hooks" + aws autoscaling put-lifecycle-hook \ + --lifecycle-hook-name "Launch-LC-Hook" \ + --auto-scaling-group-name $asg_name \ + --lifecycle-transition="autoscaling:EC2_INSTANCE_LAUNCHING" \ + --heartbeat-timeout=$heartbeat_timeout \ + --notification-target-arn=$sns_arn \ + --role-arn=$auto_scaling_role_arn \ + --default-result="ABANDON" + aws autoscaling put-lifecycle-hook \ + --lifecycle-hook-name "Terminate-LC-Hook" \ + --auto-scaling-group-name $asg_name \ + --lifecycle-transition="autoscaling:EC2_INSTANCE_TERMINATING" \ + --heartbeat-timeout=$heartbeat_timeout \ + --notification-target-arn=$sns_arn \ + --role-arn=$auto_scaling_role_arn \ + --default-result="CONTINUE" +} + +function create_auto_scaling_role { + auto_scaling_role_exists=$(aws iam get-role --role-name=$auto_scaling_role_name 2> /dev/null | grep "$auto_scaling_role_name" || :) + if [[ -z $auto_scaling_role_exists ]]; then + echo "๐Ÿฅ‘ Creating Auto Scaling Role" + auto_scaling_role_arn=$(aws iam create-service-linked-role --aws-service-name autoscaling.amazonaws.com --custom-suffix "nth-test" | jq -r '.Role.Arn') + sleep 10 + else + echo "๐Ÿฅ‘ $auto_scaling_role_name already exists; continuing with test run" + auto_scaling_role_arn=$(aws iam get-role --role-name=$auto_scaling_role_name 2> /dev/null | jq -r '.Role.Arn') + fi +} + +### HELM ### +function install_helm { + get_aws_credentials + + anth_helm_args=( + upgrade + --install + --namespace kube-system + "$CLUSTER_NAME-acth" + "$SCRIPTPATH/../../config/helm/aws-node-termination-handler/" + --set image.repository="$NODE_TERMINATION_HANDLER_DOCKER_REPO" + --set image.tag="$NODE_TERMINATION_HANDLER_DOCKER_TAG" + --set image.pullPolicy="Always" + --set nodeSelector."${nth_label}" + --set tolerations[0].operator=Exists + --set awsAccessKeyID="$aws_access_key_id" + --set awsSecretAccessKey="$aws_secret_access_key" + --set awsRegion="${REGION}" + --set checkTagBeforeDraining=false + --set enableSqsTerminationDraining=true + --set queueURL="${queue_url}" + --wait + ) + + set -x + helm "${anth_helm_args[@]}" + set +x + + sleep 15 +} + +function get_aws_credentials { + echo "๐Ÿฅ‘ Retrieving AWS Credentials" + aws_access_key_id=$(aws --profile default configure get aws_access_key_id 2> /dev/null) + if [[ -z $aws_access_key_id ]]; then + echo "โŒ Failed to retrieve AWS Access Key โŒ" + exit 1 + fi + + aws_secret_access_key=$(aws --profile default configure get aws_secret_access_key 2> /dev/null) + if [[ -z $aws_access_key_id ]]; then + echo "โŒ Failed to retrieve AWS Secret Access Key โŒ" + exit 1 + fi +} + +### FIS ### +function create_FIS_role { + fis_role_exists=$(aws iam get-role --role-name $fis_role_name 2> /dev/null | grep "$fis_role_name" || :) + if [[ -z $fis_role_exists ]]; then + echo "๐Ÿฅ‘ Creating FIS Role" + fis_role_arn=$(aws iam create-role --role-name $fis_role_name --assume-role-policy-document file:///tmp/fis-role-trust-policy.json | jq -r '.Role.Arn') + aws iam attach-role-policy --role-name $fis_role_name --policy-arn $fis_policy_arn + sleep 10 + else + echo "๐Ÿฅ‘ $fis_role_name already exists; continuing with test run" + fis_role_arn=$(aws iam get-role --role-name=$fis_role_name 2> /dev/null | jq -r '.Role.Arn') + fi +} + +function create_experiment_template { + experiment_exists=$(aws fis list-experiment-templates | grep "$fis_template_name" || :) + if [[ -z $experiment_exists ]]; then + create_FIS_Template_JSON + echo "๐Ÿฅ‘ Creating experiment template" + template_id=$(aws fis create-experiment-template --cli-input-json file:///tmp/fis-experiment-template.json | jq -r .experimentTemplate.id) + else + template_id=$(aws fis list-experiment-templates | jq -r --arg template_name $fis_template_name '.experimentTemplates | .[] | select(.tags | has("Name")) | select(.tags.Name | contains($template_name)) | .id') + echo "๐Ÿฅ‘ $fis_template_name already exists; continuing with test run" + fi +} + +function create_tags { + echo "๐Ÿฅ‘ Creating instance tags" + instance_id_string=$(tr '\n' ' ' <<< ${instance_ids}) + eval 'aws ec2 create-tags --resources'" $instance_id_string "'--tags 'Key="aws-node-termination-handler/managed",Value='' + aws ec2 create-tags --resources "${termination_node_id}" --tags Key=Name,Value=interruptMe +} + +function start_FIS_experiment { + create_tags + create_FIS_role + create_experiment_template + echo "๐Ÿฅ‘ Starting Experiment" + experiment_start_time=$(date +%s) + aws fis start-experiment --experiment-template-id $template_id > /dev/null +} + + +##### TESTING ##### +function convert_date_to_epoch_seconds { + IFS='T' read -r date_part time_part <<< "$1" + IFS='-' read -r year month day <<< "$date_part" + IFS=':' read -r hour minute second_fractional <<< "$time_part" + IFS='.' read -r -a seconds_array <<< "$second_fractional" + IFS=':' read -r offset_hours offset_minutes <<< "${time_part:16:5}" + + # Convert time strings to base-10 integers + year=$((10#$year + 0)); month=$((10#$month + 0)); day=$((10#$day + 0)) + hour=$((10#$hour + 0)); minute=$((10#$minute + 0)); second=$((10#${seconds_array[0]} + 0)) + offset_hours=$((10#$offset_hours + 0)); offset_minutes=$((10#$offset_minutes + 0)) + + if [[ $time_part =~ .*"-".* ]]; then + offset_hours=$((offset_hours * -1)) + offset_minutes=$((offset_minutes * -1)) + fi + + total_days=$(((year - 1970) * 365 + (year - 1970)/4)) + for ((k = 1; k < month; k++)); do + total_days=$((total_days + $(cal $k $year | awk 'NF {DAYS = $NF} END {print DAYS}'))) + done + total_days=$((total_days + day - 1)) + total_seconds=$((total_days * 86400 + (hour + offset_hours) * 3600 + (minute + offset_minutes) * 60 + second)) +} + +function get_launch_activity { + echo "๐Ÿฅ‘ Finding launch activity " + launch_activity="" + for i in $(seq 1 $LAUNCH_CHECK_CYCLES); do + activities=$(aws autoscaling describe-scaling-activities --auto-scaling-group-name $asg_name) + activities_details=$(jq -r '[.Activities | .[] | .ActivityId, .Description, .StartTime]' <<< $activities) + num_activities=$(jq -r 'length' <<< $activities_details) + for j in $(seq 0 3 $((--num_activities))); do + id=$(jq -r .[$j] <<< $activities_details) + description=$(jq -r .[$((++j))] <<< $activities_details) + start=$(jq -r .[$((j+=2))] <<< $activities_details) + activity_instance=${description##*:} + convert_date_to_epoch_seconds $start + if [[ $description =~ .*"Launching".* && $total_seconds -gt $experiment_start_time ]]; then + launch_activity=$id + break 2 + fi + done + + echo "Setup Loop $i/$LAUNCH_CHECK_CYCLES, sleeping for $LAUNCH_ACTIVITY_CHECK_SLEEP seconds" + sleep $LAUNCH_ACTIVITY_CHECK_SLEEP + done + + if [[ -n $launch_activity ]]; then + echo "โœ… Launch Activity found for instance $activity_instance" + else + echo "โŒ Failed to find a new launched instance โŒ" + exit 1 + fi +} + +function test_launch_lifecycle { + echo "๐Ÿฅ‘ Verifying launch hook completion " + for i in $(seq 1 $LAUNCH_CHECK_CYCLES); do + activity_status=$(aws autoscaling describe-scaling-activities --auto-scaling-group-name $asg_name --activity-ids $launch_activity | jq -r '.Activities | .[].StatusCode') + if [[ $activity_status == "Successful" ]]; then + echo "" + echo "โœ… Launch Lifecycle Successfully Completed โœ…" + exit 0 + elif [[ $activity_status == "Cancelled" || $activity_status == "Failed" ]]; then + echo "" + echo "โŒ Launch Lifecycle $activity_status โŒ" + exit 1 + fi + + echo "Assertion Loop $i/$LAUNCH_CHECK_CYCLES, sleeping for $LAUNCH_STATUS_CHECK_SLEEP seconds" + sleep $LAUNCH_STATUS_CHECK_SLEEP + done + + echo "โŒ Failed to verify launch hook completion โŒ" + exit 1 +} + + +##### CLEAN UP ##### +function clean_up { + echo "=====================================================================================================" + echo "๐Ÿงน Cleaning up SQS, SNS, NodeGroup, IAM, FIS ๐Ÿงน" + echo "=====================================================================================================" + print_logs + uninstall_helm + delete_node_group_policy + if [[ -n $subscription_arn ]]; then + echo "๐Ÿฅ‘ Unsubscribing SNS from SQS" + aws sns unsubscribe --subscription-arn $subscription_arn + fi + if [[ -n $queue_url ]]; then + echo "๐Ÿฅ‘ Deleting SQS queue" + aws sqs delete-queue --queue-url $queue_url + fi + if [[ -n $sns_arn ]]; then + echo "๐Ÿฅ‘ Deleting SNS topic" + aws sns delete-topic --topic-arn $sns_arn + fi + if [[ -n $template_id ]]; then + echo "๐Ÿฅ‘ Deleting FIS experiment template" + aws fis delete-experiment-template --id $template_id --no-paginate > /dev/null + fi + echo "๐Ÿฅ‘ Detaching FIS role policy" + aws iam detach-role-policy --role-name $fis_role_name --policy-arn $fis_policy_arn + echo "๐Ÿฅ‘ Deleting FIS role" + aws iam delete-role --role-name $fis_role_name + echo "๐Ÿฅ‘ Deleting autoscaling role" + aws iam delete-service-linked-role --role-name $auto_scaling_role_name > /dev/null + if [[ -n $node_policy_arn ]]; then + echo "๐Ÿฅ‘ Deleting Node role policy" + aws iam delete-policy --policy-arn $node_policy_arn + fi +} + +function print_logs { + pod_id=$(get_nth_worker_pod || :) + if [[ -n $pod_id ]]; then + kubectl logs $pod_id --namespace kube-system || : + else + echo "โŒ Failed to get pod ID. Unable to print logs โŒ" + fi +} + +function uninstall_helm { + helm_exists=$(helm ls -A | grep "$CLUSTER_NAME-acth") + if [[ -n $helm_exists ]]; then + echo "๐Ÿฅ‘ Uninstalling NTH helm chart" + helm uninstall "$CLUSTER_NAME-acth" -n kube-system + fi +} + +function delete_node_group_policy { + if [[ -z $node_role_name || -z $node_policy_name ]]; then return; fi + + node_policy_exists=$(aws iam list-attached-role-policies --role-name $node_role_name | grep "$node_policy_name" || :) + if [[ -n $node_policy_exists ]]; then + echo "๐Ÿฅ‘ Detaching NTH Node Group policy" + aws iam detach-role-policy --role-name $node_role_name --policy-arn $node_policy_arn + fi +} + +trap "clean_up" EXIT +validate_aws_account +provision_sqs_queue +provision_sns_topic +subscribe_sqs_to_sns +update_node_group +install_helm +start_FIS_experiment +get_launch_activity +test_launch_lifecycle diff --git a/test/eks-cluster-test/run-test b/test/eks-cluster-test/run-test index 9d6a7f77..949f68d1 100755 --- a/test/eks-cluster-test/run-test +++ b/test/eks-cluster-test/run-test @@ -205,6 +205,8 @@ if [[ -z ${assertion_scripts+x} ]]; then #"$SCRIPTPATH/../e2e/webhook-http-proxy-test" #"$SCRIPTPATH/../e2e/webhook-secret-test" "$SCRIPTPATH/../e2e/webhook-test" + # This test terminates nodes in the cluster and needs to be run last + "$SCRIPTPATH/../e2e/asg-launch-lifecycle-sqs-test" ) fi diff --git a/test/k8s-local-cluster-test/run-test b/test/k8s-local-cluster-test/run-test index b21b2c64..be1e243c 100755 --- a/test/k8s-local-cluster-test/run-test +++ b/test/k8s-local-cluster-test/run-test @@ -23,6 +23,9 @@ AEMM_DL_URL="https://github.com/aws/amazon-ec2-metadata-mock/releases/download/v WEBHOOK_URL=${WEBHOOK_URL:="http://webhook-test-proxy.default.svc.cluster.local"} ASSERTION_SCRIPTS=$(find "$SCRIPTPATH/../e2e" -type f | sort) +SCRIPT_DENYLIST=( + "$SCRIPTPATH/../e2e/asg-launch-lifecycle-sqs-test" +) function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } @@ -271,8 +274,17 @@ kubectl label node "${CLUSTER_NAME}-worker" "$(echo $NTH_WORKER_LABEL | tr -d '\ ## Mark worker2 only for Critical Add-Ons like dns kubectl taint node "${CLUSTER_NAME}-worker2" CriticalAddonsOnly=true:NoSchedule --overwrite +function is_denylisted { + if [[ ${SCRIPT_DENYLIST[*]} =~ (^|[[:space:]])$1($|[[:space:]]) ]]; then + return 0 + fi + return 1 +} + i=0 for assert_script in $ASSERTION_SCRIPTS; do + if is_denylisted $assert_script; then continue; fi + reset_cluster START_FOR_QUERYING=$(date -u +"%Y-%m-%dT%TZ") IMDS_PORT=$((i + 1338))