diff --git a/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsBackoffTest.java b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsBackoffTest.java new file mode 100644 index 0000000000..b793cf39bc --- /dev/null +++ b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsBackoffTest.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.dataprepper.plugins.source.sqs.common; + +import com.linecorp.armeria.client.retry.Backoff; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class SqsBackoffTest { + + @Test + void testCreateExponentialBackoff() { + final Backoff backoff = SqsBackoff.createExponentialBackoff(); + assertNotNull(backoff, "Backoff should not be null"); + final long firstDelay = backoff.nextDelayMillis(1); + final long expectedBaseDelay = 20_000L; + final double jitterRate = 0.20; + final long minDelay = (long) (expectedBaseDelay * (1 - jitterRate)); + final long maxDelay = (long) (expectedBaseDelay * (1 + jitterRate)); + + assertTrue( + firstDelay >= minDelay && firstDelay <= maxDelay, + String.format("First delay %dms should be between %dms and %dms", + firstDelay, minDelay, maxDelay) + ); + } +} diff --git a/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsClientFactoryTest.java b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsClientFactoryTest.java new file mode 100644 index 0000000000..5f2f64e48e --- /dev/null +++ b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsClientFactoryTest.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.dataprepper.plugins.source.sqs.common; + +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sqs.SqsClient; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +class SqsClientFactoryTest { + + @Test + void testCreateSqsClientReturnsNonNull() { + final StaticCredentialsProvider credentialsProvider = + StaticCredentialsProvider.create(AwsBasicCredentials.create("testKey", "testSecret")); + + final SqsClient sqsClient = SqsClientFactory.createSqsClient(Region.US_EAST_1, credentialsProvider); + assertNotNull(sqsClient, "SqsClient should not be null"); + } +} diff --git a/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommonTest.java b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommonTest.java new file mode 100644 index 0000000000..32cbf3fd69 --- /dev/null +++ b/data-prepper-plugins/sqs-common/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommonTest.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.sqs.common; + +import com.linecorp.armeria.client.retry.Backoff; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; + +import java.time.Duration; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +class SqsWorkerCommonTest { + private SqsClient sqsClient; + private Backoff backoff; + private PluginMetrics pluginMetrics; + private AcknowledgementSetManager acknowledgementSetManager; + private SqsWorkerCommon sqsWorkerCommon; + + @BeforeEach + void setUp() { + sqsClient = Mockito.mock(SqsClient.class); + backoff = Mockito.mock(Backoff.class); + pluginMetrics = Mockito.mock(PluginMetrics.class); + acknowledgementSetManager = Mockito.mock(AcknowledgementSetManager.class); + when(pluginMetrics.counter(Mockito.anyString())).thenReturn(Mockito.mock(Counter.class)); + when(pluginMetrics.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class)); + sqsWorkerCommon = new SqsWorkerCommon(sqsClient, backoff, pluginMetrics, acknowledgementSetManager); + } + + @Test + void testPollSqsMessages_handlesEmptyList() { + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(ReceiveMessageResponse.builder() + .messages(Collections.emptyList()) + .build()); + var messages = sqsWorkerCommon.pollSqsMessages( + "testQueueUrl", + 10, + Duration.ofSeconds(5), + Duration.ofSeconds(30) + ); + + assertNotNull(messages); + assertTrue(messages.isEmpty()); + Mockito.verify(sqsClient).receiveMessage(any(ReceiveMessageRequest.class)); + Mockito.verify(backoff, Mockito.never()).nextDelayMillis(Mockito.anyInt()); + } + + @Test + void testDeleteSqsMessages_callsClientWhenNotStopped() { + var entries = Collections.singletonList( + DeleteMessageBatchRequestEntry.builder() + .id("msg-id") + .receiptHandle("receipt-handle") + .build() + ); + + when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))) + .thenReturn(DeleteMessageBatchResponse.builder().build()); + + sqsWorkerCommon.deleteSqsMessages("testQueueUrl", entries); + ArgumentCaptor captor = + ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); + Mockito.verify(sqsClient).deleteMessageBatch(captor.capture()); + assertEquals("testQueueUrl", captor.getValue().queueUrl()); + assertEquals(1, captor.getValue().entries().size()); + } + + @Test + void testStop_skipsFurtherOperations() { + sqsWorkerCommon.stop(); + sqsWorkerCommon.deleteSqsMessages("testQueueUrl", Collections.singletonList( + DeleteMessageBatchRequestEntry.builder() + .id("msg-id") + .receiptHandle("receipt-handle") + .build() + )); + Mockito.verify(sqsClient, Mockito.never()).deleteMessageBatch((DeleteMessageBatchRequest) any()); + } +}