Skip to content

Commit

Permalink
Upgrade rcf to 4.0
Browse files Browse the repository at this point in the history
This PR upgrades rcf to 4.0 as it has bug fixes and support for streaming imputation mode.

Testing done:
1. gradle build

Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo committed Mar 25, 2024
1 parent 1507dd4 commit 0dd8f56
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 60 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/backport.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ on:

jobs:
backport:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest
permissions:
contents: write
Expand All @@ -25,4 +26,5 @@ jobs:
uses: VachaShah/[email protected]
with:
github_token: ${{ steps.github_app_token.outputs.token }}
branch_name: backport/backport-${{ github.event.number }}
head_template: backport/backport-<%= number %>-to-<%= base %>
failure_labels: backport-failed
6 changes: 3 additions & 3 deletions .github/workflows/test_build_multi_platform.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:

- name: Build and Run Tests
run: |
./gradlew build
./gradlew build -x spotlessJava
- name: Publish to Maven Local
run: |
./gradlew publishToMavenLocal
Expand Down Expand Up @@ -91,7 +91,7 @@ jobs:
run: |
chown -R 1000:1000 `pwd`
su `id -un 1000` -c "./gradlew assemble &&
./gradlew build &&
./gradlew build -x spotlessJava &&
./gradlew publishToMavenLocal &&
./gradlew integTest -PnumNodes=3"
- name: Upload Coverage Report
Expand Down Expand Up @@ -127,7 +127,7 @@ jobs:
./gradlew assemble
- name: Build and Run Tests
run: |
./gradlew build
./gradlew build -x spotlessJava
- name: Publish to Maven Local
run: |
./gradlew publishToMavenLocal
Expand Down
11 changes: 7 additions & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ buildscript {
js_resource_folder = "src/test/resources/job-scheduler"
common_utils_version = System.getProperty("common_utils.version", opensearch_build)
job_scheduler_version = System.getProperty("job_scheduler.version", opensearch_build)
bwcVersionShort = "2.10.0"
bwcVersionShort = "2.14.0"
bwcVersion = bwcVersionShort + ".0"
bwcOpenSearchADDownload = 'https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/' + bwcVersionShort + '/latest/linux/x64/tar/builds/' +
'opensearch/plugins/opensearch-anomaly-detection-' + bwcVersion + '.zip'
Expand Down Expand Up @@ -126,9 +126,9 @@ dependencies {
implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2'
implementation group: 'commons-lang', name: 'commons-lang', version: '2.6'
implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.11.1'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:3.8.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:3.8.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:3.8.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.0.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.0.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:4.0.0'

// we inherit jackson-core from opensearch core
implementation "com.fasterxml.jackson.core:jackson-databind:2.16.1"
Expand All @@ -149,6 +149,9 @@ dependencies {
exclude group: 'org.ow2.asm', module: 'asm-tree'
}

// used for output encoding of config descriptions
implementation group: 'org.owasp.encoder' , name: 'encoder', version: '1.2.3'

testImplementation group: 'pl.pragmatists', name: 'JUnitParams', version: '1.1.1'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.9.0'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3'
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/org/opensearch/ad/ml/CheckpointDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ private Optional<ThresholdedRandomCutForest> convertToTRCF(Optional<RandomCutFor
if (kllThreshold.isPresent()) {
scores = kllThreshold.get().extractScores();
}
return Optional.of(new ThresholdedRandomCutForest(rcf.get(), anomalyRate, scores));
// last parameter is lastShingledInput. Since we don't know it, use all 0 double array
return Optional.of(new ThresholdedRandomCutForest(rcf.get(), anomalyRate, scores, new double[rcf.get().getDimensions()]));
}

/**
Expand Down
26 changes: 13 additions & 13 deletions src/main/java/org/opensearch/ad/task/ADTaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,7 @@
import static org.opensearch.ad.constant.ADCommonMessages.NO_ELIGIBLE_NODE_TO_RUN_DETECTOR;
import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX;
import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN;
import static org.opensearch.ad.model.ADTask.COORDINATING_NODE_FIELD;
import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD;
import static org.opensearch.ad.model.ADTask.ERROR_FIELD;
import static org.opensearch.ad.model.ADTask.ESTIMATED_MINUTES_LEFT_FIELD;
import static org.opensearch.ad.model.ADTask.EXECUTION_END_TIME_FIELD;
import static org.opensearch.ad.model.ADTask.EXECUTION_START_TIME_FIELD;
import static org.opensearch.ad.model.ADTask.INIT_PROGRESS_FIELD;
import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD;
import static org.opensearch.ad.model.ADTask.LAST_UPDATE_TIME_FIELD;
import static org.opensearch.ad.model.ADTask.PARENT_TASK_ID_FIELD;
import static org.opensearch.ad.model.ADTask.STATE_FIELD;
import static org.opensearch.ad.model.ADTask.STOPPED_BY_FIELD;
import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD;
import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD;
import static org.opensearch.ad.model.ADTaskType.ALL_HISTORICAL_TASK_TYPES;
import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES;
import static org.opensearch.ad.model.ADTaskType.REALTIME_TASK_TYPES;
Expand All @@ -52,6 +39,19 @@
import static org.opensearch.timeseries.constant.CommonName.TASK_ID_FIELD;
import static org.opensearch.timeseries.model.TaskState.NOT_ENDED_STATES;
import static org.opensearch.timeseries.model.TaskType.taskTypeToString;
import static org.opensearch.timeseries.model.TimeSeriesTask.COORDINATING_NODE_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.ERROR_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.ESTIMATED_MINUTES_LEFT_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.EXECUTION_END_TIME_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.EXECUTION_START_TIME_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.INIT_PROGRESS_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.IS_LATEST_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.LAST_UPDATE_TIME_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.PARENT_TASK_ID_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.STATE_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.STOPPED_BY_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.TASK_PROGRESS_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.TASK_TYPE_FIELD;
import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_MIN_SAMPLES;
import static org.opensearch.timeseries.util.ExceptionUtil.getErrorMessage;
import static org.opensearch.timeseries.util.ExceptionUtil.getShardsFailure;
Expand Down
59 changes: 28 additions & 31 deletions src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -1067,27 +1067,22 @@ public void testDeserializeTRCFModel() throws Exception {
coldStartData.add(sample4);
coldStartData.add(sample5);

// This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them
// to the scores generated by the imported RCF3.0-rc2.1
// This scores were generated with the sample data on RCF4.0. RCF4.0 changed implementation
// and we are seeing different rcf scores between 4.0 and 3.8. This is verified by switching
// rcf version between 3.8 and 4.0 while other code in AD unchanged. But we get different scores.
List<Double> scores = new ArrayList<>();
scores.add(4.814651669367903);
scores.add(5.566968073093689);
scores.add(5.919907610660049);
scores.add(5.770278090352401);
scores.add(5.319779117320102);

List<Double> grade = new ArrayList<>();
grade.add(1.0);
grade.add(0.0);
grade.add(0.0);
grade.add(0.0);
grade.add(0.0);
scores.add(5.052069275347555);
scores.add(6.117465704461799);
scores.add(6.6401649744661055);
scores.add(6.918514609476484);
scores.add(6.928318158276434);

// rcf 3.8 has a number of improvements on thresholder and predictor corrector.
// We don't expect the results have the same anomaly grade.
for (int i = 0; i < coldStartData.size(); i++) {
forest.process(coldStartData.get(i), 0);
AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0);
assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9);
assertEquals(scores.get(i), descriptor.getRCFScore(), 1e-9);
}
}

Expand Down Expand Up @@ -1133,21 +1128,22 @@ public void testDeserialize_rcf3_rc3_single_stream_model() throws Exception {
coldStartData.add(sample4);
coldStartData.add(sample5);

// This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them
// to the scores generated by the imported RCF3.0-rc2.1
// This scores were generated with the sample data on RCF4.0. RCF4.0 changed implementation
// and we are seeing different rcf scores between 4.0 and 3.8. This is verified by switching
// rcf version between 3.8 and 4.0 while other code in AD unchanged. But we get different scores.
List<Double> scores = new ArrayList<>();
scores.add(3.3830441158587066);
scores.add(2.825961659490065);
scores.add(2.4685871670647384);
scores.add(2.3123460886413647);
scores.add(2.1401987653477135);
scores.add(3.678754481587072);
scores.add(3.6809634269790252);
scores.add(3.683659822587799);
scores.add(3.6852688612219646);
scores.add(3.6859330728661064);

// rcf 3.8 has a number of improvements on thresholder and predictor corrector.
// We don't expect the results have the same anomaly grade.
for (int i = 0; i < coldStartData.size(); i++) {
forest.process(coldStartData.get(i), 0);
AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0);
assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9);
assertEquals(scores.get(i), descriptor.getRCFScore(), 1e-9);
}
}

Expand Down Expand Up @@ -1190,21 +1186,22 @@ public void testDeserialize_rcf3_rc3_hc_model() throws Exception {
coldStartData.add(sample4);
coldStartData.add(sample5);

// This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them
// to the scores generated by the imported RCF3.0-rc2.1
// This scores were generated with the sample data but on RCF4.0 that changed implementation
// and we are seeing different rcf scores between 4.0 and 3.8. This is verified by switching
// rcf version between 3.8 and 4.0 while other code in AD unchanged. But we get different scores.
List<Double> scores = new ArrayList<>();
scores.add(1.86645896573027);
scores.add(1.8760247712797833);
scores.add(1.6809181763279901);
scores.add(1.7126716645678555);
scores.add(1.323776514074674);
scores.add(2.119532552959117);
scores.add(2.7347456872746325);
scores.add(3.066704948143919);
scores.add(3.2965580521876725);
scores.add(3.1888920146607047);

// rcf 3.8 has a number of improvements on thresholder and predictor corrector.
// We don't expect the results have the same anomaly grade.
for (int i = 0; i < coldStartData.size(); i++) {
forest.process(coldStartData.get(i), 0);
AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0);
assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9);
assertEquals(scores.get(i), descriptor.getRCFScore(), 1e-9);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ public void testAccuracyOneMinuteIntervalNoInterpolation() throws Exception {
clusterService
);

accuracyTemplate(1, 0.6f, 0.6f);
accuracyTemplate(1, 0.5f, 0.5f);
}

private ModelState<EntityModel> createStateForCacheRelease() {
Expand Down
25 changes: 19 additions & 6 deletions src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.ad.ADUnitTestCase;
import org.opensearch.ad.cluster.HashRing;
import org.opensearch.ad.indices.ADIndexManagement;
import org.opensearch.ad.mock.model.MockSimpleLog;
Expand All @@ -89,6 +88,7 @@
import org.opensearch.ad.model.ADTaskType;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.ad.stats.InternalStatNames;
import org.opensearch.ad.transport.ADStatsNodeResponse;
import org.opensearch.ad.transport.ADStatsNodesResponse;
Expand Down Expand Up @@ -120,6 +120,7 @@
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.AbstractTimeSeriesTest;
import org.opensearch.timeseries.TestHelpers;
import org.opensearch.timeseries.common.exception.DuplicateTaskException;
import org.opensearch.timeseries.constant.CommonName;
Expand All @@ -139,7 +140,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

public class ADTaskManagerTests extends ADUnitTestCase {
public class ADTaskManagerTests extends AbstractTimeSeriesTest {

private Settings settings;
private Client client;
Expand Down Expand Up @@ -1447,10 +1448,22 @@ public void testForwardRequestToLeadNodeWithNotExistingNode() throws IOException
@SuppressWarnings("unchecked")
public void testScaleTaskLaneOnCoordinatingNode() {
ADTask adTask = mock(ADTask.class);
when(adTask.getCoordinatingNode()).thenReturn(node1.getId());
when(nodeFilter.getEligibleDataNodes()).thenReturn(new DiscoveryNode[] { node1, node2 });
ActionListener<JobResponse> listener = mock(ActionListener.class);
adTaskManager.scaleTaskLaneOnCoordinatingNode(adTask, 2, transportService, listener);
try {
// bring up real transport service as mockito cannot mock final method
// and transportService.sendRequest is called. A lot of null pointer
// exception will be thrown if we use mocked transport service.
setUpThreadPool(ADTaskManagerTests.class.getSimpleName());
setupTestNodes(AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.AD_PAGE_SIZE);
when(adTask.getCoordinatingNode()).thenReturn(testNodes[1].getNodeId());
when(nodeFilter.getEligibleDataNodes())
.thenReturn(new DiscoveryNode[] { testNodes[0].discoveryNode(), testNodes[1].discoveryNode() });
ActionListener<JobResponse> listener = mock(ActionListener.class);

adTaskManager.scaleTaskLaneOnCoordinatingNode(adTask, 2, testNodes[1].transportService, listener);
} finally {
tearDownTestNodes();
tearDownThreadPool();
}
}

@SuppressWarnings("unchecked")
Expand Down
Loading

0 comments on commit 0dd8f56

Please sign in to comment.