diff --git a/CHANGELOG.md b/CHANGELOG.md index aa6e7bce8655d..15a2764d810c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Changed - Convert transport-reactor-netty4 to use gradle version catalog [#17233](https://github.com/opensearch-project/OpenSearch/pull/17233) - Increase force merge threads to 1/8th of cores [#17255](https://github.com/opensearch-project/OpenSearch/pull/17255) +- Avoid invalid retries in multiple replicas when querying [#17370](https://github.com/opensearch-project/OpenSearch/pull/17370) ### Deprecated diff --git a/libs/core/src/main/java/org/opensearch/OpenSearchException.java b/libs/core/src/main/java/org/opensearch/OpenSearchException.java index dda3983fbb4d1..38945fa411430 100644 --- a/libs/core/src/main/java/org/opensearch/OpenSearchException.java +++ b/libs/core/src/main/java/org/opensearch/OpenSearchException.java @@ -296,8 +296,12 @@ protected Map> getHeaders() { * Returns the rest status code associated with this exception. */ public RestStatus status() { - Throwable cause = unwrapCause(); - if (cause == this) { + return status(this); + } + + public static RestStatus status(Throwable t) { + Throwable cause = ExceptionsHelper.unwrapCause(t); + if (cause == t) { return RestStatus.INTERNAL_SERVER_ERROR; } else { return ExceptionsHelper.status(cause); diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 85ea34e442c8f..7f7f080e165a6 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -514,10 +514,19 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh // we do make sure to clean it on a successful response from a shard setPhaseResourceUsages(); onShardFailure(shardIndex, shard, e); - SearchShardTarget nextShard = FailAwareWeightedRouting.getInstance() - .findNext(shardIt, clusterState, e, () -> totalOps.incrementAndGet()); - final boolean lastShard = nextShard == null; + final SearchShardTarget nextShard; + final boolean lastShard; + final int advanceShardCount; + if (TransportActions.isRetryableSearchException(e)) { + nextShard = FailAwareWeightedRouting.getInstance().findNext(shardIt, clusterState, e, () -> totalOps.incrementAndGet()); + lastShard = nextShard == null; + advanceShardCount = 1; + } else { + nextShard = null; + lastShard = true; + advanceShardCount = remainingOpsOnIterator(shardIt); + } if (logger.isTraceEnabled()) { logger.trace( () -> new ParameterizedMessage( @@ -542,7 +551,7 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh if (lastShard) { onShardGroupFailure(shardIndex, shard, e); } - final int totalOps = this.totalOps.incrementAndGet(); + final int totalOps = this.totalOps.addAndGet(advanceShardCount); if (totalOps == expectedTotalOps) { try { onPhaseDone(); @@ -561,6 +570,14 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh } } + private int remainingOpsOnIterator(SearchShardIterator shardsIt) { + if (shardsIt.skip()) { + return shardsIt.remaining(); + } else { + return shardsIt.remaining() + 1; + } + } + /** * Executed once for every {@link ShardId} that failed on all available shard routing. * @@ -651,12 +668,7 @@ private void onShardResultConsumed(Result result, SearchShardIterator shardIt) { } private void successfulShardExecution(SearchShardIterator shardsIt) { - final int remainingOpsOnIterator; - if (shardsIt.skip()) { - remainingOpsOnIterator = shardsIt.remaining(); - } else { - remainingOpsOnIterator = shardsIt.remaining() + 1; - } + final int remainingOpsOnIterator = remainingOpsOnIterator(shardsIt); final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator); if (xTotalOps == expectedTotalOps) { try { diff --git a/server/src/main/java/org/opensearch/action/support/TransportActions.java b/server/src/main/java/org/opensearch/action/support/TransportActions.java index 03e7509b3b8e3..8c5f67d66410f 100644 --- a/server/src/main/java/org/opensearch/action/support/TransportActions.java +++ b/server/src/main/java/org/opensearch/action/support/TransportActions.java @@ -34,8 +34,10 @@ import org.apache.lucene.store.AlreadyClosedException; import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchException; import org.opensearch.action.NoShardAvailableActionException; import org.opensearch.action.UnavailableShardsException; +import org.opensearch.core.tasks.TaskCancelledException; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.shard.IllegalIndexShardStateException; import org.opensearch.index.shard.ShardNotFoundException; @@ -64,4 +66,8 @@ public static boolean isReadOverrideException(Exception e) { return !isShardNotAvailableException(e); } + public static boolean isRetryableSearchException(final Exception e) { + return (OpenSearchException.status(e).getStatus() / 100 != 4) && (e.getCause() instanceof TaskCancelledException == false); + } + } diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index b0fab3b7a3556..4ffdaf9ab36f8 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -49,6 +49,7 @@ import org.opensearch.core.common.breaker.NoopCircuitBreaker; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.tasks.TaskCancelledException; import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -66,6 +67,7 @@ import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportException; import org.junit.After; import org.junit.Before; @@ -136,6 +138,7 @@ private AbstractSearchAsyncAction createAction( controlled, false, false, + false, expected, resourceUsage, new SearchShardIterator(null, null, Collections.emptyList(), null) @@ -148,6 +151,7 @@ private AbstractSearchAsyncAction createAction( ActionListener listener, final boolean controlled, final boolean failExecutePhaseOnShard, + final boolean throw4xxExceptionOnShard, final boolean catchExceptionWhenExecutePhaseOnShard, final AtomicLong expected, final TaskResourceUsage resourceUsage, @@ -217,7 +221,11 @@ protected void executePhaseOnShard( final SearchActionListener listener ) { if (failExecutePhaseOnShard) { - listener.onFailure(new ShardNotFoundException(shardIt.shardId())); + if (throw4xxExceptionOnShard) { + listener.onFailure(new TransportException(new TaskCancelledException(shardIt.shardId().toString()))); + } else { + listener.onFailure(new ShardNotFoundException(shardIt.shardId())); + } } else { if (catchExceptionWhenExecutePhaseOnShard) { try { @@ -585,6 +593,7 @@ public void onFailure(Exception e) { false, true, false, + false, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), shards @@ -601,6 +610,62 @@ public void onFailure(Exception e) { assertThat(searchResponse.getSuccessfulShards(), equalTo(0)); } + public void testSkipInValidRetryInMultiReplicas() throws InterruptedException { + final Index index = new Index("test", UUID.randomUUID().toString()); + final CountDownLatch latch = new CountDownLatch(1); + final AtomicBoolean fail = new AtomicBoolean(true); + + List targetNodeIds = List.of("n1", "n2", "n3"); + final SearchShardIterator[] shards = IntStream.range(2, 4) + .mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), targetNodeIds, null, null, null)) + .toArray(SearchShardIterator[]::new); + + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + searchRequest.setMaxConcurrentShardRequests(1); + + final ArraySearchPhaseResults queryResult = new ArraySearchPhaseResults<>(shards.length); + AbstractSearchAsyncAction action = createAction( + searchRequest, + queryResult, + new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + + } + + @Override + public void onFailure(Exception e) { + if (fail.compareAndExchange(true, false)) { + try { + throw new RuntimeException("Simulated exception"); + } finally { + executor.submit(() -> latch.countDown()); + } + } + } + }, + false, + true, + true, + false, + new AtomicLong(), + new TaskResourceUsage(randomLong(), randomLong()), + shards + ); + action.run(); + assertTrue(latch.await(1, TimeUnit.SECONDS)); + InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty(); + SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null); + assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations()); + assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest()); + assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile()); + assertSame(searchResponse.getHits(), internalSearchResponse.hits()); + assertThat(searchResponse.getSuccessfulShards(), equalTo(0)); + for (int i = 0; i < shards.length; i++) { + assertEquals(targetNodeIds.size() - 1, shards[i].remaining()); + } + } + public void testOnShardSuccessPhaseDoneFailure() throws InterruptedException { final Index index = new Index("test", UUID.randomUUID().toString()); final CountDownLatch latch = new CountDownLatch(1); @@ -633,6 +698,7 @@ public void onFailure(Exception e) { false, false, false, + false, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), shards @@ -685,6 +751,7 @@ public void onFailure(Exception e) { }, false, false, + false, catchExceptionWhenExecutePhaseOnShard, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()),