Skip to content

Commit

Permalink
Avoid invalid retries on multiple replicas when querying
Browse files Browse the repository at this point in the history
Signed-off-by: kkewwei <[email protected]>
Signed-off-by: kkewwei <[email protected]>
  • Loading branch information
kkewwei committed Feb 19, 2025
1 parent 56825f6 commit 37168c2
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Changed
- Convert transport-reactor-netty4 to use gradle version catalog [#17233](https://github.com/opensearch-project/OpenSearch/pull/17233)
- Increase force merge threads to 1/8th of cores [#17255](https://github.com/opensearch-project/OpenSearch/pull/17255)
- Avoid invalid retries in multiple replicas when querying [#17370](https://github.com/opensearch-project/OpenSearch/pull/17370)

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,12 @@ protected Map<String, List<String>> getHeaders() {
* Returns the rest status code associated with this exception.
*/
public RestStatus status() {
Throwable cause = unwrapCause();
if (cause == this) {
return status(this);
}

public static RestStatus status(Throwable t) {
Throwable cause = ExceptionsHelper.unwrapCause(t);
if (cause == t) {
return RestStatus.INTERNAL_SERVER_ERROR;
} else {
return ExceptionsHelper.status(cause);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,10 +514,19 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh
// we do make sure to clean it on a successful response from a shard
setPhaseResourceUsages();
onShardFailure(shardIndex, shard, e);
SearchShardTarget nextShard = FailAwareWeightedRouting.getInstance()
.findNext(shardIt, clusterState, e, () -> totalOps.incrementAndGet());

final boolean lastShard = nextShard == null;
final SearchShardTarget nextShard;
final boolean lastShard;
final int advanceShardCount;
if (TransportActions.isRetryableSearchException(e)) {
nextShard = FailAwareWeightedRouting.getInstance().findNext(shardIt, clusterState, e, () -> totalOps.incrementAndGet());
lastShard = nextShard == null;
advanceShardCount = 1;
} else {
nextShard = null;
lastShard = true;
advanceShardCount = remainingOpsOnIterator(shardIt);
}
if (logger.isTraceEnabled()) {
logger.trace(
() -> new ParameterizedMessage(
Expand All @@ -542,7 +551,7 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh
if (lastShard) {
onShardGroupFailure(shardIndex, shard, e);
}
final int totalOps = this.totalOps.incrementAndGet();
final int totalOps = this.totalOps.addAndGet(advanceShardCount);
if (totalOps == expectedTotalOps) {
try {
onPhaseDone();
Expand All @@ -561,6 +570,14 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh
}
}

private int remainingOpsOnIterator(SearchShardIterator shardsIt) {
if (shardsIt.skip()) {
return shardsIt.remaining();
} else {
return shardsIt.remaining() + 1;
}
}

/**
* Executed once for every {@link ShardId} that failed on all available shard routing.
*
Expand Down Expand Up @@ -651,12 +668,7 @@ private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
}

private void successfulShardExecution(SearchShardIterator shardsIt) {
final int remainingOpsOnIterator;
if (shardsIt.skip()) {
remainingOpsOnIterator = shardsIt.remaining();
} else {
remainingOpsOnIterator = shardsIt.remaining() + 1;
}
final int remainingOpsOnIterator = remainingOpsOnIterator(shardsIt);
final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator);
if (xTotalOps == expectedTotalOps) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@

import org.apache.lucene.store.AlreadyClosedException;
import org.opensearch.ExceptionsHelper;
import org.opensearch.OpenSearchException;
import org.opensearch.action.NoShardAvailableActionException;
import org.opensearch.action.UnavailableShardsException;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.shard.IllegalIndexShardStateException;
import org.opensearch.index.shard.ShardNotFoundException;
Expand Down Expand Up @@ -64,4 +66,8 @@ public static boolean isReadOverrideException(Exception e) {
return !isShardNotAvailableException(e);
}

public static boolean isRetryableSearchException(final Exception e) {
return (OpenSearchException.status(e).getStatus() / 100 != 4) && (e.getCause() instanceof TaskCancelledException == false);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.opensearch.core.common.breaker.NoopCircuitBreaker;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;
import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage;
import org.opensearch.index.query.MatchAllQueryBuilder;
Expand All @@ -66,6 +67,7 @@
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.Transport;
import org.opensearch.transport.TransportException;
import org.junit.After;
import org.junit.Before;

Expand Down Expand Up @@ -136,6 +138,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
controlled,
false,
false,
false,
expected,
resourceUsage,
new SearchShardIterator(null, null, Collections.emptyList(), null)
Expand All @@ -148,6 +151,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
ActionListener<SearchResponse> listener,
final boolean controlled,
final boolean failExecutePhaseOnShard,
final boolean throw4xxExceptionOnShard,
final boolean catchExceptionWhenExecutePhaseOnShard,
final AtomicLong expected,
final TaskResourceUsage resourceUsage,
Expand Down Expand Up @@ -217,7 +221,11 @@ protected void executePhaseOnShard(
final SearchActionListener<SearchPhaseResult> listener
) {
if (failExecutePhaseOnShard) {
listener.onFailure(new ShardNotFoundException(shardIt.shardId()));
if (throw4xxExceptionOnShard) {
listener.onFailure(new TransportException(new TaskCancelledException(shardIt.shardId().toString())));
} else {
listener.onFailure(new ShardNotFoundException(shardIt.shardId()));
}
} else {
if (catchExceptionWhenExecutePhaseOnShard) {
try {
Expand Down Expand Up @@ -585,6 +593,7 @@ public void onFailure(Exception e) {
false,
true,
false,
false,
new AtomicLong(),
new TaskResourceUsage(randomLong(), randomLong()),
shards
Expand All @@ -601,6 +610,62 @@ public void onFailure(Exception e) {
assertThat(searchResponse.getSuccessfulShards(), equalTo(0));
}

public void testSkipInValidRetryInMultiReplicas() throws InterruptedException {
final Index index = new Index("test", UUID.randomUUID().toString());
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean fail = new AtomicBoolean(true);

List<String> targetNodeIds = List.of("n1", "n2", "n3");
final SearchShardIterator[] shards = IntStream.range(2, 4)
.mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), targetNodeIds, null, null, null))
.toArray(SearchShardIterator[]::new);

SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
searchRequest.setMaxConcurrentShardRequests(1);

final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(
searchRequest,
queryResult,
new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse response) {

}

@Override
public void onFailure(Exception e) {
if (fail.compareAndExchange(true, false)) {
try {
throw new RuntimeException("Simulated exception");
} finally {
executor.submit(() -> latch.countDown());
}
}
}
},
false,
true,
true,
false,
new AtomicLong(),
new TaskResourceUsage(randomLong(), randomLong()),
shards
);
action.run();
assertTrue(latch.await(1, TimeUnit.SECONDS));
InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null);
assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
assertSame(searchResponse.getHits(), internalSearchResponse.hits());
assertThat(searchResponse.getSuccessfulShards(), equalTo(0));
for (int i = 0; i < shards.length; i++) {
assertEquals(targetNodeIds.size() - 1, shards[i].remaining());
}
}

public void testOnShardSuccessPhaseDoneFailure() throws InterruptedException {
final Index index = new Index("test", UUID.randomUUID().toString());
final CountDownLatch latch = new CountDownLatch(1);
Expand Down Expand Up @@ -633,6 +698,7 @@ public void onFailure(Exception e) {
false,
false,
false,
false,
new AtomicLong(),
new TaskResourceUsage(randomLong(), randomLong()),
shards
Expand Down Expand Up @@ -685,6 +751,7 @@ public void onFailure(Exception e) {
},
false,
false,
false,
catchExceptionWhenExecutePhaseOnShard,
new AtomicLong(),
new TaskResourceUsage(randomLong(), randomLong()),
Expand Down

0 comments on commit 37168c2

Please sign in to comment.