Skip to content

Commit

Permalink
Refactor client's getDetectorProfile to use GetAnomalyDetectorTranspo…
Browse files Browse the repository at this point in the history
…rtAction (#1124)

Signed-off-by: Tyler Ohlsen <[email protected]>
  • Loading branch information
ohltyler authored Dec 28, 2023
1 parent 59b4ebe commit 106dc25
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.ad.transport.GetAnomalyDetectorRequest;
import org.opensearch.ad.transport.GetAnomalyDetectorResponse;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;

Expand Down Expand Up @@ -54,20 +55,20 @@ default ActionFuture<SearchResponse> searchAnomalyResults(SearchRequest searchRe

/**
* Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector
* @param detectorId the detector ID to fetch the profile for
* @return ActionFuture of ADTaskProfileResponse
* @param profileRequest request to fetch the detector profile
* @return ActionFuture of GetAnomalyDetectorResponse
*/
default ActionFuture<ADTaskProfileResponse> getDetectorProfile(String detectorId) {
PlainActionFuture<ADTaskProfileResponse> actionFuture = PlainActionFuture.newFuture();
getDetectorProfile(detectorId, actionFuture);
default ActionFuture<GetAnomalyDetectorResponse> getDetectorProfile(GetAnomalyDetectorRequest profileRequest) {
PlainActionFuture<GetAnomalyDetectorResponse> actionFuture = PlainActionFuture.newFuture();
getDetectorProfile(profileRequest, actionFuture);
return actionFuture;
}

/**
* Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector
* @param detectorId the detector ID to fetch the profile for
* @param profileRequest request to fetch the detector profile
* @param listener a listener to be notified of the result
*/
void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener);
void getDetectorProfile(GetAnomalyDetectorRequest profileRequest, ActionListener<GetAnomalyDetectorResponse> listener);

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,20 @@

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ad.transport.ADTaskProfileAction;
import org.opensearch.ad.transport.ADTaskProfileRequest;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.ad.transport.GetAnomalyDetectorAction;
import org.opensearch.ad.transport.GetAnomalyDetectorRequest;
import org.opensearch.ad.transport.GetAnomalyDetectorResponse;
import org.opensearch.ad.transport.SearchAnomalyDetectorAction;
import org.opensearch.ad.transport.SearchAnomalyResultAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.timeseries.util.DiscoveryNodeFilterer;

public class AnomalyDetectionNodeClient implements AnomalyDetectionClient {
private final Client client;
private final DiscoveryNodeFilterer nodeFilterer;

public AnomalyDetectionNodeClient(Client client, ClusterService clusterService) {
public AnomalyDetectionNodeClient(Client client) {
this.client = client;
this.nodeFilterer = new DiscoveryNodeFilterer(clusterService);
}

@Override
Expand All @@ -45,19 +40,20 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener<Sea
}

@Override
public void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener) {
final DiscoveryNode[] eligibleNodes = this.nodeFilterer.getEligibleDataNodes();
ADTaskProfileRequest profileRequest = new ADTaskProfileRequest(detectorId, eligibleNodes);
this.client.execute(ADTaskProfileAction.INSTANCE, profileRequest, getADTaskProfileResponseActionListener(listener));
public void getDetectorProfile(GetAnomalyDetectorRequest profileRequest, ActionListener<GetAnomalyDetectorResponse> listener) {
this.client.execute(GetAnomalyDetectorAction.INSTANCE, profileRequest, getAnomalyDetectorResponseActionListener(listener));
}

// We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic
// ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins.
private ActionListener<ADTaskProfileResponse> getADTaskProfileResponseActionListener(ActionListener<ADTaskProfileResponse> listener) {
ActionListener<ADTaskProfileResponse> internalListener = ActionListener
.wrap(profileResponse -> { listener.onResponse(profileResponse); }, listener::onFailure);
ActionListener<ADTaskProfileResponse> actionListener = wrapActionListener(internalListener, actionResponse -> {
ADTaskProfileResponse response = ADTaskProfileResponse.fromActionResponse(actionResponse);
private ActionListener<GetAnomalyDetectorResponse> getAnomalyDetectorResponseActionListener(
ActionListener<GetAnomalyDetectorResponse> listener
) {
ActionListener<GetAnomalyDetectorResponse> internalListener = ActionListener.wrap(getAnomalyDetectorResponse -> {
listener.onResponse(getAnomalyDetectorResponse);
}, listener::onFailure);
ActionListener<GetAnomalyDetectorResponse> actionListener = wrapActionListener(internalListener, actionResponse -> {
GetAnomalyDetectorResponse response = GetAnomalyDetectorResponse.fromActionResponse(actionResponse);
return response;
});
return actionListener;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@

package org.opensearch.ad.transport;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.ad.model.ADTask;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.DetectorProfile;
import org.opensearch.ad.model.EntityProfile;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -212,4 +217,19 @@ public ADTask getHistoricalAdTask() {
public AnomalyDetector getDetector() {
return detector;
}

public static GetAnomalyDetectorResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof GetAnomalyDetectorResponse) {
return (GetAnomalyDetectorResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new GetAnomalyDetectorResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into GetAnomalyDetectorResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.ad.transport.GetAnomalyDetectorRequest;
import org.opensearch.ad.transport.GetAnomalyDetectorResponse;
import org.opensearch.common.lucene.uid.Versions;
import org.opensearch.core.action.ActionListener;

public class AnomalyDetectionClientTests {
Expand All @@ -27,7 +29,7 @@ public class AnomalyDetectionClientTests {
SearchResponse searchResultsResponse;

@Mock
ADTaskProfileResponse profileResponse;
GetAnomalyDetectorResponse profileResponse;

@Before
public void setUp() {
Expand All @@ -46,7 +48,7 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener<Sea
}

@Override
public void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener) {
public void getDetectorProfile(GetAnomalyDetectorRequest profileRequest, ActionListener<GetAnomalyDetectorResponse> listener) {
listener.onResponse(profileResponse);
}
};
Expand All @@ -64,7 +66,17 @@ public void searchAnomalyResults() {

@Test
public void getDetectorProfile() {
assertEquals(profileResponse, anomalyDetectionClient.getDetectorProfile("foo").actionGet());
GetAnomalyDetectorRequest profileRequest = new GetAnomalyDetectorRequest(
"foo",
Versions.MATCH_ANY,
true,
false,
"",
"",
false,
null
);
assertEquals(profileResponse, anomalyDetectionClient.getDetectorProfile(profileRequest).actionGet());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,42 @@
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN;
import static org.opensearch.ad.model.AnomalyDetector.DETECTOR_TYPE_FIELD;
import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG;

import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.junit.Before;
import org.junit.Test;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ad.HistoricalAnalysisIntegTestCase;
import org.opensearch.ad.constant.ADCommonName;
import org.opensearch.ad.model.ADTaskProfile;
import org.opensearch.ad.model.ADTask;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.AnomalyDetectorType;
import org.opensearch.ad.transport.ADTaskProfileAction;
import org.opensearch.ad.transport.ADTaskProfileNodeResponse;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.ad.model.DetectorProfile;
import org.opensearch.ad.model.DetectorState;
import org.opensearch.ad.transport.GetAnomalyDetectorAction;
import org.opensearch.ad.transport.GetAnomalyDetectorRequest;
import org.opensearch.ad.transport.GetAnomalyDetectorResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.lucene.uid.Versions;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.timeseries.TestHelpers;
import org.opensearch.timeseries.constant.CommonName;
import org.opensearch.timeseries.model.Job;

import com.google.common.collect.ImmutableList;

Expand All @@ -54,19 +55,16 @@
// The exhaustive set of transport action scenarios are within the respective transport action
// test suites themselves. We do not want to unnecessarily duplicate all of those tests here.
public class AnomalyDetectionNodeClientTests extends HistoricalAnalysisIntegTestCase {
private final Logger logger = LogManager.getLogger(this.getClass());

private String indexName = "test-data";
private Instant startTime = Instant.now().minus(2, ChronoUnit.DAYS);
private Client clientSpy;
private AnomalyDetectionNodeClient adClient;
private PlainActionFuture<SearchResponse> searchResponseFuture;
private PlainActionFuture<ADTaskProfileResponse> profileFuture;

@Before
public void setup() {
clientSpy = spy(client());
adClient = new AnomalyDetectionNodeClient(clientSpy, clusterService());
adClient = new AnomalyDetectionNodeClient(clientSpy);
}

@Test
Expand Down Expand Up @@ -150,39 +148,90 @@ public void testGetDetectorProfile_NoIndices() throws ExecutionException, Interr
deleteIndexIfExists(ALL_AD_RESULTS_INDEX_PATTERN);
deleteIndexIfExists(ADCommonName.DETECTION_STATE_INDEX);

profileFuture = mock(PlainActionFuture.class);
ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000);
List<ADTaskProfileNodeResponse> responses = response.getNodes();

assertNotEquals(0, responses.size());
assertEquals(null, responses.get(0).getAdTaskProfile());
verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any());

GetAnomalyDetectorRequest profileRequest = new GetAnomalyDetectorRequest(
"foo",
Versions.MATCH_ANY,
true,
false,
"",
"",
false,
null
);

OpenSearchStatusException exception = expectThrows(
OpenSearchStatusException.class,
() -> adClient.getDetectorProfile(profileRequest).actionGet(10000)
);

assertTrue(exception.getMessage().contains(FAIL_TO_FIND_CONFIG_MSG));
verify(clientSpy, times(1)).execute(any(GetAnomalyDetectorAction.class), any(), any());
}

@Test
public void testGetDetectorProfile_Populated() {
DiscoveryNode localNode = clusterService().localNode();
ADTaskProfile adTaskProfile = new ADTaskProfile("foo-task-id", 0, 0L, false, 0, 0L, localNode.getId());
public void testGetDetectorProfile_Populated() throws IOException {
ingestTestData(indexName, startTime, 1, "test", 10);
AnomalyDetector detector = TestHelpers
.randomAnomalyDetector(
ImmutableList.of(indexName),
ImmutableList.of(TestHelpers.randomFeature(true)),
null,
Instant.now(),
1,
false,
null
);
createDetectorIndex();
String detectorId = createDetector(detector);

doAnswer(invocation -> {
Object[] args = invocation.getArguments();

ActionListener<ADTaskProfileResponse> listener = (ActionListener<ADTaskProfileResponse>) args[2];
ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(localNode, adTaskProfile, null);

List<ADTaskProfileNodeResponse> nodeResponses = Arrays.asList(nodeResponse);
listener.onResponse(new ADTaskProfileResponse(new ClusterName("test-cluster"), nodeResponses, Collections.emptyList()));
ActionListener<GetAnomalyDetectorResponse> listener = (ActionListener<GetAnomalyDetectorResponse>) args[2];

// Setting up mock profile to test that the state is returned correctly in the client response
DetectorProfile mockProfile = mock(DetectorProfile.class);
when(mockProfile.getState()).thenReturn(DetectorState.DISABLED);

GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse(
1234,
"4567",
9876,
2345,
detector,
mock(Job.class),
false,
mock(ADTask.class),
mock(ADTask.class),
false,
RestStatus.OK,
mockProfile,
null,
false
);
listener.onResponse(response);

return null;
}).when(clientSpy).execute(any(ADTaskProfileAction.class), any(), any());

ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000);
String responseTaskId = response.getNodes().get(0).getAdTaskProfile().getTaskId();

assertNotEquals(0, response.getNodes().size());
assertEquals(responseTaskId, adTaskProfile.getTaskId());
verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any());
}).when(clientSpy).execute(any(GetAnomalyDetectorAction.class), any(), any());

GetAnomalyDetectorRequest profileRequest = new GetAnomalyDetectorRequest(
detectorId,
Versions.MATCH_ANY,
true,
false,
"",
"",
false,
null
);

GetAnomalyDetectorResponse response = adClient.getDetectorProfile(profileRequest).actionGet(10000);

assertNotEquals(null, response.getDetector());
assertNotEquals(null, response.getDetectorProfile());
assertEquals(null, response.getAdJob());
assertEquals(detector.getName(), response.getDetector().getName());
assertEquals(DetectorState.DISABLED, response.getDetectorProfile().getState());
verify(clientSpy, times(1)).execute(any(GetAnomalyDetectorAction.class), any(), any());
}

}

0 comments on commit 106dc25

Please sign in to comment.