From 4e6d762c36e680b96d7b344b81002982861ca431 Mon Sep 17 00:00:00 2001 From: Fae Charlton Date: Thu, 2 May 2024 22:57:33 -0400 Subject: [PATCH] Fix sqs region selection (#39327) Fix an error in region selection that was introduced in a previous cleanup (https://github.com/elastic/beats/pull/38958). When the configured region disagrees with the region detected from the queue URL, the configured region is supposed to take precedence. Due to a misreading, my code instead chose the URL region when there is a conflict. I've broken region selection out into another helper function to make this logic easier to test, and added several unit test cases that would have caught this mistake. --- x-pack/filebeat/input/awss3/input.go | 69 ++++++++++++++--------- x-pack/filebeat/input/awss3/input_test.go | 68 ++++++++++++++++------ 2 files changed, 92 insertions(+), 45 deletions(-) diff --git a/x-pack/filebeat/input/awss3/input.go b/x-pack/filebeat/input/awss3/input.go index 51e8c9808edb..6c9c202b1f54 100644 --- a/x-pack/filebeat/input/awss3/input.go +++ b/x-pack/filebeat/input/awss3/input.go @@ -24,6 +24,7 @@ import ( "github.com/elastic/beats/v7/libbeat/feature" awscommon "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" conf "github.com/elastic/elastic-agent-libs/config" + "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/go-concert/unison" ) @@ -117,17 +118,12 @@ func (in *s3Input) runQueueReader( inputContext v2.Context, pipeline beat.Pipeline, ) error { - configRegion := in.config.RegionName - urlRegion, err := getRegionFromQueueURL(in.config.QueueURL, in.config.AWSConfig.Endpoint) - if err != nil && configRegion == "" { - // Only report an error if we don't have a configured region - // to fall back on. - return fmt.Errorf("failed to get AWS region from queue_url: %w", err) - } else if configRegion != "" && configRegion != urlRegion { - inputContext.Logger.Warnf("configured region disagrees with queue_url region (%q != %q): using %q", configRegion, urlRegion, urlRegion) + // Set awsConfig.Region based on the config and queue URL + region, err := chooseRegion(inputContext.Logger, in.config) + if err != nil { + return err } - - in.awsConfig.Region = urlRegion + in.awsConfig.Region = region // Create SQS receiver and S3 notification processor. receiver, err := in.createSQSReceiver(inputContext, pipeline) @@ -326,32 +322,51 @@ func (in *s3Input) createS3Poller(ctx v2.Context, cancelCtx context.Context, cli var errBadQueueURL = errors.New("QueueURL is not in format: https://sqs.{REGION_ENDPOINT}.{ENDPOINT}/{ACCOUNT_NUMBER}/{QUEUE_NAME} or https://{VPC_ENDPOINT}.sqs.{REGION_ENDPOINT}.vpce.{ENDPOINT}/{ACCOUNT_NUMBER}/{QUEUE_NAME}") -func getRegionFromQueueURL(queueURL, endpoint string) (string, error) { +func chooseRegion(log *logp.Logger, config config) (string, error) { + urlRegion := getRegionFromQueueURL(config.QueueURL, config.AWSConfig.Endpoint) + if config.RegionName != "" { + // If a region is configured, that takes precedence over the URL. + if log != nil && config.RegionName != urlRegion { + log.Warnf("configured region disagrees with queue_url region (%q != %q): using %q", config.RegionName, urlRegion, config.RegionName) + } + return config.RegionName, nil + } + if urlRegion != "" { + // If no region is configured, fall back on the URL. + return urlRegion, nil + } + // If we can't get the region from the config or the URL, report an error. + return "", fmt.Errorf("failed to get AWS region from queue_url: %w", errBadQueueURL) +} + +// getRegionFromQueueURL returns the region from standard queue URLs, or the +// empty string if it couldn't be determined. +func getRegionFromQueueURL(queueURL, endpoint string) string { // get region from queueURL // Example for sqs queue: https://sqs.us-east-1.amazonaws.com/12345678912/test-s3-logs // Example for vpce: https://vpce-test.sqs.us-east-1.vpce.amazonaws.com/12345678912/sqs-queue u, err := url.Parse(queueURL) if err != nil { - return "", fmt.Errorf(queueURL + " is not a valid URL") - } - if (u.Scheme == "https" || u.Scheme == "http") && u.Host != "" { - queueHostSplit := strings.SplitN(u.Host, ".", 3) - // check for sqs queue url - if len(queueHostSplit) == 3 && queueHostSplit[0] == "sqs" { - if queueHostSplit[2] == endpoint || (endpoint == "" && strings.HasPrefix(queueHostSplit[2], "amazonaws.")) { - return queueHostSplit[1], nil - } + return "" + } + + // check for sqs queue url + host := strings.SplitN(u.Host, ".", 3) + if len(host) == 3 && host[0] == "sqs" { + if host[2] == endpoint || (endpoint == "" && strings.HasPrefix(host[2], "amazonaws.")) { + return host[1] } + } - // check for vpce url - queueHostSplitVPC := strings.SplitN(u.Host, ".", 5) - if len(queueHostSplitVPC) == 5 && queueHostSplitVPC[1] == "sqs" { - if queueHostSplitVPC[4] == endpoint || (endpoint == "" && strings.HasPrefix(queueHostSplitVPC[4], "amazonaws.")) { - return queueHostSplitVPC[2], nil - } + // check for vpce url + host = strings.SplitN(u.Host, ".", 5) + if len(host) == 5 && host[1] == "sqs" { + if host[4] == endpoint || (endpoint == "" && strings.HasPrefix(host[4], "amazonaws.")) { + return host[2] } } - return "", errBadQueueURL + + return "" } func getRegionForBucket(ctx context.Context, s3Client *s3.Client, bucketName string) (string, error) { diff --git a/x-pack/filebeat/input/awss3/input_test.go b/x-pack/filebeat/input/awss3/input_test.go index 0a3053f7f1b9..c76e939424f7 100644 --- a/x-pack/filebeat/input/awss3/input_test.go +++ b/x-pack/filebeat/input/awss3/input_test.go @@ -5,8 +5,10 @@ package awss3 import ( + "errors" "testing" + aws "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" "github.com/stretchr/testify/assert" ) @@ -51,23 +53,36 @@ func TestGetProviderFromDomain(t *testing.T) { func TestGetRegionFromQueueURL(t *testing.T) { tests := []struct { - name string - queueURL string - endpoint string - want string - wantErr error + name string + queueURL string + regionName string + endpoint string + want string + wantErr error }{ { name: "amazonaws.com_domain_with_blank_endpoint", queueURL: "https://sqs.us-east-1.amazonaws.com/627959692251/test-s3-logs", want: "us-east-1", }, + { + name: "amazonaws.com_domain_with_region_override", + queueURL: "https://sqs.us-east-1.amazonaws.com/627959692251/test-s3-logs", + regionName: "us-east-2", + want: "us-east-2", + }, { name: "abc.xyz_and_domain_with_matching_endpoint", queueURL: "https://sqs.us-east-1.abc.xyz/627959692251/test-s3-logs", endpoint: "abc.xyz", want: "us-east-1", }, + { + name: "abc.xyz_with_region_override", + queueURL: "https://sqs.us-east-1.abc.xyz/627959692251/test-s3-logs", + regionName: "us-west-3", + want: "us-west-3", + }, { name: "abc.xyz_and_domain_with_blank_endpoint", queueURL: "https://sqs.us-east-1.abc.xyz/627959692251/test-s3-logs", @@ -78,18 +93,46 @@ func TestGetRegionFromQueueURL(t *testing.T) { queueURL: "https://vpce-test.sqs.us-east-2.vpce.amazonaws.com/12345678912/sqs-queue", want: "us-east-2", }, + { + name: "vpce_endpoint_with_region_override", + queueURL: "https://vpce-test.sqs.us-east-2.vpce.amazonaws.com/12345678912/sqs-queue", + regionName: "us-west-1", + want: "us-west-1", + }, { name: "vpce_endpoint_with_endpoint", queueURL: "https://vpce-test.sqs.us-east-1.vpce.amazonaws.com/12345678912/sqs-queue", endpoint: "amazonaws.com", want: "us-east-1", }, + { + name: "non_aws_vpce_with_endpoint", + queueURL: "https://vpce-test.sqs.us-east-1.vpce.abc.xyz/12345678912/sqs-queue", + endpoint: "abc.xyz", + want: "us-east-1", + }, + { + name: "non_aws_vpce_without_endpoint", + queueURL: "https://vpce-test.sqs.us-east-1.vpce.abc.xyz/12345678912/sqs-queue", + wantErr: errBadQueueURL, + }, + { + name: "non_aws_vpce_with_region_override", + queueURL: "https://vpce-test.sqs.us-east-1.vpce.abc.xyz/12345678912/sqs-queue", + regionName: "us-west-1", + want: "us-west-1", + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := getRegionFromQueueURL(test.queueURL, test.endpoint) - if !sameError(err, test.wantErr) { + config := config{ + QueueURL: test.queueURL, + RegionName: test.regionName, + AWSConfig: aws.ConfigAWS{Endpoint: test.endpoint}, + } + got, err := chooseRegion(nil, config) + if !errors.Is(err, test.wantErr) { t.Errorf("unexpected error: got:%v want:%v", err, test.wantErr) } if got != test.want { @@ -98,14 +141,3 @@ func TestGetRegionFromQueueURL(t *testing.T) { }) } } - -func sameError(a, b error) bool { - switch { - case a == nil && b == nil: - return true - case a == nil, b == nil: - return false - default: - return a.Error() == b.Error() - } -}