diff --git a/data-prepper-plugins/s3-source/build.gradle b/data-prepper-plugins/s3-source/build.gradle index f4afbfbfe3..30bced953d 100644 --- a/data-prepper-plugins/s3-source/build.gradle +++ b/data-prepper-plugins/s3-source/build.gradle @@ -11,6 +11,8 @@ dependencies { implementation project(':data-prepper-api') implementation project(':data-prepper-plugins:buffer-common') implementation project(':data-prepper-plugins:common') + implementation project(':data-prepper-plugins:sqs-common') + implementation libs.armeria.core implementation 'io.micrometer:micrometer-core' implementation 'software.amazon.awssdk:s3' diff --git a/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/s3/SqsServiceIT.java b/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/s3/SqsServiceIT.java index 25cdb2be67..3d1e6343f0 100644 --- a/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/s3/SqsServiceIT.java +++ b/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/s3/SqsServiceIT.java @@ -22,6 +22,7 @@ import org.opensearch.dataprepper.plugins.source.s3.configuration.OnErrorOption; import org.opensearch.dataprepper.plugins.source.s3.configuration.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.source.s3.configuration.SqsOptions; +import org.opensearch.dataprepper.plugins.source.sqs.common.SqsBackoff; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -154,8 +155,7 @@ public void test_sqsService(int numWorkers) throws IOException { } private void clearSqsQueue() { - Backoff backoff = Backoff.exponential(SqsService.INITIAL_DELAY, SqsService.MAXIMUM_DELAY).withJitter(SqsService.JITTER_RATE) - .withMaxAttempts(Integer.MAX_VALUE); + Backoff backoff = SqsBackoff.createExponentialBackoff(); final SqsWorker sqsWorker = new SqsWorker(acknowledgementSetManager, sqsClient, s3Service, s3SourceConfig, pluginMetrics, backoff); //final SqsService objectUnderTest = createObjectUnderTest(); int sqsMessagesProcessed; diff --git a/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorkerIT.java b/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorkerIT.java index 21475930ec..a74e1b35bb 100644 --- a/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorkerIT.java +++ b/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorkerIT.java @@ -27,6 +27,7 @@ import org.opensearch.dataprepper.plugins.source.s3.configuration.NotificationSourceOption; import org.opensearch.dataprepper.plugins.source.s3.configuration.OnErrorOption; import org.opensearch.dataprepper.plugins.source.s3.configuration.SqsOptions; +import org.opensearch.dataprepper.plugins.source.sqs.common.SqsBackoff; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.sqs.SqsClient; @@ -93,8 +94,7 @@ void setUp() { .region(Region.of(System.getProperty("tests.s3source.region"))) .build(); - backoff = Backoff.exponential(SqsService.INITIAL_DELAY, SqsService.MAXIMUM_DELAY).withJitter(SqsService.JITTER_RATE) - .withMaxAttempts(Integer.MAX_VALUE); + backoff = SqsBackoff.createExponentialBackoff(); s3SourceConfig = mock(S3SourceConfig.class); s3Service = mock(S3Service.class); diff --git a/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsService.java b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsService.java index c674be5f68..de35592b90 100644 --- a/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsService.java +++ b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsService.java @@ -12,11 +12,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; -import software.amazon.awssdk.core.retry.RetryPolicy; import software.amazon.awssdk.services.sqs.SqsClient; +import org.opensearch.dataprepper.plugins.source.sqs.common.SqsBackoff; +import org.opensearch.dataprepper.plugins.source.sqs.common.SqsClientFactory; -import java.time.Duration; import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.Executors; @@ -27,10 +26,6 @@ public class SqsService { private static final Logger LOG = LoggerFactory.getLogger(SqsService.class); static final long SHUTDOWN_TIMEOUT = 30L; - static final long INITIAL_DELAY = Duration.ofSeconds(20).toMillis(); - static final long MAXIMUM_DELAY = Duration.ofMinutes(5).toMillis(); - static final double JITTER_RATE = 0.20; - private final S3SourceConfig s3SourceConfig; private final S3Service s3Accessor; private final SqsClient sqsClient; @@ -38,6 +33,7 @@ public class SqsService { private final AcknowledgementSetManager acknowledgementSetManager; private final ExecutorService executorService; private final List sqsWorkers; + private final Backoff backoff; public SqsService(final AcknowledgementSetManager acknowledgementSetManager, final S3SourceConfig s3SourceConfig, @@ -48,11 +44,9 @@ public SqsService(final AcknowledgementSetManager acknowledgementSetManager, this.s3Accessor = s3Accessor; this.pluginMetrics = pluginMetrics; this.acknowledgementSetManager = acknowledgementSetManager; - this.sqsClient = createSqsClient(credentialsProvider); + this.sqsClient = SqsClientFactory.createSqsClient(s3SourceConfig.getAwsAuthenticationOptions().getAwsRegion(), credentialsProvider); executorService = Executors.newFixedThreadPool(s3SourceConfig.getNumWorkers(), BackgroundThreadFactory.defaultExecutorThreadFactory("s3-source-sqs")); - - final Backoff backoff = Backoff.exponential(INITIAL_DELAY, MAXIMUM_DELAY).withJitter(JITTER_RATE) - .withMaxAttempts(Integer.MAX_VALUE); + backoff = SqsBackoff.createExponentialBackoff(); sqsWorkers = IntStream.range(0, s3SourceConfig.getNumWorkers()) .mapToObj(i -> new SqsWorker(acknowledgementSetManager, sqsClient, s3Accessor, s3SourceConfig, pluginMetrics, backoff)) .collect(Collectors.toList()); @@ -62,17 +56,6 @@ public void start() { sqsWorkers.forEach(executorService::submit); } - SqsClient createSqsClient(final AwsCredentialsProvider credentialsProvider) { - LOG.debug("Creating SQS client"); - return SqsClient.builder() - .region(s3SourceConfig.getAwsAuthenticationOptions().getAwsRegion()) - .credentialsProvider(credentialsProvider) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .retryPolicy(RetryPolicy.builder().numRetries(5).build()) - .build()) - .build(); - } - public void stop() { executorService.shutdown(); sqsWorkers.forEach(SqsWorker::stop); diff --git a/data-prepper-plugins/sqs-common/build.gradle b/data-prepper-plugins/sqs-common/build.gradle new file mode 100644 index 0000000000..b4ffbc8e5e --- /dev/null +++ b/data-prepper-plugins/sqs-common/build.gradle @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +plugins { + id 'java' +} + +dependencies { + implementation project(':data-prepper-api') + implementation project(':data-prepper-plugins:buffer-common') + implementation project(':data-prepper-plugins:common') + implementation libs.armeria.core + implementation project(':data-prepper-plugins:aws-plugin-api') + implementation 'software.amazon.awssdk:sqs' + implementation 'software.amazon.awssdk:arns' + implementation 'software.amazon.awssdk:sts' + implementation 'io.micrometer:micrometer-core' + implementation 'com.fasterxml.jackson.core:jackson-annotations' + implementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310' + implementation 'org.hibernate.validator:hibernate-validator:8.0.1.Final' + testImplementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml' + testImplementation project(':data-prepper-plugins:blocking-buffer') +} +test { + useJUnitPlatform() +} diff --git a/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsBackoff.java b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsBackoff.java new file mode 100644 index 0000000000..b5f85cd61d --- /dev/null +++ b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsBackoff.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.dataprepper.plugins.source.sqs.common; + +import com.linecorp.armeria.client.retry.Backoff; +import java.time.Duration; + +public final class SqsBackoff { + private static final long INITIAL_DELAY_MILLIS = Duration.ofSeconds(20).toMillis(); + private static final long MAX_DELAY_MILLIS = Duration.ofMinutes(5).toMillis(); + private static final double JITTER_RATE = 0.20; + + private SqsBackoff() {} + + public static Backoff createExponentialBackoff() { + return Backoff.exponential(INITIAL_DELAY_MILLIS, MAX_DELAY_MILLIS) + .withJitter(JITTER_RATE) + .withMaxAttempts(Integer.MAX_VALUE); + } +} diff --git a/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsClientFactory.java b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsClientFactory.java new file mode 100644 index 0000000000..8754d87749 --- /dev/null +++ b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsClientFactory.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.sqs.common; + +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sqs.SqsClient; + +/** + * A common factory to create SQS clients + */ +public final class SqsClientFactory { + + private SqsClientFactory() { + } + + public static SqsClient createSqsClient( + final Region region, + final AwsCredentialsProvider credentialsProvider) { + + return SqsClient.builder() + .region(region) + .credentialsProvider(credentialsProvider) + .overrideConfiguration(ClientOverrideConfiguration.builder() + .retryPolicy(RetryPolicy.builder().numRetries(5).build()) + .build()) + .build(); + } +} diff --git a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsRetriesExhaustedException.java b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsRetriesExhaustedException.java similarity index 89% rename from data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsRetriesExhaustedException.java rename to data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsRetriesExhaustedException.java index e1fd536cb7..6dd42ce95b 100644 --- a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsRetriesExhaustedException.java +++ b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsRetriesExhaustedException.java @@ -8,7 +8,7 @@ * */ -package org.opensearch.dataprepper.plugins.source.sqs; +package org.opensearch.dataprepper.plugins.source.sqs.common; /** * This exception is thrown when SQS retries are exhausted diff --git a/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommon.java b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommon.java new file mode 100644 index 0000000000..9301574237 --- /dev/null +++ b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommon.java @@ -0,0 +1,212 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.sqs.common; + +import com.linecorp.armeria.client.retry.Backoff; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.Message; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.SqsException; +import software.amazon.awssdk.services.sts.model.StsException; + +import java.time.Duration; +import java.util.Collections; +import java.util.List; + +public class SqsWorkerCommon { + private static final Logger LOG = LoggerFactory.getLogger(SqsWorkerCommon.class); + public static final String ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME = "acknowledgementSetCallbackCounter"; + public static final String SQS_MESSAGES_RECEIVED_METRIC_NAME = "sqsMessagesReceived"; + public static final String SQS_MESSAGES_DELETED_METRIC_NAME = "sqsMessagesDeleted"; + public static final String SQS_MESSAGES_FAILED_METRIC_NAME = "sqsMessagesFailed"; + public static final String SQS_MESSAGES_DELETE_FAILED_METRIC_NAME = "sqsMessagesDeleteFailed"; + public static final String SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME = "sqsVisibilityTimeoutChangedCount"; + public static final String SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME = "sqsVisibilityTimeoutChangeFailedCount"; + + private final SqsClient sqsClient; + private final Backoff standardBackoff; + private final PluginMetrics pluginMetrics; + private final AcknowledgementSetManager acknowledgementSetManager; + private volatile boolean isStopped; + private int failedAttemptCount; + private final Counter sqsMessagesReceivedCounter; + private final Counter sqsMessagesDeletedCounter; + private final Counter sqsMessagesFailedCounter; + private final Counter sqsMessagesDeleteFailedCounter; + private final Counter acknowledgementSetCallbackCounter; + private final Counter sqsVisibilityTimeoutChangedCount; + private final Counter sqsVisibilityTimeoutChangeFailedCount; + + public SqsWorkerCommon(final SqsClient sqsClient, + final Backoff standardBackoff, + final PluginMetrics pluginMetrics, + final AcknowledgementSetManager acknowledgementSetManager) { + + this.sqsClient = sqsClient; + this.standardBackoff = standardBackoff; + this.pluginMetrics = pluginMetrics; + this.acknowledgementSetManager = acknowledgementSetManager; + this.isStopped = false; + this.failedAttemptCount = 0; + + sqsMessagesReceivedCounter = pluginMetrics.counter(SQS_MESSAGES_RECEIVED_METRIC_NAME); + sqsMessagesDeletedCounter = pluginMetrics.counter(SQS_MESSAGES_DELETED_METRIC_NAME); + sqsMessagesFailedCounter = pluginMetrics.counter(SQS_MESSAGES_FAILED_METRIC_NAME); + sqsMessagesDeleteFailedCounter = pluginMetrics.counter(SQS_MESSAGES_DELETE_FAILED_METRIC_NAME); + acknowledgementSetCallbackCounter = pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME); + sqsVisibilityTimeoutChangedCount = pluginMetrics.counter(SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME); + sqsVisibilityTimeoutChangeFailedCount = pluginMetrics.counter(SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME); + } + + public List pollSqsMessages(final String queueUrl, + final Integer maxNumberOfMessages, + final Duration waitTime, + final Duration visibilityTimeout) { + try { + final ReceiveMessageRequest request = createReceiveMessageRequest(queueUrl, maxNumberOfMessages, waitTime, visibilityTimeout); + final List messages = sqsClient.receiveMessage(request).messages(); + failedAttemptCount = 0; + if (!messages.isEmpty()) { + sqsMessagesReceivedCounter.increment(messages.size()); + } + return messages; + } + catch (SqsException | StsException e) { + LOG.error("Error reading from SQS: {}. Retrying with exponential backoff.", e.getMessage()); + applyBackoff(); + return Collections.emptyList(); + } + } + + private ReceiveMessageRequest createReceiveMessageRequest(String queueUrl, Integer maxNumberOfMessages, Duration waitTime, Duration visibilityTimeout) { + ReceiveMessageRequest.Builder requestBuilder = ReceiveMessageRequest.builder() + .queueUrl(queueUrl) + .attributeNamesWithStrings("All") + .messageAttributeNames("All"); + + if (waitTime != null) { + requestBuilder.waitTimeSeconds((int) waitTime.getSeconds()); + } + if (maxNumberOfMessages != null) { + requestBuilder.maxNumberOfMessages(maxNumberOfMessages); + } + if (visibilityTimeout != null) { + requestBuilder.visibilityTimeout((int) visibilityTimeout.getSeconds()); + } + return requestBuilder.build(); + } + + public void applyBackoff() { + final long delayMillis = standardBackoff.nextDelayMillis(++failedAttemptCount); + if (delayMillis < 0) { + Thread.currentThread().interrupt(); + throw new SqsRetriesExhaustedException("SQS retries exhausted. Check your SQS configuration."); + } + + final Duration delayDuration = Duration.ofMillis(delayMillis); + LOG.info("Pausing SQS processing for {}.{} seconds due to an error.", + delayDuration.getSeconds(), delayDuration.toMillisPart()); + + try { + Thread.sleep(delayMillis); + } catch (InterruptedException e) { + LOG.error("Thread interrupted during SQS backoff sleep.", e); + Thread.currentThread().interrupt(); + } + } + + public void deleteSqsMessages(final String queueUrl, final List entries) { + if (entries == null || entries.isEmpty() || isStopped) { + return; + } + + try { + final DeleteMessageBatchRequest request = DeleteMessageBatchRequest.builder() + .queueUrl(queueUrl) + .entries(entries) + .build(); + + final DeleteMessageBatchResponse response = sqsClient.deleteMessageBatch(request); + + if (response.hasSuccessful()) { + final int successCount = response.successful().size(); + sqsMessagesDeletedCounter.increment(successCount); + LOG.debug("Deleted {} messages from SQS queue [{}]", successCount, queueUrl); + } + if (response.hasFailed()) { + final int failCount = response.failed().size(); + sqsMessagesDeleteFailedCounter.increment(failCount); + LOG.error("Failed to delete {} messages from SQS queue [{}].", failCount, queueUrl); + } + } catch (SdkException e) { + sqsMessagesDeleteFailedCounter.increment(entries.size()); + LOG.error("Failed to delete messages from SQS queue [{}]: {}", queueUrl, e.getMessage()); + } + } + + public void increaseVisibilityTimeout(final String queueUrl, + final String receiptHandle, + final int newVisibilityTimeoutSeconds, + final String messageIdForLogging) { + if (isStopped) { + LOG.info("Skipping visibility timeout extension because worker is stopping. ID: {}", messageIdForLogging); + return; + } + + try { + ChangeMessageVisibilityRequest request = ChangeMessageVisibilityRequest.builder() + .queueUrl(queueUrl) + .receiptHandle(receiptHandle) + .visibilityTimeout(newVisibilityTimeoutSeconds) + .build(); + + sqsClient.changeMessageVisibility(request); + sqsVisibilityTimeoutChangedCount.increment(); + LOG.debug("Set visibility timeout for message {} to {} seconds", messageIdForLogging, newVisibilityTimeoutSeconds); + } + catch (Exception e) { + sqsVisibilityTimeoutChangeFailedCount.increment(); + LOG.error("Failed to set visibility timeout for message {} to {}. Reason: {}", + messageIdForLogging, newVisibilityTimeoutSeconds, e.getMessage()); + } + } + + public DeleteMessageBatchRequestEntry buildDeleteMessageBatchRequestEntry(final String messageId, + final String receiptHandle) { + return DeleteMessageBatchRequestEntry.builder() + .id(messageId) + .receiptHandle(receiptHandle) + .build(); + } + + public Timer createTimer(final String timerName) { + return pluginMetrics.timer(timerName); + } + + public Counter getSqsMessagesFailedCounter() { + return sqsMessagesFailedCounter; + } + + public void stop() { + isStopped = true; + } +} diff --git a/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsBackoffTest.java b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsBackoffTest.java new file mode 100644 index 0000000000..b793cf39bc --- /dev/null +++ b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsBackoffTest.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.dataprepper.plugins.source.sqs.common; + +import com.linecorp.armeria.client.retry.Backoff; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class SqsBackoffTest { + + @Test + void testCreateExponentialBackoff() { + final Backoff backoff = SqsBackoff.createExponentialBackoff(); + assertNotNull(backoff, "Backoff should not be null"); + final long firstDelay = backoff.nextDelayMillis(1); + final long expectedBaseDelay = 20_000L; + final double jitterRate = 0.20; + final long minDelay = (long) (expectedBaseDelay * (1 - jitterRate)); + final long maxDelay = (long) (expectedBaseDelay * (1 + jitterRate)); + + assertTrue( + firstDelay >= minDelay && firstDelay <= maxDelay, + String.format("First delay %dms should be between %dms and %dms", + firstDelay, minDelay, maxDelay) + ); + } +} diff --git a/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsClientFactoryTest.java b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsClientFactoryTest.java new file mode 100644 index 0000000000..5f2f64e48e --- /dev/null +++ b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsClientFactoryTest.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.dataprepper.plugins.source.sqs.common; + +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sqs.SqsClient; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +class SqsClientFactoryTest { + + @Test + void testCreateSqsClientReturnsNonNull() { + final StaticCredentialsProvider credentialsProvider = + StaticCredentialsProvider.create(AwsBasicCredentials.create("testKey", "testSecret")); + + final SqsClient sqsClient = SqsClientFactory.createSqsClient(Region.US_EAST_1, credentialsProvider); + assertNotNull(sqsClient, "SqsClient should not be null"); + } +} diff --git a/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommonTest.java b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommonTest.java new file mode 100644 index 0000000000..32cbf3fd69 --- /dev/null +++ b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommonTest.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.sqs.common; + +import com.linecorp.armeria.client.retry.Backoff; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; + +import java.time.Duration; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +class SqsWorkerCommonTest { + private SqsClient sqsClient; + private Backoff backoff; + private PluginMetrics pluginMetrics; + private AcknowledgementSetManager acknowledgementSetManager; + private SqsWorkerCommon sqsWorkerCommon; + + @BeforeEach + void setUp() { + sqsClient = Mockito.mock(SqsClient.class); + backoff = Mockito.mock(Backoff.class); + pluginMetrics = Mockito.mock(PluginMetrics.class); + acknowledgementSetManager = Mockito.mock(AcknowledgementSetManager.class); + when(pluginMetrics.counter(Mockito.anyString())).thenReturn(Mockito.mock(Counter.class)); + when(pluginMetrics.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class)); + sqsWorkerCommon = new SqsWorkerCommon(sqsClient, backoff, pluginMetrics, acknowledgementSetManager); + } + + @Test + void testPollSqsMessages_handlesEmptyList() { + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(ReceiveMessageResponse.builder() + .messages(Collections.emptyList()) + .build()); + var messages = sqsWorkerCommon.pollSqsMessages( + "testQueueUrl", + 10, + Duration.ofSeconds(5), + Duration.ofSeconds(30) + ); + + assertNotNull(messages); + assertTrue(messages.isEmpty()); + Mockito.verify(sqsClient).receiveMessage(any(ReceiveMessageRequest.class)); + Mockito.verify(backoff, Mockito.never()).nextDelayMillis(Mockito.anyInt()); + } + + @Test + void testDeleteSqsMessages_callsClientWhenNotStopped() { + var entries = Collections.singletonList( + DeleteMessageBatchRequestEntry.builder() + .id("msg-id") + .receiptHandle("receipt-handle") + .build() + ); + + when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))) + .thenReturn(DeleteMessageBatchResponse.builder().build()); + + sqsWorkerCommon.deleteSqsMessages("testQueueUrl", entries); + ArgumentCaptor captor = + ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); + Mockito.verify(sqsClient).deleteMessageBatch(captor.capture()); + assertEquals("testQueueUrl", captor.getValue().queueUrl()); + assertEquals(1, captor.getValue().entries().size()); + } + + @Test + void testStop_skipsFurtherOperations() { + sqsWorkerCommon.stop(); + sqsWorkerCommon.deleteSqsMessages("testQueueUrl", Collections.singletonList( + DeleteMessageBatchRequestEntry.builder() + .id("msg-id") + .receiptHandle("receipt-handle") + .build() + )); + Mockito.verify(sqsClient, Mockito.never()).deleteMessageBatch((DeleteMessageBatchRequest) any()); + } +} diff --git a/data-prepper-plugins/sqs-source/build.gradle b/data-prepper-plugins/sqs-source/build.gradle index b4ffbc8e5e..7a8ce38f29 100644 --- a/data-prepper-plugins/sqs-source/build.gradle +++ b/data-prepper-plugins/sqs-source/build.gradle @@ -11,6 +11,7 @@ dependencies { implementation project(':data-prepper-api') implementation project(':data-prepper-plugins:buffer-common') implementation project(':data-prepper-plugins:common') + implementation project(':data-prepper-plugins:sqs-common') implementation libs.armeria.core implementation project(':data-prepper-plugins:aws-plugin-api') implementation 'software.amazon.awssdk:sqs' diff --git a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java index 672ee9874c..5c3e3ba2d2 100644 --- a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java +++ b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java @@ -18,16 +18,16 @@ import org.opensearch.dataprepper.model.configuration.PluginModel; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.plugin.PluginFactory; + import org.opensearch.dataprepper.plugins.source.sqs.common.SqsBackoff; + import org.opensearch.dataprepper.plugins.source.sqs.common.SqsClientFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; - import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; - import software.amazon.awssdk.core.retry.RetryPolicy; import software.amazon.awssdk.services.sqs.SqsClient; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; - import java.time.Duration; + import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; @@ -39,10 +39,6 @@ public class SqsService { private static final Logger LOG = LoggerFactory.getLogger(SqsService.class); static final long SHUTDOWN_TIMEOUT = 30L; - static final long INITIAL_DELAY = Duration.ofSeconds(20).toMillis(); - static final long MAXIMUM_DELAY = Duration.ofMinutes(5).toMillis(); - static final double JITTER_RATE = 0.20; - private final SqsSourceConfig sqsSourceConfig; private final SqsClient sqsClient; private final PluginMetrics pluginMetrics; @@ -51,6 +47,7 @@ public class SqsService { private final List allSqsUrlExecutorServices; private final List sqsWorkers; private final Buffer> buffer; + private final Backoff backoff; public SqsService(final Buffer> buffer, final AcknowledgementSetManager acknowledgementSetManager, @@ -65,17 +62,13 @@ public SqsService(final Buffer> buffer, this.acknowledgementSetManager = acknowledgementSetManager; this.allSqsUrlExecutorServices = new ArrayList<>(); this.sqsWorkers = new ArrayList<>(); - this.sqsClient = createSqsClient(credentialsProvider); + this.sqsClient = SqsClientFactory.createSqsClient(sqsSourceConfig.getAwsAuthenticationOptions().getAwsRegion(), credentialsProvider); this.buffer = buffer; + backoff = SqsBackoff.createExponentialBackoff(); } - public void start() { - final Backoff backoff = Backoff.exponential(INITIAL_DELAY, MAXIMUM_DELAY).withJitter(JITTER_RATE) - .withMaxAttempts(Integer.MAX_VALUE); - LOG.info("Starting SqsService"); - sqsSourceConfig.getQueues().forEach(queueConfig -> { String queueUrl = queueConfig.getUrl(); String queueName = queueUrl.substring(queueUrl.lastIndexOf('/') + 1); @@ -112,17 +105,7 @@ public void start() { LOG.info("Started SQS workers for queue {} with {} workers", queueUrl, numWorkers); }); } - - SqsClient createSqsClient(final AwsCredentialsProvider credentialsProvider) { - LOG.debug("Creating SQS client"); - return SqsClient.builder() - .region(sqsSourceConfig.getAwsAuthenticationOptions().getAwsRegion()) - .credentialsProvider(credentialsProvider) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .retryPolicy(RetryPolicy.builder().numRetries(5).build()) - .build()) - .build(); - } + public void stop() { allSqsUrlExecutorServices.forEach(ExecutorService::shutdown); diff --git a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java index f6de0b9ee1..cb4c168345 100644 --- a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java +++ b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java @@ -17,58 +17,35 @@ import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.plugins.source.sqs.common.SqsWorkerCommon; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.sqs.SqsClient; -import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest; -import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry; -import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; import software.amazon.awssdk.services.sqs.model.Message; -import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; -import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; -import software.amazon.awssdk.services.sqs.model.SqsException; -import software.amazon.awssdk.services.sts.model.StsException; -import org.opensearch.dataprepper.model.buffer.Buffer; import java.time.Duration; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; - public class SqsWorker implements Runnable { private static final Logger LOG = LoggerFactory.getLogger(SqsWorker.class); - static final String SQS_MESSAGES_RECEIVED_METRIC_NAME = "sqsMessagesReceived"; - static final String SQS_MESSAGES_DELETED_METRIC_NAME = "sqsMessagesDeleted"; - static final String SQS_MESSAGES_FAILED_METRIC_NAME = "sqsMessagesFailed"; - static final String SQS_MESSAGES_DELETE_FAILED_METRIC_NAME = "sqsMessagesDeleteFailed"; - static final String SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME = "sqsVisibilityTimeoutChangedCount"; - static final String SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME = "sqsVisibilityTimeoutChangeFailedCount"; static final String ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME = "acknowledgementSetCallbackCounter"; - - private final SqsClient sqsClient; + private final SqsWorkerCommon sqsWorkerCommon; private final SqsEventProcessor sqsEventProcessor; - private final Counter sqsMessagesReceivedCounter; - private final Counter sqsMessagesDeletedCounter; - private final Counter sqsMessagesFailedCounter; - private final Counter sqsMessagesDeleteFailedCounter; - private final Counter acknowledgementSetCallbackCounter; - private final Counter sqsVisibilityTimeoutChangedCount; - private final Counter sqsVisibilityTimeoutChangeFailedCount; - private final Backoff standardBackoff; private final QueueConfig queueConfig; - private int failedAttemptCount; private final boolean endToEndAcknowledgementsEnabled; - private final AcknowledgementSetManager acknowledgementSetManager; - private volatile boolean isStopped = false; private final Buffer> buffer; private final int bufferTimeoutMillis; - private Map messageVisibilityTimesMap; + private final AcknowledgementSetManager acknowledgementSetManager; + private final Counter acknowledgementSetCallbackCounter; + private int failedAttemptCount; + private volatile boolean isStopped = false; + private final Map messageVisibilityTimesMap; public SqsWorker(final Buffer> buffer, final AcknowledgementSetManager acknowledgementSetManager, @@ -78,24 +55,16 @@ public SqsWorker(final Buffer> buffer, final PluginMetrics pluginMetrics, final SqsEventProcessor sqsEventProcessor, final Backoff backoff) { - - this.sqsClient = sqsClient; + this.sqsWorkerCommon = new SqsWorkerCommon(sqsClient, backoff, pluginMetrics, acknowledgementSetManager); this.queueConfig = queueConfig; this.acknowledgementSetManager = acknowledgementSetManager; - this.standardBackoff = backoff; - this.endToEndAcknowledgementsEnabled = sqsSourceConfig.getAcknowledgements(); + this.sqsEventProcessor = sqsEventProcessor; this.buffer = buffer; this.bufferTimeoutMillis = (int) sqsSourceConfig.getBufferTimeout().toMillis(); - this.sqsEventProcessor = sqsEventProcessor; - messageVisibilityTimesMap = new HashMap<>(); - failedAttemptCount = 0; - sqsMessagesReceivedCounter = pluginMetrics.counter(SQS_MESSAGES_RECEIVED_METRIC_NAME); - sqsMessagesDeletedCounter = pluginMetrics.counter(SQS_MESSAGES_DELETED_METRIC_NAME); - sqsMessagesFailedCounter = pluginMetrics.counter(SQS_MESSAGES_FAILED_METRIC_NAME); - sqsMessagesDeleteFailedCounter = pluginMetrics.counter(SQS_MESSAGES_DELETE_FAILED_METRIC_NAME); - acknowledgementSetCallbackCounter = pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME); - sqsVisibilityTimeoutChangedCount = pluginMetrics.counter(SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME); - sqsVisibilityTimeoutChangeFailedCount = pluginMetrics.counter(SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME); + this.endToEndAcknowledgementsEnabled = sqsSourceConfig.getAcknowledgements(); + this.messageVisibilityTimesMap = new HashMap<>(); + this.failedAttemptCount = 0; + this.acknowledgementSetCallbackCounter = pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME); } @Override @@ -104,10 +73,9 @@ public void run() { int messagesProcessed = 0; try { messagesProcessed = processSqsMessages(); - } catch (final Exception e) { LOG.error("Unable to process SQS messages. Processing error due to: {}", e.getMessage()); - applyBackoff(); + sqsWorkerCommon.applyBackoff(); } if (messagesProcessed > 0 && queueConfig.getPollDelay().toMillis() > 0) { @@ -115,216 +83,103 @@ public void run() { Thread.sleep(queueConfig.getPollDelay().toMillis()); } catch (final InterruptedException e) { LOG.error("Thread is interrupted while polling SQS.", e); + Thread.currentThread().interrupt(); } } } } int processSqsMessages() { - final List messages = getMessagesFromSqs(); + List messages = sqsWorkerCommon.pollSqsMessages(queueConfig.getUrl(), + queueConfig.getMaximumMessages(), + queueConfig.getWaitTime(), + queueConfig.getVisibilityTimeout()); if (!messages.isEmpty()) { - sqsMessagesReceivedCounter.increment(messages.size()); final List deleteMessageBatchRequestEntries = processSqsEvents(messages); if (!deleteMessageBatchRequestEntries.isEmpty()) { - deleteSqsMessages(deleteMessageBatchRequestEntries); + sqsWorkerCommon.deleteSqsMessages(queueConfig.getUrl(), deleteMessageBatchRequestEntries); } } return messages.size(); } - private List getMessagesFromSqs() { - try { - final ReceiveMessageRequest request = createReceiveMessageRequest(); - final ReceiveMessageResponse response = sqsClient.receiveMessage(request); - List messages = response.messages(); - failedAttemptCount = 0; - return messages; - - } catch (final SqsException | StsException e) { - LOG.error("Error reading from SQS: {}. Retrying with exponential backoff.", e.getMessage()); - applyBackoff(); - return Collections.emptyList(); - } - } - - private void applyBackoff() { - final long delayMillis = standardBackoff.nextDelayMillis(++failedAttemptCount); - if (delayMillis < 0) { - Thread.currentThread().interrupt(); - throw new SqsRetriesExhaustedException("SQS retries exhausted. Make sure that SQS configuration is valid, SQS queue exists, and IAM role has required permissions."); - } - final Duration delayDuration = Duration.ofMillis(delayMillis); - LOG.info("Pausing SQS processing for {}.{} seconds due to an error in processing.", - delayDuration.getSeconds(), delayDuration.toMillisPart()); - try { - Thread.sleep(delayMillis); - } catch (final InterruptedException e){ - LOG.error("Thread is interrupted while polling SQS with retry.", e); - } - } - - private ReceiveMessageRequest createReceiveMessageRequest() { - ReceiveMessageRequest.Builder requestBuilder = ReceiveMessageRequest.builder() - .queueUrl(queueConfig.getUrl()) - .attributeNamesWithStrings("All") - .messageAttributeNames("All"); - - if (queueConfig.getWaitTime() != null) { - requestBuilder.waitTimeSeconds((int) queueConfig.getWaitTime().getSeconds()); - } - if (queueConfig.getMaximumMessages() != null) { - requestBuilder.maxNumberOfMessages(queueConfig.getMaximumMessages()); - } - if (queueConfig.getVisibilityTimeout() != null) { - requestBuilder.visibilityTimeout((int) queueConfig.getVisibilityTimeout().getSeconds()); - } - return requestBuilder.build(); - } - private List processSqsEvents(final List messages) { final List deleteMessageBatchRequestEntryCollection = new ArrayList<>(); final Map messageAcknowledgementSetMap = new HashMap<>(); final Map> messageWaitingForAcknowledgementsMap = new HashMap<>(); - + for (Message message : messages) { List waitingForAcknowledgements = new ArrayList<>(); AcknowledgementSet acknowledgementSet = null; - final int visibilityTimeout; - if (queueConfig.getVisibilityTimeout() != null) { - visibilityTimeout = (int) queueConfig.getVisibilityTimeout().getSeconds(); - } else { - visibilityTimeout = (int) Duration.ofSeconds(30).getSeconds(); - - } + final int visibilityTimeout = queueConfig.getVisibilityTimeout() != null + ? (int) queueConfig.getVisibilityTimeout().getSeconds() + : 30; - final int maxVisibilityTimeout = (int)queueConfig.getVisibilityDuplicateProtectionTimeout().getSeconds(); - final int progressCheckInterval = visibilityTimeout/2 - 1; + final int maxVisibilityTimeout = (int) queueConfig.getVisibilityDuplicateProtectionTimeout().getSeconds(); + final int progressCheckInterval = visibilityTimeout / 2 - 1; if (endToEndAcknowledgementsEnabled) { - int expiryTimeout = visibilityTimeout - 2; - final boolean visibilityDuplicateProtectionEnabled = queueConfig.getVisibilityDuplicateProtection(); - if (visibilityDuplicateProtectionEnabled) { - expiryTimeout = maxVisibilityTimeout; - } - acknowledgementSet = acknowledgementSetManager.create( - (result) -> { - acknowledgementSetCallbackCounter.increment(); - // Delete only if this is positive acknowledgement - if (visibilityDuplicateProtectionEnabled) { - messageVisibilityTimesMap.remove(message); + int expiryTimeout = queueConfig.getVisibilityDuplicateProtection() + ? maxVisibilityTimeout + : visibilityTimeout - 2; + acknowledgementSet = acknowledgementSetManager.create(result -> { + acknowledgementSetCallbackCounter.increment(); + if (queueConfig.getVisibilityDuplicateProtection()) { + messageVisibilityTimesMap.remove(message); + } + if (result) { + sqsWorkerCommon.deleteSqsMessages(queueConfig.getUrl(), waitingForAcknowledgements); + } + }, Duration.ofSeconds(expiryTimeout)); + if (queueConfig.getVisibilityDuplicateProtection()) { + acknowledgementSet.addProgressCheck(ratio -> { + int newValue = messageVisibilityTimesMap.getOrDefault(message, visibilityTimeout) + progressCheckInterval; + if (newValue >= maxVisibilityTimeout) { + return; } - if (result) { - deleteSqsMessages(waitingForAcknowledgements); - } - }, - Duration.ofSeconds(expiryTimeout)); - if (visibilityDuplicateProtectionEnabled) { - acknowledgementSet.addProgressCheck( - (ratio) -> { - int newValue = messageVisibilityTimesMap.getOrDefault(message, visibilityTimeout) + progressCheckInterval; - if (newValue >= maxVisibilityTimeout) { - return; - } - messageVisibilityTimesMap.put(message, newValue); - final int newVisibilityTimeoutSeconds = visibilityTimeout; - increaseVisibilityTimeout(message, newVisibilityTimeoutSeconds); - }, - Duration.ofSeconds(progressCheckInterval)); + messageVisibilityTimesMap.put(message, newValue); + sqsWorkerCommon.increaseVisibilityTimeout(queueConfig.getUrl(), + message.receiptHandle(), + visibilityTimeout, + message.messageId()); + }, Duration.ofSeconds(progressCheckInterval)); } messageAcknowledgementSetMap.put(message, acknowledgementSet); messageWaitingForAcknowledgementsMap.put(message, waitingForAcknowledgements); } } - - if (endToEndAcknowledgementsEnabled) { - LOG.debug("Created acknowledgement sets for {} messages.", messages.size()); - } for (Message message : messages) { final AcknowledgementSet acknowledgementSet = messageAcknowledgementSetMap.get(message); final List waitingForAcknowledgements = messageWaitingForAcknowledgementsMap.get(message); - final Optional deleteMessageBatchRequestEntry = processSqsObject(message, acknowledgementSet); + final Optional deleteEntry = processSqsObject(message, acknowledgementSet); if (endToEndAcknowledgementsEnabled) { - deleteMessageBatchRequestEntry.ifPresent(waitingForAcknowledgements::add); - acknowledgementSet.complete(); + deleteEntry.ifPresent(waitingForAcknowledgements::add); + if (acknowledgementSet != null) { + acknowledgementSet.complete(); + } } else { - deleteMessageBatchRequestEntry.ifPresent(deleteMessageBatchRequestEntryCollection::add); + deleteEntry.ifPresent(deleteMessageBatchRequestEntryCollection::add); } } - return deleteMessageBatchRequestEntryCollection; } - - private Optional processSqsObject( - final Message message, - final AcknowledgementSet acknowledgementSet) { + private Optional processSqsObject(final Message message, + final AcknowledgementSet acknowledgementSet) { try { sqsEventProcessor.addSqsObject(message, queueConfig.getUrl(), buffer, bufferTimeoutMillis, acknowledgementSet); - return Optional.of(buildDeleteMessageBatchRequestEntry(message)); + return Optional.of(sqsWorkerCommon.buildDeleteMessageBatchRequestEntry(message.messageId(), message.receiptHandle())); } catch (final Exception e) { - sqsMessagesFailedCounter.increment(); + sqsWorkerCommon.getSqsMessagesFailedCounter().increment(); LOG.error("Error processing from SQS: {}. Retrying with exponential backoff.", e.getMessage()); - applyBackoff(); + sqsWorkerCommon.applyBackoff(); return Optional.empty(); } } - private void increaseVisibilityTimeout(final Message message, final int newVisibilityTimeoutSeconds) { - if(isStopped) { - LOG.info("Some messages are pending completion of acknowledgments. Data Prepper will not increase the visibility timeout because it is shutting down. {}", message); - return; - } - final ChangeMessageVisibilityRequest changeMessageVisibilityRequest = ChangeMessageVisibilityRequest.builder() - .visibilityTimeout(newVisibilityTimeoutSeconds) - .queueUrl(queueConfig.getUrl()) - .receiptHandle(message.receiptHandle()) - .build(); - - try { - sqsClient.changeMessageVisibility(changeMessageVisibilityRequest); - sqsVisibilityTimeoutChangedCount.increment(); - LOG.debug("Set visibility timeout for message {} to {}", message.messageId(), newVisibilityTimeoutSeconds); - } catch (Exception e) { - LOG.error("Failed to set visibility timeout for message {} to {}", message.messageId(), newVisibilityTimeoutSeconds, e); - sqsVisibilityTimeoutChangeFailedCount.increment(); - } - } - - - private DeleteMessageBatchRequestEntry buildDeleteMessageBatchRequestEntry(Message message) { - return DeleteMessageBatchRequestEntry.builder() - .id(message.messageId()) - .receiptHandle(message.receiptHandle()) - .build(); - } - - private void deleteSqsMessages(final List deleteEntries) { - if (deleteEntries.isEmpty()) return; - - try { - DeleteMessageBatchRequest deleteRequest = DeleteMessageBatchRequest.builder() - .queueUrl(queueConfig.getUrl()) - .entries(deleteEntries) - .build(); - DeleteMessageBatchResponse response = sqsClient.deleteMessageBatch(deleteRequest); - - if (response.hasSuccessful()) { - int successfulDeletes = response.successful().size(); - sqsMessagesDeletedCounter.increment(successfulDeletes); - } - if (response.hasFailed()) { - int failedDeletes = response.failed().size(); - sqsMessagesDeleteFailedCounter.increment(failedDeletes); - LOG.error("Failed to delete {} messages from SQS.", failedDeletes); - } - } catch (SdkException e) { - LOG.error("Failed to delete messages from SQS: {}", e.getMessage()); - sqsMessagesDeleteFailedCounter.increment(deleteEntries.size()); - } - } - void stop() { isStopped = true; + sqsWorkerCommon.stop(); } -} \ No newline at end of file +} diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java index 695164db82..83a12e5940 100644 --- a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java @@ -22,11 +22,9 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.sqs.SqsClient; import java.util.List; -import static org.mockito.Mockito.doReturn; + import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.mockito.Mockito.withSettings; @@ -62,26 +60,19 @@ void start_with_single_queue_starts_workers() { when(queueConfig.getNumWorkers()).thenReturn(2); when(sqsSourceConfig.getQueues()).thenReturn(List.of(queueConfig)); SqsService sqsService = spy(new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, pluginMetrics, pluginFactory, credentialsProvider)); - doReturn(sqsClient).when(sqsService).createSqsClient(credentialsProvider); sqsService.start(); // if no exception is thrown here, then workers have been started } @Test - void stop_should_shutdown_executors_and_workers_and_close_client() throws InterruptedException { + void stop_should_shutdown_executors_and_workers() throws InterruptedException { QueueConfig queueConfig = mock(QueueConfig.class); when(queueConfig.getUrl()).thenReturn("MyQueue"); when(queueConfig.getNumWorkers()).thenReturn(1); when(sqsSourceConfig.getQueues()).thenReturn(List.of(queueConfig)); SqsClient sqsClient = mock(SqsClient.class); - SqsService sqsService = new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, pluginMetrics, pluginFactory, credentialsProvider) { - @Override - SqsClient createSqsClient(final AwsCredentialsProvider credentialsProvider) { - return sqsClient; - } - }; + SqsService sqsService = new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, pluginMetrics, pluginFactory, credentialsProvider) {}; sqsService.start(); - sqsService.stop(); - verify(sqsClient, times(1)).close(); + sqsService.stop(); // again assuming that if no exception is thrown here, then workers and client have been stopped } } \ No newline at end of file diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java index 22bf48596f..e7339543c2 100644 --- a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java @@ -26,6 +26,8 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.plugin.PluginFactory; import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.source.sqs.common.SqsWorkerCommon; +import org.opensearch.dataprepper.plugins.source.sqs.common.SqsRetriesExhaustedException; import software.amazon.awssdk.services.sqs.SqsClient; import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest; import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; @@ -111,13 +113,20 @@ private SqsWorker createObjectUnderTest() { @BeforeEach void setUp() { - when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME)).thenReturn(sqsMessagesReceivedCounter); - when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME)).thenReturn(sqsMessagesDeletedCounter); - when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_FAILED_METRIC_NAME)).thenReturn(sqsMessagesFailedCounter); - when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETE_FAILED_METRIC_NAME)).thenReturn(sqsMessagesDeleteFailedCounter); - when(pluginMetrics.counter(SqsWorker.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)).thenReturn(acknowledgementSetCallbackCounter); - when(pluginMetrics.counter(SqsWorker.SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME)).thenReturn(sqsVisibilityTimeoutChangedCount); - when(pluginMetrics.counter(SqsWorker.SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME)).thenReturn(sqsVisibilityTimeoutChangeFailedCount); + when(pluginMetrics.counter(SqsWorkerCommon.SQS_MESSAGES_RECEIVED_METRIC_NAME)) + .thenReturn(sqsMessagesReceivedCounter); + when(pluginMetrics.counter(SqsWorkerCommon.SQS_MESSAGES_DELETED_METRIC_NAME)) + .thenReturn(sqsMessagesDeletedCounter); + when(pluginMetrics.counter(SqsWorkerCommon.SQS_MESSAGES_FAILED_METRIC_NAME)) + .thenReturn(sqsMessagesFailedCounter); + when(pluginMetrics.counter(SqsWorkerCommon.SQS_MESSAGES_DELETE_FAILED_METRIC_NAME)) + .thenReturn(sqsMessagesDeleteFailedCounter); + when(pluginMetrics.counter(SqsWorkerCommon.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)) + .thenReturn(acknowledgementSetCallbackCounter); + when(pluginMetrics.counter(SqsWorkerCommon.SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME)) + .thenReturn(sqsVisibilityTimeoutChangedCount); + when(pluginMetrics.counter(SqsWorkerCommon.SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME)) + .thenReturn(sqsVisibilityTimeoutChangeFailedCount); when(sqsSourceConfig.getAcknowledgements()).thenReturn(false); when(sqsSourceConfig.getBufferTimeout()).thenReturn(Duration.ofSeconds(10)); when(queueConfig.getUrl()).thenReturn("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"); diff --git a/settings.gradle b/settings.gradle index d86bc7e1da..d2aa09b52c 100644 --- a/settings.gradle +++ b/settings.gradle @@ -167,6 +167,7 @@ include 'data-prepper-plugins:parquet-codecs' include 'data-prepper-plugins:aws-sqs-common' include 'data-prepper-plugins:buffer-common' include 'data-prepper-plugins:sqs-source' +include 'data-prepper-plugins:sqs-common' //include 'data-prepper-plugins:cloudwatch-logs' //include 'data-prepper-plugins:http-sink' //include 'data-prepper-plugins:sns-sink'