diff --git a/build.gradle b/build.gradle index 06bade6d4..40d7142f9 100644 --- a/build.gradle +++ b/build.gradle @@ -126,9 +126,9 @@ dependencies { implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2' implementation group: 'commons-lang', name: 'commons-lang', version: '2.6' implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.11.1' - implementation 'software.amazon.randomcutforest:randomcutforest-serialization:3.8.0' - implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:3.8.0' - implementation 'software.amazon.randomcutforest:randomcutforest-core:3.8.0' + implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.0.0' + implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.0.0' + implementation 'software.amazon.randomcutforest:randomcutforest-core:4.0.0' // we inherit jackson-core from opensearch core implementation "com.fasterxml.jackson.core:jackson-databind:2.16.1" @@ -149,6 +149,9 @@ dependencies { exclude group: 'org.ow2.asm', module: 'asm-tree' } + // used for output encoding of config descriptions + implementation group: 'org.owasp.encoder' , name: 'encoder', version: '1.2.3' + testImplementation group: 'pl.pragmatists', name: 'JUnitParams', version: '1.1.1' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.9.0' testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3' @@ -408,6 +411,7 @@ testClusters.integTest { @Override File getAsFile() { return configurations.zipArchive.asFileTree.getSingleFile() + //return fileTree("src/test/resources/job-scheduler").getSingleFile() } } } diff --git a/src/main/java/org/opensearch/ad/ADEntityProfileRunner.java b/src/main/java/org/opensearch/ad/ADEntityProfileRunner.java new file mode 100644 index 000000000..897c853f1 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ADEntityProfileRunner.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.ad.transport.ADEntityProfileAction; +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.EntityProfileRunner; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class ADEntityProfileRunner extends EntityProfileRunner { + + public ADEntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples + ) { + super( + client, + clientUtil, + xContentRegistry, + requiredSamples, + AnomalyDetector::parse, + ADNumericSetting.maxCategoricalFields(), + AnalysisType.AD, + ADEntityProfileAction.INSTANCE, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + AnomalyResult.DETECTOR_ID_FIELD + ); + } +} diff --git a/src/main/java/org/opensearch/ad/ADJobProcessor.java b/src/main/java/org/opensearch/ad/ADJobProcessor.java new file mode 100644 index 000000000..4492d3708 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ADJobProcessor.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad; + +import java.time.Instant; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.ADProfileAction; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.JobProcessor; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.ResultRequest; + +public class ADJobProcessor extends + JobProcessor { + + private static final Logger log = LogManager.getLogger(ADJobProcessor.class); + + private static ADJobProcessor INSTANCE; + + public static ADJobProcessor getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (ADJobProcessor.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new ADJobProcessor(); + return INSTANCE; + } + } + + private ADJobProcessor() { + // Singleton class, use getJobRunnerInstance method instead of constructor + super(AnalysisType.AD, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, AnomalyResultAction.INSTANCE); + } + + public void registerSettings(Settings settings) { + super.registerSettings(settings, AnomalyDetectorSettings.AD_MAX_RETRY_FOR_END_RUN_EXCEPTION); + } + + @Override + protected ResultRequest createResultRequest(String configId, long start, long end) { + return new AnomalyResultRequest(configId, start, end); + } + + @Override + protected void validateResultIndexAndRunJob( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + String configId, + String user, + List roles, + ExecuteADResultResponseRecorder recorder, + Config detector + ) { + String resultIndex = jobParameter.getCustomResultIndex(); + if (resultIndex == null) { + runJob(jobParameter, lockService, lock, executionStartTime, executionEndTime, configId, user, roles, recorder, detector); + return; + } + ActionListener listener = ActionListener.wrap(r -> { log.debug("Custom index is valid"); }, e -> { + Exception exception = new EndRunException(configId, e.getMessage(), false); + handleException(jobParameter, lockService, lock, executionStartTime, executionEndTime, exception, recorder, detector); + }); + indexManagement.validateCustomIndexForBackendJob(resultIndex, configId, user, roles, () -> { + listener.onResponse(true); + runJob(jobParameter, lockService, lock, executionStartTime, executionEndTime, configId, user, roles, recorder, detector); + }, listener); + } +} diff --git a/src/main/java/org/opensearch/ad/ADTaskProfileRunner.java b/src/main/java/org/opensearch/ad/ADTaskProfileRunner.java new file mode 100644 index 000000000..6bad4935c --- /dev/null +++ b/src/main/java/org/opensearch/ad/ADTaskProfileRunner.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.ad.transport.ADTaskProfileAction; +import org.opensearch.ad.transport.ADTaskProfileNodeResponse; +import org.opensearch.ad.transport.ADTaskProfileRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.TaskProfileRunner; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.model.EntityTaskProfile; + +public class ADTaskProfileRunner implements TaskProfileRunner { + public final Logger logger = LogManager.getLogger(ADTaskProfileRunner.class); + + private final HashRing hashRing; + private final Client client; + + public ADTaskProfileRunner(HashRing hashRing, Client client) { + this.hashRing = hashRing; + this.client = client; + } + + @Override + public void getTaskProfile(ADTask configLevelTask, ActionListener listener) { + String detectorId = configLevelTask.getConfigId(); + + hashRing.getAllEligibleDataNodesWithKnownVersion(dataNodes -> { + ADTaskProfileRequest adTaskProfileRequest = new ADTaskProfileRequest(detectorId, dataNodes); + client.execute(ADTaskProfileAction.INSTANCE, adTaskProfileRequest, ActionListener.wrap(response -> { + if (response.hasFailures()) { + listener.onFailure(response.failures().get(0)); + return; + } + + List adEntityTaskProfiles = new ArrayList<>(); + ADTaskProfile detectorTaskProfile = new ADTaskProfile(configLevelTask); + for (ADTaskProfileNodeResponse node : response.getNodes()) { + ADTaskProfile taskProfile = node.getAdTaskProfile(); + if (taskProfile != null) { + if (taskProfile.getNodeId() != null) { + // HC detector: task profile from coordinating node + // Single entity detector: task profile from worker node + detectorTaskProfile.setTaskId(taskProfile.getTaskId()); + detectorTaskProfile.setShingleSize(taskProfile.getShingleSize()); + detectorTaskProfile.setRcfTotalUpdates(taskProfile.getRcfTotalUpdates()); + detectorTaskProfile.setThresholdModelTrained(taskProfile.getThresholdModelTrained()); + detectorTaskProfile.setThresholdModelTrainingDataSize(taskProfile.getThresholdModelTrainingDataSize()); + detectorTaskProfile.setModelSizeInBytes(taskProfile.getModelSizeInBytes()); + detectorTaskProfile.setNodeId(taskProfile.getNodeId()); + detectorTaskProfile.setTotalEntitiesCount(taskProfile.getTotalEntitiesCount()); + detectorTaskProfile.setDetectorTaskSlots(taskProfile.getDetectorTaskSlots()); + detectorTaskProfile.setPendingEntitiesCount(taskProfile.getPendingEntitiesCount()); + detectorTaskProfile.setRunningEntitiesCount(taskProfile.getRunningEntitiesCount()); + detectorTaskProfile.setRunningEntities(taskProfile.getRunningEntities()); + detectorTaskProfile.setTaskType(taskProfile.getTaskType()); + } + if (taskProfile.getEntityTaskProfiles() != null) { + adEntityTaskProfiles.addAll(taskProfile.getEntityTaskProfiles()); + } + } + } + if (adEntityTaskProfiles != null && adEntityTaskProfiles.size() > 0) { + detectorTaskProfile.setEntityTaskProfiles(adEntityTaskProfiles); + } + listener.onResponse(detectorTaskProfile); + }, e -> { + logger.error("Failed to get task profile for task " + configLevelTask.getTaskId(), e); + listener.onFailure(e); + })); + }, listener); + + } + +} diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java deleted file mode 100644 index 98135e1ee..000000000 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java +++ /dev/null @@ -1,653 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad; - -import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME; -import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; - -import java.io.IOException; -import java.time.Instant; -import java.util.List; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutorService; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyResultAction; -import org.opensearch.ad.transport.AnomalyResultRequest; -import org.opensearch.ad.transport.AnomalyResultResponse; -import org.opensearch.ad.transport.AnomalyResultTransportAction; -import org.opensearch.client.Client; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.commons.InjectSecurity; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.jobscheduler.spi.JobExecutionContext; -import org.opensearch.jobscheduler.spi.LockModel; -import org.opensearch.jobscheduler.spi.ScheduledJobParameter; -import org.opensearch.jobscheduler.spi.ScheduledJobRunner; -import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; -import org.opensearch.jobscheduler.spi.utils.LockService; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.AnalysisType; -import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.InternalFailure; -import org.opensearch.timeseries.common.exception.TimeSeriesException; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.function.ExecutorFunction; -import org.opensearch.timeseries.model.Job; -import org.opensearch.timeseries.model.TaskState; -import org.opensearch.timeseries.util.SecurityUtil; - -import com.google.common.base.Throwables; - -/** - * JobScheduler will call AD job runner to get anomaly result periodically - */ -public class AnomalyDetectorJobRunner implements ScheduledJobRunner { - private static final Logger log = LogManager.getLogger(AnomalyDetectorJobRunner.class); - private static AnomalyDetectorJobRunner INSTANCE; - private Settings settings; - private int maxRetryForEndRunException; - private Client client; - private ThreadPool threadPool; - private ConcurrentHashMap detectorEndRunExceptionCount; - private ADIndexManagement anomalyDetectionIndices; - private ADTaskManager adTaskManager; - private NodeStateManager nodeStateManager; - private ExecuteADResultResponseRecorder recorder; - - public static AnomalyDetectorJobRunner getJobRunnerInstance() { - if (INSTANCE != null) { - return INSTANCE; - } - synchronized (AnomalyDetectorJobRunner.class) { - if (INSTANCE != null) { - return INSTANCE; - } - INSTANCE = new AnomalyDetectorJobRunner(); - return INSTANCE; - } - } - - private AnomalyDetectorJobRunner() { - // Singleton class, use getJobRunnerInstance method instead of constructor - this.detectorEndRunExceptionCount = new ConcurrentHashMap<>(); - } - - public void setClient(Client client) { - this.client = client; - } - - public void setThreadPool(ThreadPool threadPool) { - this.threadPool = threadPool; - } - - public void setSettings(Settings settings) { - this.settings = settings; - this.maxRetryForEndRunException = AnomalyDetectorSettings.AD_MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings); - } - - public void setAdTaskManager(ADTaskManager adTaskManager) { - this.adTaskManager = adTaskManager; - } - - public void setAnomalyDetectionIndices(ADIndexManagement anomalyDetectionIndices) { - this.anomalyDetectionIndices = anomalyDetectionIndices; - } - - public void setNodeStateManager(NodeStateManager nodeStateManager) { - this.nodeStateManager = nodeStateManager; - } - - public void setExecuteADResultResponseRecorder(ExecuteADResultResponseRecorder recorder) { - this.recorder = recorder; - } - - @Override - public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionContext context) { - String detectorId = scheduledJobParameter.getName(); - log.info("Start to run AD job {}", detectorId); - adTaskManager.refreshRealtimeJobRunTime(detectorId); - if (!(scheduledJobParameter instanceof Job)) { - throw new IllegalArgumentException( - "Job parameter is not instance of Job, type: " + scheduledJobParameter.getClass().getCanonicalName() - ); - } - Job jobParameter = (Job) scheduledJobParameter; - Instant executionStartTime = Instant.now(); - IntervalSchedule schedule = (IntervalSchedule) jobParameter.getSchedule(); - Instant detectionStartTime = executionStartTime.minus(schedule.getInterval(), schedule.getUnit()); - - final LockService lockService = context.getLockService(); - - Runnable runnable = () -> { - try { - nodeStateManager.getConfig(detectorId, AnalysisType.AD, ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId)); - return; - } - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); - - if (jobParameter.getLockDurationSeconds() != null) { - lockService - .acquireLock( - jobParameter, - context, - ActionListener - .wrap( - lock -> runAdJob( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - recorder, - detector - ), - exception -> { - indexAnomalyResultException( - jobParameter, - lockService, - null, - detectionStartTime, - executionStartTime, - exception, - false, - recorder, - detector - ); - throw new IllegalStateException("Failed to acquire lock for AD job: " + detectorId); - } - ) - ); - } else { - log.warn("Can't get lock for AD job: " + detectorId); - } - - }, e -> log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), e))); - } catch (Exception e) { - // os log won't show anything if there is an exception happens (maybe due to running on a ExecutorService) - // we at least log the error. - log.error("Can't start AD job: " + detectorId, e); - throw e; - } - }; - - ExecutorService executor = threadPool.executor(AD_THREAD_POOL_NAME); - executor.submit(runnable); - } - - /** - * Get anomaly result, index result or handle exception if failed. - * - * @param jobParameter scheduled job parameter - * @param lockService lock service - * @param lock lock to run job - * @param detectionStartTime detection start time - * @param executionStartTime detection end time - * @param recorder utility to record job execution result - * @param detector associated detector accessor - */ - protected void runAdJob( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - String detectorId = jobParameter.getName(); - if (lock == null) { - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - "Can't run AD job due to null lock", - false, - recorder, - detector - ); - return; - } - anomalyDetectionIndices.update(); - - User userInfo = SecurityUtil.getUserFromJob(jobParameter, settings); - - String user = userInfo.getName(); - List roles = userInfo.getRoles(); - - String resultIndex = jobParameter.getCustomResultIndex(); - if (resultIndex == null) { - runAnomalyDetectionJob( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - detectorId, - user, - roles, - recorder, - detector - ); - return; - } - ActionListener listener = ActionListener.wrap(r -> { log.debug("Custom index is valid"); }, e -> { - Exception exception = new EndRunException(detectorId, e.getMessage(), true); - handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, exception, recorder, detector); - }); - anomalyDetectionIndices.validateCustomIndexForBackendJob(resultIndex, detectorId, user, roles, () -> { - listener.onResponse(true); - runAnomalyDetectionJob( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - detectorId, - user, - roles, - recorder, - detector - ); - }, listener); - } - - private void runAnomalyDetectionJob( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - String detectorId, - String user, - List roles, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - // using one thread in the write threadpool - try (InjectSecurity injectSecurity = new InjectSecurity(detectorId, settings, client.threadPool().getThreadContext())) { - // Injecting user role to verify if the user has permissions for our API. - injectSecurity.inject(user, roles); - - AnomalyResultRequest request = new AnomalyResultRequest( - detectorId, - detectionStartTime.toEpochMilli(), - executionStartTime.toEpochMilli() - ); - client.execute(AnomalyResultAction.INSTANCE, request, ActionListener.wrap(response -> { - indexAnomalyResult(jobParameter, lockService, lock, detectionStartTime, executionStartTime, response, recorder, detector); - }, exception -> { - handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, exception, recorder, detector); - })); - } catch (Exception e) { - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - e, - true, - recorder, - detector - ); - log.error("Failed to execute AD job " + detectorId, e); - } - } - - /** - * Handle exception from anomaly result action. - * - * 1. If exception is {@link EndRunException} - * a). if isEndNow == true, stop AD job and store exception in anomaly result - * b). if isEndNow == false, record count of {@link EndRunException} for this - * detector. If count of {@link EndRunException} exceeds upper limit, will - * stop AD job and store exception in anomaly result; otherwise, just - * store exception in anomaly result, not stop AD job for the detector. - * - * 2. If exception is not {@link EndRunException}, decrease count of - * {@link EndRunException} for the detector and index eception in Anomaly - * result. If exception is {@link InternalFailure}, will not log exception - * stack trace as already logged in {@link AnomalyResultTransportAction}. - * - * TODO: Handle finer granularity exception such as some exception may be - * transient and retry in current job may succeed. Currently, we don't - * know which exception is transient and retryable in - * {@link AnomalyResultTransportAction}. So we don't add backoff retry - * now to avoid bring extra load to cluster, expecially the code start - * process is relatively heavy by sending out 24 queries, initializing - * models, and saving checkpoints. - * Sometimes missing anomaly and notification is not acceptable. For example, - * current detection interval is 1hour, and there should be anomaly in - * current interval, some transient exception may fail current AD job, - * so no anomaly found and user never know it. Then we start next AD job, - * maybe there is no anomaly in next 1hour, user will never know something - * wrong happened. In one word, this is some tradeoff between protecting - * our performance, user experience and what we can do currently. - * - * @param jobParameter scheduled job parameter - * @param lockService lock service - * @param lock lock to run job - * @param detectionStartTime detection start time - * @param executionStartTime detection end time - * @param exception exception - * @param recorder utility to record job execution result - * @param detector associated detector accessor - */ - protected void handleAdException( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - Exception exception, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - String detectorId = jobParameter.getName(); - if (exception instanceof EndRunException) { - log.error("EndRunException happened when executing anomaly result action for " + detectorId, exception); - - if (((EndRunException) exception).isEndNow()) { - // Stop AD job if EndRunException shows we should end job now. - log.info("JobRunner will stop AD job due to EndRunException for {}", detectorId); - stopAdJobForEndRunException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - (EndRunException) exception, - recorder, - detector - ); - } else { - detectorEndRunExceptionCount.compute(detectorId, (k, v) -> { - if (v == null) { - return 1; - } else { - return v + 1; - } - }); - log.info("EndRunException happened for {}", detectorId); - // if AD job failed consecutively due to EndRunException and failed times exceeds upper limit, will stop AD job - if (detectorEndRunExceptionCount.get(detectorId) > maxRetryForEndRunException) { - log - .info( - "JobRunner will stop AD job due to EndRunException retry exceeds upper limit {} for {}", - maxRetryForEndRunException, - detectorId - ); - stopAdJobForEndRunException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - (EndRunException) exception, - recorder, - detector - ); - return; - } - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - exception.getMessage(), - true, - recorder, - detector - ); - } - } else { - detectorEndRunExceptionCount.remove(detectorId); - if (exception instanceof InternalFailure) { - log.error("InternalFailure happened when executing anomaly result action for " + detectorId, exception); - } else { - log.error("Failed to execute anomaly result action for " + detectorId, exception); - } - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - exception, - true, - recorder, - detector - ); - } - } - - private void stopAdJobForEndRunException( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - EndRunException exception, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - String detectorId = jobParameter.getName(); - detectorEndRunExceptionCount.remove(detectorId); - String errorPrefix = exception.isEndNow() - ? "Stopped detector: " - : "Stopped detector as job failed consecutively for more than " + this.maxRetryForEndRunException + " times: "; - String error = errorPrefix + exception.getMessage(); - stopAdJob( - detectorId, - () -> indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - error, - true, - TaskState.STOPPED.name(), - recorder, - detector - ) - ); - } - - private void stopAdJob(String detectorId, ExecutorFunction function) { - GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - ActionListener listener = ActionListener.wrap(response -> { - if (response.isExists()) { - try ( - XContentParser parser = XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, response.getSourceAsString()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job job = Job.parse(parser); - if (job.isEnabled()) { - Job newJob = new Job( - job.getName(), - job.getSchedule(), - job.getWindowDelay(), - false, - job.getEnabledTime(), - Instant.now(), - Instant.now(), - job.getLockDurationSeconds(), - job.getUser(), - job.getCustomResultIndex() - ); - IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .source(newJob.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)) - .id(detectorId); - - client.index(indexRequest, ActionListener.wrap(indexResponse -> { - if (indexResponse != null && (indexResponse.getResult() == CREATED || indexResponse.getResult() == UPDATED)) { - log.info("AD Job was disabled by JobRunner for " + detectorId); - // function.execute(); - } else { - log.warn("Failed to disable AD job for " + detectorId); - } - }, exception -> { log.error("JobRunner failed to update AD job as disabled for " + detectorId, exception); })); - } else { - log.info("AD Job was disabled for " + detectorId); - } - } catch (IOException e) { - log.error("JobRunner failed to stop detector job " + detectorId, e); - } - } else { - log.info("AD Job was not found for " + detectorId); - } - }, exception -> log.error("JobRunner failed to get detector job " + detectorId, exception)); - - client.get(getRequest, ActionListener.runAfter(listener, () -> function.execute())); - } - - private void indexAnomalyResult( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - AnomalyResultResponse response, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - String detectorId = jobParameter.getName(); - detectorEndRunExceptionCount.remove(detectorId); - try { - recorder.indexAnomalyResult(detectionStartTime, executionStartTime, response, detector); - } catch (EndRunException e) { - handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, e, recorder, detector); - } catch (Exception e) { - log.error("Failed to index anomaly result for " + detectorId, e); - } finally { - releaseLock(jobParameter, lockService, lock); - } - - } - - private void indexAnomalyResultException( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - Exception exception, - boolean releaseLock, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - try { - String errorMessage = exception instanceof TimeSeriesException - ? exception.getMessage() - : Throwables.getStackTraceAsString(exception); - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - errorMessage, - releaseLock, - recorder, - detector - ); - } catch (Exception e) { - log.error("Failed to index anomaly result for " + jobParameter.getName(), e); - } - } - - private void indexAnomalyResultException( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - String errorMessage, - boolean releaseLock, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - errorMessage, - releaseLock, - null, - recorder, - detector - ); - } - - private void indexAnomalyResultException( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - String errorMessage, - boolean releaseLock, - String taskState, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - try { - recorder.indexAnomalyResultException(detectionStartTime, executionStartTime, errorMessage, taskState, detector); - } finally { - if (releaseLock) { - releaseLock(jobParameter, lockService, lock); - } - } - } - - private void releaseLock(Job jobParameter, LockService lockService, LockModel lock) { - lockService - .release( - lock, - ActionListener.wrap(released -> { log.info("Released lock for AD job {}", jobParameter.getName()); }, exception -> { - log.error("Failed to release lock for AD job: " + jobParameter.getName(), exception); - }) - ); - } -} diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java index 582f3e39a..5119220b9 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java @@ -11,80 +11,56 @@ package org.opensearch.ad; -import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; -import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_PARSE_CONFIG_MSG; -import java.util.List; -import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.core.util.Throwables; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.DetectorState; -import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileRequest; -import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingRequest; import org.opensearch.ad.transport.RCFPollingResponse; import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.search.SearchHits; -import org.opensearch.search.aggregations.Aggregation; -import org.opensearch.search.aggregations.AggregationBuilder; -import org.opensearch.search.aggregations.AggregationBuilders; -import org.opensearch.search.aggregations.Aggregations; -import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; -import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; -import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; -import org.opensearch.search.aggregations.metrics.InternalCardinality; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ProfileRunner; import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigState; import org.opensearch.timeseries.model.Job; -import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.model.ProfileName; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.ExceptionUtil; import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; -public class AnomalyDetectorProfileRunner extends AbstractProfileRunner { +public class AnomalyDetectorProfileRunner extends + ProfileRunner { + private final Logger logger = LogManager.getLogger(AnomalyDetectorProfileRunner.class); - private Client client; - private SecurityClientUtil clientUtil; - private NamedXContentRegistry xContentRegistry; - private DiscoveryNodeFilterer nodeFilter; - private final TransportService transportService; - private final ADTaskManager adTaskManager; - private final int maxTotalEntitiesToTrack; public AnomalyDetectorProfileRunner( Client client, @@ -93,300 +69,133 @@ public AnomalyDetectorProfileRunner( DiscoveryNodeFilterer nodeFilter, long requiredSamples, TransportService transportService, - ADTaskManager adTaskManager + ADTaskManager adTaskManager, + ADTaskProfileRunner taskProfileRunner ) { - super(requiredSamples); - this.client = client; - this.clientUtil = clientUtil; - this.xContentRegistry = xContentRegistry; - this.nodeFilter = nodeFilter; - if (requiredSamples <= 0) { - throw new IllegalArgumentException("required samples should be a positive number, but was " + requiredSamples); - } - this.transportService = transportService; - this.adTaskManager = adTaskManager; - this.maxTotalEntitiesToTrack = TimeSeriesSettings.MAX_TOTAL_ENTITIES_TO_TRACK; + super( + client, + clientUtil, + xContentRegistry, + nodeFilter, + requiredSamples, + transportService, + adTaskManager, + AnalysisType.AD, + ADTaskType.REALTIME_TASK_TYPES, + ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES, + ADNumericSetting.maxCategoricalFields(), + ProfileName.AD_TASK, + ADProfileAction.INSTANCE, + AnomalyDetector::parse, + taskProfileRunner + ); } - public void profile(String detectorId, ActionListener listener, Set profilesToCollect) { - if (profilesToCollect.isEmpty()) { - listener.onFailure(new IllegalArgumentException(CommonMessages.EMPTY_PROFILES_COLLECT)); - return; - } - calculateTotalResponsesToWait(detectorId, profilesToCollect, listener); - } - - private void calculateTotalResponsesToWait( - String detectorId, - Set profilesToCollect, - ActionListener listener - ) { - GetRequest getDetectorRequest = new GetRequest(CommonName.CONFIG_INDEX, detectorId); - client.get(getDetectorRequest, ActionListener.wrap(getDetectorResponse -> { - if (getDetectorResponse != null && getDetectorResponse.isExists()) { - try ( - XContentParser xContentParser = XContentType.JSON - .xContent() - .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getDetectorResponse.getSourceAsString()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser); - AnomalyDetector detector = AnomalyDetector.parse(xContentParser, detectorId); - prepareProfile(detector, listener, profilesToCollect); - } catch (Exception e) { - logger.error(FAIL_TO_PARSE_CONFIG_MSG + detectorId, e); - listener.onFailure(new OpenSearchStatusException(FAIL_TO_PARSE_CONFIG_MSG + detectorId, BAD_REQUEST)); - } - } else { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, BAD_REQUEST)); - } - }, exception -> { - logger.error(FAIL_TO_FIND_CONFIG_MSG + detectorId, exception); - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, INTERNAL_SERVER_ERROR)); - })); + @Override + protected DetectorProfile.Builder createProfileBuilder() { + return new DetectorProfile.Builder(); } - private void prepareProfile( - AnomalyDetector detector, - ActionListener listener, - Set profilesToCollect - ) { - String detectorId = detector.getId(); - GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX, detectorId); - client.get(getRequest, ActionListener.wrap(getResponse -> { - if (getResponse != null && getResponse.isExists()) { - try ( - XContentParser parser = XContentType.JSON - .xContent() - .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job job = Job.parse(parser); - long enabledTimeMs = job.getEnabledTime().toEpochMilli(); - - boolean isMultiEntityDetector = detector.isHighCardinality(); - - int totalResponsesToWait = 0; - if (profilesToCollect.contains(DetectorProfileName.ERROR)) { - totalResponsesToWait++; - } - - // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide - // when to consolidate results and return to users - if (isMultiEntityDetector) { - if (profilesToCollect.contains(DetectorProfileName.TOTAL_ENTITIES)) { - totalResponsesToWait++; - } - if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) - || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) - || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) - || profilesToCollect.contains(DetectorProfileName.MODELS) - || profilesToCollect.contains(DetectorProfileName.ACTIVE_ENTITIES) - || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS) - || profilesToCollect.contains(DetectorProfileName.STATE)) { - totalResponsesToWait++; - } - if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { + @Override + protected void prepareProfile(Config config, ActionListener listener, Set profilesToCollect) { + boolean isHC = config.isHighCardinality(); + if (isHC) { + super.prepareProfile(config, listener, profilesToCollect); + } else { + String configId = config.getId(); + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX, configId); + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (getResponse != null && getResponse.isExists()) { + try ( + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job job = Job.parse(parser); + long enabledTimeMs = job.getEnabledTime().toEpochMilli(); + + int totalResponsesToWait = 0; + if (profilesToCollect.contains(ProfileName.ERROR)) { totalResponsesToWait++; } - } else { - if (profilesToCollect.contains(DetectorProfileName.STATE) - || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { + + // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide + // when to consolidate results and return to users + + if (profilesToCollect.contains(ProfileName.STATE) || profilesToCollect.contains(ProfileName.INIT_PROGRESS)) { totalResponsesToWait++; } - if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) - || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) - || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) - || profilesToCollect.contains(DetectorProfileName.MODELS)) { + if (profilesToCollect.contains(ProfileName.COORDINATING_NODE) + || profilesToCollect.contains(ProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(ProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(ProfileName.MODELS)) { totalResponsesToWait++; } - if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { + if (profilesToCollect.contains(ProfileName.AD_TASK)) { totalResponsesToWait++; } - } - - MultiResponsesDelegateActionListener delegateListener = - new MultiResponsesDelegateActionListener( - listener, - totalResponsesToWait, - CommonMessages.FAIL_FETCH_ERR_MSG + detectorId, - false - ); - if (profilesToCollect.contains(DetectorProfileName.ERROR)) { - adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, ADTaskType.REALTIME_TASK_TYPES, adTask -> { - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - if (adTask.isPresent()) { - long lastUpdateTimeMs = adTask.get().getLastUpdateTime().toEpochMilli(); - // if state index hasn't been updated, we should not use the error field - // For example, before a detector is enabled, if the error message contains - // the phrase "stopped due to blah", we should not show this when the detector - // is enabled. - if (lastUpdateTimeMs > enabledTimeMs && adTask.get().getError() != null) { - profileBuilder.error(adTask.get().getError()); + MultiResponsesDelegateActionListener delegateListener = + new MultiResponsesDelegateActionListener( + listener, + totalResponsesToWait, + CommonMessages.FAIL_FETCH_ERR_MSG + configId, + false + ); + if (profilesToCollect.contains(ProfileName.ERROR)) { + taskManager.getAndExecuteOnLatestConfigLevelTask(configId, realTimeTaskTypes, task -> { + DetectorProfile.Builder profileBuilder = createProfileBuilder(); + if (task.isPresent()) { + long lastUpdateTimeMs = task.get().getLastUpdateTime().toEpochMilli(); + + // if state index hasn't been updated, we should not use the error field + // For example, before a detector is enabled, if the error message contains + // the phrase "stopped due to blah", we should not show this when the detector + // is enabled. + if (lastUpdateTimeMs > enabledTimeMs && task.get().getError() != null) { + profileBuilder.error(task.get().getError()); + } + delegateListener.onResponse(profileBuilder.build()); + } else { + // detector state for this detector does not exist + delegateListener.onResponse(profileBuilder.build()); } - delegateListener.onResponse(profileBuilder.build()); - } else { - // detector state for this detector does not exist - delegateListener.onResponse(profileBuilder.build()); - } - }, transportService, false, delegateListener); - } - - // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide - // when to consolidate results and return to users - if (isMultiEntityDetector) { - if (profilesToCollect.contains(DetectorProfileName.TOTAL_ENTITIES)) { - profileEntityStats(delegateListener, detector); - } - if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) - || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) - || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) - || profilesToCollect.contains(DetectorProfileName.MODELS) - || profilesToCollect.contains(DetectorProfileName.ACTIVE_ENTITIES) - || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS) - || profilesToCollect.contains(DetectorProfileName.STATE)) { - profileModels(detector, profilesToCollect, job, true, delegateListener); + }, transportService, false, delegateListener); } - if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { - adTaskManager.getLatestHistoricalTaskProfile(detectorId, transportService, null, delegateListener); - } - } else { - if (profilesToCollect.contains(DetectorProfileName.STATE) - || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { - profileStateRelated(detector, delegateListener, job.isEnabled(), profilesToCollect); + + // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide + // when to consolidate results and return to users + + if (profilesToCollect.contains(ProfileName.STATE) || profilesToCollect.contains(ProfileName.INIT_PROGRESS)) { + profileStateRelated(config, delegateListener, job.isEnabled(), profilesToCollect); } - if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) - || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) - || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) - || profilesToCollect.contains(DetectorProfileName.MODELS)) { - profileModels(detector, profilesToCollect, job, false, delegateListener); + if (profilesToCollect.contains(ProfileName.COORDINATING_NODE) + || profilesToCollect.contains(ProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(ProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(ProfileName.MODELS)) { + profileModels(config, profilesToCollect, job, false, delegateListener); } - if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { - adTaskManager.getLatestHistoricalTaskProfile(detectorId, transportService, null, delegateListener); + if (profilesToCollect.contains(ProfileName.AD_TASK)) { + getLatestHistoricalTaskProfile(configId, transportService, null, delegateListener); } - } - - } catch (Exception e) { - logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG, e); - listener.onFailure(e); - } - } else { - onGetDetectorForPrepare(detectorId, listener, profilesToCollect); - } - }, exception -> { - if (ExceptionUtil.isIndexNotAvailable(exception)) { - logger.info(exception.getMessage()); - onGetDetectorForPrepare(detectorId, listener, profilesToCollect); - } else { - logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG + detectorId); - listener.onFailure(exception); - } - })); - } - private void profileEntityStats(MultiResponsesDelegateActionListener listener, AnomalyDetector detector) { - List categoryField = detector.getCategoryFields(); - if (!detector.isHighCardinality() || categoryField.size() > ADNumericSetting.maxCategoricalFields()) { - listener.onResponse(new DetectorProfile.Builder().build()); - } else { - if (categoryField.size() == 1) { - // Run a cardinality aggregation to count the cardinality of single category fields - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - CardinalityAggregationBuilder aggBuilder = new CardinalityAggregationBuilder(ADCommonName.TOTAL_ENTITIES); - aggBuilder.field(categoryField.get(0)); - searchSourceBuilder.aggregation(aggBuilder); - - SearchRequest request = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { - Map aggMap = searchResponse.getAggregations().asMap(); - InternalCardinality totalEntities = (InternalCardinality) aggMap.get(ADCommonName.TOTAL_ENTITIES); - long value = totalEntities.getValue(); - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - DetectorProfile profile = profileBuilder.totalEntities(value).build(); - listener.onResponse(profile); - }, searchException -> { - logger.warn(CommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getId()); - listener.onFailure(searchException); - }); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - request, - client::search, - detector.getId(), - client, - AnalysisType.AD, - searchResponseListener - ); - } else { - // Run a composite query and count the number of buckets to decide cardinality of multiple category fields - AggregationBuilder bucketAggs = AggregationBuilders - .composite( - ADCommonName.TOTAL_ENTITIES, - detector - .getCategoryFields() - .stream() - .map(f -> new TermsValuesSourceBuilder(f).field(f)) - .collect(Collectors.toList()) - ) - .size(maxTotalEntitiesToTrack); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(bucketAggs).trackTotalHits(false).size(0); - SearchRequest searchRequest = new SearchRequest() - .indices(detector.getIndices().toArray(new String[0])) - .source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - Aggregations aggs = searchResponse.getAggregations(); - if (aggs == null) { - // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date - // with - // the large amounts of changes there). For example, they may change to if there are results return it; otherwise - // return - // null instead of an empty Aggregations as they currently do. - logger.warn("Unexpected null aggregation."); - listener.onResponse(profileBuilder.totalEntities(0L).build()); - return; + } catch (Exception e) { + logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG, e); + listener.onFailure(e); } - - Aggregation aggrResult = aggs.get(ADCommonName.TOTAL_ENTITIES); - if (aggrResult == null) { - listener.onFailure(new IllegalArgumentException("Fail to find valid aggregation result")); - return; - } - - CompositeAggregation compositeAgg = (CompositeAggregation) aggrResult; - DetectorProfile profile = profileBuilder.totalEntities(Long.valueOf(compositeAgg.getBuckets().size())).build(); - listener.onResponse(profile); - }, searchException -> { - logger.warn(CommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getId()); - listener.onFailure(searchException); - }); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - detector.getId(), - client, - AnalysisType.AD, - searchResponseListener - ); - } - - } - } - - private void onGetDetectorForPrepare(String detectorId, ActionListener listener, Set profiles) { - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - if (profiles.contains(DetectorProfileName.STATE)) { - profileBuilder.state(DetectorState.DISABLED); - } - if (profiles.contains(DetectorProfileName.AD_TASK)) { - adTaskManager.getLatestHistoricalTaskProfile(detectorId, transportService, profileBuilder.build(), listener); - } else { - listener.onResponse(profileBuilder.build()); + } else { + onGetDetectorForPrepare(configId, listener, profilesToCollect); + } + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + logger.info(exception.getMessage()); + onGetDetectorForPrepare(configId, listener, profilesToCollect); + } else { + logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG + configId); + listener.onFailure(exception); + } + })); } } @@ -395,141 +204,29 @@ private void onGetDetectorForPrepare(String detectorId, ActionListener listener, boolean enabled, - Set profilesToCollect + Set profilesToCollect ) { if (enabled) { - RCFPollingRequest request = new RCFPollingRequest(detector.getId()); - client.execute(RCFPollingAction.INSTANCE, request, onPollRCFUpdates(detector, profilesToCollect, listener)); + RCFPollingRequest request = new RCFPollingRequest(config.getId()); + client.execute(RCFPollingAction.INSTANCE, request, onPollRCFUpdates(config, profilesToCollect, listener)); } else { DetectorProfile.Builder builder = new DetectorProfile.Builder(); - if (profilesToCollect.contains(DetectorProfileName.STATE)) { - builder.state(DetectorState.DISABLED); + if (profilesToCollect.contains(ProfileName.STATE)) { + builder.state(ConfigState.DISABLED); } listener.onResponse(builder.build()); } } - private void profileModels( - AnomalyDetector detector, - Set profiles, - Job job, - boolean forMultiEntityDetector, - MultiResponsesDelegateActionListener listener - ) { - DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); - ProfileRequest profileRequest = new ProfileRequest(detector.getId(), profiles, forMultiEntityDetector, dataNodes); - client.execute(ProfileAction.INSTANCE, profileRequest, onModelResponse(detector, profiles, job, listener));// get init progress - } - - private ActionListener onModelResponse( - AnomalyDetector detector, - Set profilesToCollect, - Job job, - MultiResponsesDelegateActionListener listener - ) { - boolean isMultientityDetector = detector.isHighCardinality(); - return ActionListener.wrap(profileResponse -> { - DetectorProfile.Builder profile = new DetectorProfile.Builder(); - if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE)) { - profile.coordinatingNode(profileResponse.getCoordinatingNode()); - } - if (profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE)) { - profile.shingleSize(profileResponse.getShingleSize()); - } - if (profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES)) { - profile.totalSizeInBytes(profileResponse.getTotalSizeInBytes()); - } - if (profilesToCollect.contains(DetectorProfileName.MODELS)) { - profile.modelProfile(profileResponse.getModelProfile()); - profile.modelCount(profileResponse.getModelCount()); - } - if (isMultientityDetector && profilesToCollect.contains(DetectorProfileName.ACTIVE_ENTITIES)) { - profile.activeEntities(profileResponse.getActiveEntities()); - } - - if (isMultientityDetector - && (profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS) - || profilesToCollect.contains(DetectorProfileName.STATE))) { - profileMultiEntityDetectorStateRelated(job, profilesToCollect, profileResponse, profile, detector, listener); - } else { - listener.onResponse(profile.build()); - } - }, listener::onFailure); - } - - private void profileMultiEntityDetectorStateRelated( - Job job, - Set profilesToCollect, - ProfileResponse profileResponse, - DetectorProfile.Builder profileBuilder, - AnomalyDetector detector, - MultiResponsesDelegateActionListener listener - ) { - if (job.isEnabled()) { - if (profileResponse.getTotalUpdates() < requiredSamples) { - // need to double check since what ProfileResponse returns is the highest priority entity currently in memory, but - // another entity might have already been initialized and sit somewhere else (in memory or on disk). - long enabledTime = job.getEnabledTime().toEpochMilli(); - long totalUpdates = profileResponse.getTotalUpdates(); - ProfileUtil - .confirmDetectorRealtimeInitStatus( - detector, - enabledTime, - client, - onInittedEver(enabledTime, profileBuilder, profilesToCollect, detector, totalUpdates, listener) - ); - } else { - createRunningStateAndInitProgress(profilesToCollect, profileBuilder); - listener.onResponse(profileBuilder.build()); - } - } else { - if (profilesToCollect.contains(DetectorProfileName.STATE)) { - profileBuilder.state(DetectorState.DISABLED); - } - listener.onResponse(profileBuilder.build()); - } - } - - private ActionListener onInittedEver( - long lastUpdateTimeMs, - DetectorProfile.Builder profileBuilder, - Set profilesToCollect, - AnomalyDetector detector, - long totalUpdates, - MultiResponsesDelegateActionListener listener - ) { - return ActionListener.wrap(searchResponse -> { - SearchHits hits = searchResponse.getHits(); - if (hits.getTotalHits().value == 0L) { - processInitResponse(detector, profilesToCollect, totalUpdates, false, profileBuilder, listener); - } else { - createRunningStateAndInitProgress(profilesToCollect, profileBuilder); - listener.onResponse(profileBuilder.build()); - } - }, exception -> { - if (ExceptionUtil.isIndexNotAvailable(exception)) { - // anomaly result index is not created yet - processInitResponse(detector, profilesToCollect, totalUpdates, false, profileBuilder, listener); - } else { - logger - .error( - "Fail to find any anomaly result with anomaly score larger than 0 after AD job enabled time for detector {}", - detector.getId() - ); - listener.onFailure(exception); - } - }); - } - /** * Listener for polling rcf updates through transport messaging * @param detector anomaly detector @@ -538,8 +235,8 @@ private ActionListener onInittedEver( * @return Listener for polling rcf updates through transport messaging */ private ActionListener onPollRCFUpdates( - AnomalyDetector detector, - Set profilesToCollect, + Config detector, + Set profilesToCollect, MultiResponsesDelegateActionListener listener ) { return ActionListener.wrap(rcfPollResponse -> { @@ -547,7 +244,7 @@ private ActionListener onPollRCFUpdates( if (totalUpdates < requiredSamples) { processInitResponse(detector, profilesToCollect, totalUpdates, false, new DetectorProfile.Builder(), listener); } else { - DetectorProfile.Builder builder = new DetectorProfile.Builder(); + DetectorProfile.Builder builder = createProfileBuilder(); createRunningStateAndInitProgress(profilesToCollect, builder); listener.onResponse(builder.build()); } @@ -570,7 +267,7 @@ private ActionListener onPollRCFUpdates( // a detector before cold start finishes, where the actual // initialization time may be much shorter if sufficient historical // data exists. - processInitResponse(detector, profilesToCollect, 0L, true, new DetectorProfile.Builder(), listener); + processInitResponse(detector, profilesToCollect, 0L, true, createProfileBuilder(), listener); } else { logger.error(new ParameterizedMessage("Fail to get init progress through messaging for {}", detector.getId()), exception); listener.onFailure(exception); @@ -578,40 +275,4 @@ private ActionListener onPollRCFUpdates( }); } - private void createRunningStateAndInitProgress(Set profilesToCollect, DetectorProfile.Builder builder) { - if (profilesToCollect.contains(DetectorProfileName.STATE)) { - builder.state(DetectorState.RUNNING).build(); - } - - if (profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { - InitProgressProfile initProgress = new InitProgressProfile("100%", 0, 0); - builder.initProgress(initProgress); - } - } - - private void processInitResponse( - AnomalyDetector detector, - Set profilesToCollect, - long totalUpdates, - boolean hideMinutesLeft, - DetectorProfile.Builder builder, - MultiResponsesDelegateActionListener listener - ) { - if (profilesToCollect.contains(DetectorProfileName.STATE)) { - builder.state(DetectorState.INIT); - } - - if (profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { - if (hideMinutesLeft) { - InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, 0); - builder.initProgress(initProgress); - } else { - long intervalMins = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMinutes(); - InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, intervalMins); - builder.initProgress(initProgress); - } - } - - listener.onResponse(builder.build()); - } } diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java index c5336316c..169afe7b4 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java @@ -24,16 +24,16 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchSecurityException; -import org.opensearch.ad.constant.CommonValue; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.Features; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.EntityAnomalyResult; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.constant.CommonValue; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.Features; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.FeatureData; @@ -45,11 +45,11 @@ public final class AnomalyDetectorRunner { private final Logger logger = LogManager.getLogger(AnomalyDetectorRunner.class); - private final ModelManager modelManager; + private final ADModelManager modelManager; private final FeatureManager featureManager; private final int maxPreviewResults; - public AnomalyDetectorRunner(ModelManager modelManager, FeatureManager featureManager, int maxPreviewResults) { + public AnomalyDetectorRunner(ADModelManager modelManager, FeatureManager featureManager, int maxPreviewResults) { this.modelManager = modelManager; this.featureManager = featureManager; this.maxPreviewResults = maxPreviewResults; @@ -103,7 +103,7 @@ public void executeDetector( endTime.toEpochMilli(), ActionListener.wrap(features -> { List entityResults = modelManager - .getPreviewResults(features.getProcessedFeatures(), detector.getShingleSize()); + .getPreviewResults(features.getProcessedFeatures(), detector.getShingleSize(), detector.getTimeDecay()); List sampledEntityResults = sample( parsePreviewResult(detector, features, entityResults, entity), maxPreviewResults @@ -117,7 +117,7 @@ public void executeDetector( featureManager.getPreviewFeatures(detector, startTime.toEpochMilli(), endTime.toEpochMilli(), ActionListener.wrap(features -> { try { List results = modelManager - .getPreviewResults(features.getProcessedFeatures(), detector.getShingleSize()); + .getPreviewResults(features.getProcessedFeatures(), detector.getShingleSize(), detector.getTimeDecay()); listener.onResponse(sample(parsePreviewResult(detector, features, results, null), maxPreviewResults)); } catch (Exception e) { onFailure(e, listener, detector.getId()); @@ -166,24 +166,24 @@ private List parsePreviewResult( AnomalyResult result; if (results != null && results.size() > i) { - ThresholdingResult thresholdingResult = results.get(i); - List resultsToSave = thresholdingResult - .toIndexableResults( - detector, - Instant.ofEpochMilli(timeRange.getKey()), - Instant.ofEpochMilli(timeRange.getValue()), - null, - null, - featureDatas, - Optional.ofNullable(entity), - CommonValue.NO_SCHEMA_VERSION, - null, - null, - null + anomalyResults + .addAll( + results + .get(i) + .toIndexableResults( + detector, + Instant.ofEpochMilli(timeRange.getKey()), + Instant.ofEpochMilli(timeRange.getValue()), + null, + null, + featureDatas, + Optional.ofNullable(entity), + CommonValue.NO_SCHEMA_VERSION, + null, + null, + null + ) ); - for (AnomalyResult r : resultsToSave) { - anomalyResults.add(r); - } } else { result = new AnomalyResult( detector.getId(), diff --git a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java index 6590cead6..e1d042267 100644 --- a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java +++ b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java @@ -11,373 +11,121 @@ package org.opensearch.ad; -import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_LATEST_TASK; - import java.time.Instant; import java.util.ArrayList; -import java.util.HashSet; import java.util.Optional; -import java.util.Set; -import java.util.concurrent.TimeUnit; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.update.UpdateResponse; -import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyResultResponse; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileRequest; -import org.opensearch.ad.transport.RCFPollingAction; -import org.opensearch.ad.transport.RCFPollingRequest; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; -import org.opensearch.search.SearchHits; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.ResourceNotFoundException; -import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.model.FeatureData; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; -import org.opensearch.timeseries.util.ExceptionUtil; -public class ExecuteADResultResponseRecorder { - private static final Logger log = LogManager.getLogger(ExecuteADResultResponseRecorder.class); +public class ExecuteADResultResponseRecorder extends + ExecuteResultResponseRecorder { - private ADIndexManagement anomalyDetectionIndices; - private AnomalyIndexHandler anomalyResultHandler; - private ADTaskManager adTaskManager; - private DiscoveryNodeFilterer nodeFilter; - private ThreadPool threadPool; - private Client client; - private NodeStateManager nodeStateManager; - private ADTaskCacheManager adTaskCacheManager; - private int rcfMinSamples; + private static final Logger log = LogManager.getLogger(ExecuteADResultResponseRecorder.class); public ExecuteADResultResponseRecorder( - ADIndexManagement anomalyDetectionIndices, - AnomalyIndexHandler anomalyResultHandler, - ADTaskManager adTaskManager, + ADIndexManagement indexManagement, + ResultBulkIndexingHandler resultHandler, + ADTaskManager taskManager, DiscoveryNodeFilterer nodeFilter, ThreadPool threadPool, Client client, NodeStateManager nodeStateManager, - ADTaskCacheManager adTaskCacheManager, + ADTaskCacheManager taskCacheManager, int rcfMinSamples ) { - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.anomalyResultHandler = anomalyResultHandler; - this.adTaskManager = adTaskManager; - this.nodeFilter = nodeFilter; - this.threadPool = threadPool; - this.client = client; - this.nodeStateManager = nodeStateManager; - this.adTaskCacheManager = adTaskCacheManager; - this.rcfMinSamples = rcfMinSamples; + super( + indexManagement, + resultHandler, + taskManager, + nodeFilter, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + client, + nodeStateManager, + taskCacheManager, + rcfMinSamples, + ADIndex.RESULT, + AnalysisType.AD, + ADProfileAction.INSTANCE + ); } - public void indexAnomalyResult( - Instant detectionStartTime, - Instant executionStartTime, - AnomalyResultResponse response, - AnomalyDetector detector + @Override + protected AnomalyResult createErrorResult( + String configId, + Instant dataStartTime, + Instant dataEndTime, + Instant executeEndTime, + String errorMessage, + User user ) { - String detectorId = detector.getId(); - try { - // skipping writing to the result index if not necessary - // For a single-entity detector, the result is not useful if error is null - // and rcf score (thus anomaly grade/confidence) is null. - // For a HCAD detector, we don't need to save on the detector level. - // We return 0 or Double.NaN rcf score if there is no error. - if ((response.getAnomalyScore() <= 0 || Double.isNaN(response.getAnomalyScore())) && response.getError() == null) { - updateRealtimeTask(response, detectorId); - return; - } - IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay(); - Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - User user = detector.getUser(); - - if (response.getError() != null) { - log.info("Anomaly result action run successfully for {} with error {}", detectorId, response.getError()); - } - - AnomalyResult anomalyResult = response - .toAnomalyResult( - detectorId, - dataStartTime, - dataEndTime, - executionStartTime, - Instant.now(), - anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), - user, - response.getError() - ); - - String resultIndex = detector.getCustomResultIndex(); - anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); - updateRealtimeTask(response, detectorId); - } catch (EndRunException e) { - throw e; - } catch (Exception e) { - log.error("Failed to index anomaly result for " + detectorId, e); - } + return new AnomalyResult( + configId, + null, // no task id + new ArrayList(), + dataStartTime, + dataEndTime, + executeEndTime, + Instant.now(), + errorMessage, + Optional.empty(), // single-stream detectors have no entity + user, + indexManagement.getSchemaVersion(resultIndex), + null // no model id + ); } /** * Update real time task (one document per detector in state index). If the real-time task has no changes compared with local cache, - * the task won't update. Task only updates when the state changed, or any error happened, or AD job stopped. Task is mainly consumed - * by the front-end to track detector status. For single-stream detectors, we embed model total updates in AnomalyResultResponse and - * update state accordingly. For HCAD, we won't wait for model finishing updating before returning a response to the job scheduler + * the task won't update. Task only updates when the state changed, or any error happened, or job stopped. Task is mainly consumed + * by the front-end to track analysis status. For single-stream analyses, we embed model total updates in ResultResponse and + * update state accordingly. For HC analysis, we won't wait for model finishing updating before returning a response to the job scheduler * since it might be long before all entities finish execution. So we don't embed model total updates in AnomalyResultResponse. * Instead, we issue a profile request to poll each model node and get the maximum total updates among all models. * @param response response returned from executing AnomalyResultAction - * @param detectorId Detector Id + * @param configId config Id */ - private void updateRealtimeTask(AnomalyResultResponse response, String detectorId) { - if (response.isHCDetector() != null && response.isHCDetector()) { - if (adTaskManager.skipUpdateHCRealtimeTask(detectorId, response.getError())) { + @Override + protected void updateRealtimeTask(ResultResponse response, String configId) { + if (response.isHC() != null && response.isHC()) { + if (taskManager.skipUpdateRealtimeTask(configId, response.getError())) { return; } - DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); - Set profiles = new HashSet<>(); - profiles.add(DetectorProfileName.INIT_PROGRESS); - ProfileRequest profileRequest = new ProfileRequest(detectorId, profiles, true, dataNodes); - Runnable profileHCInitProgress = () -> { - client.execute(ProfileAction.INSTANCE, profileRequest, ActionListener.wrap(r -> { - log.debug("Update latest realtime task for HC detector {}, total updates: {}", detectorId, r.getTotalUpdates()); - updateLatestRealtimeTask(detectorId, null, r.getTotalUpdates(), response.getIntervalInMinutes(), response.getError()); - }, e -> { log.error("Failed to update latest realtime task for " + detectorId, e); })); - }; - if (!adTaskManager.isHCRealtimeTaskStartInitializing(detectorId)) { - // real time init progress is 0 may mean this is a newly started detector - // Delay real time cache update by one minute. If we are in init status, the delay may give the model training time to - // finish. We can change the detector running immediately instead of waiting for the next interval. - threadPool - .schedule(profileHCInitProgress, new TimeValue(60, TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); - } else { - profileHCInitProgress.run(); - } - + delayedUpdate(response, configId); } else { log .debug( "Update latest realtime task for single stream detector {}, total updates: {}", - detectorId, + configId, response.getRcfTotalUpdates() ); - updateLatestRealtimeTask(detectorId, null, response.getRcfTotalUpdates(), response.getIntervalInMinutes(), response.getError()); - } - } - - private void updateLatestRealtimeTask( - String detectorId, - String taskState, - Long rcfTotalUpdates, - Long detectorIntervalInMinutes, - String error - ) { - // Don't need info as this will be printed repeatedly in each interval - ActionListener listener = ActionListener.wrap(r -> { - if (r != null) { - log.debug("Updated latest realtime task successfully for detector {}, taskState: {}", detectorId, taskState); - } - }, e -> { - if ((e instanceof ResourceNotFoundException) && e.getMessage().contains(CAN_NOT_FIND_LATEST_TASK)) { - // Clear realtime task cache, will recreate AD task in next run, check AnomalyResultTransportAction. - log.error("Can't find latest realtime task of detector " + detectorId); - adTaskManager.removeRealtimeTaskCache(detectorId); - } else { - log.error("Failed to update latest realtime task for detector " + detectorId, e); - } - }); - - // rcfTotalUpdates is null when we save exception messages - if (!adTaskCacheManager.hasQueriedResultIndex(detectorId) && rcfTotalUpdates != null && rcfTotalUpdates < rcfMinSamples) { - // confirm the total updates number since it is possible that we have already had results after job enabling time - // If yes, total updates should be at least rcfMinSamples so that the init progress reaches 100%. - confirmTotalRCFUpdatesFound( - detectorId, - taskState, - rcfTotalUpdates, - detectorIntervalInMinutes, - error, - ActionListener - .wrap( - r -> adTaskManager - .updateLatestRealtimeTaskOnCoordinatingNode( - detectorId, - taskState, - r, - detectorIntervalInMinutes, - error, - listener - ), - e -> { - log.error("Fail to confirm rcf update", e); - adTaskManager - .updateLatestRealtimeTaskOnCoordinatingNode( - detectorId, - taskState, - rcfTotalUpdates, - detectorIntervalInMinutes, - error, - listener - ); - } - ) - ); - } else { - adTaskManager - .updateLatestRealtimeTaskOnCoordinatingNode( - detectorId, - taskState, - rcfTotalUpdates, - detectorIntervalInMinutes, - error, - listener - ); - } - } - - /** - * The function is not only indexing the result with the exception, but also updating the task state after - * 60s if the exception is related to cold start (index not found exceptions) for a single stream detector. - * - * @param detectionStartTime execution start time - * @param executionStartTime execution end time - * @param errorMessage Error message to record - * @param taskState AD task state (e.g., stopped) - * @param detector Detector config accessor - */ - public void indexAnomalyResultException( - Instant detectionStartTime, - Instant executionStartTime, - String errorMessage, - String taskState, - AnomalyDetector detector - ) { - String detectorId = detector.getId(); - try { - IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay(); - Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - User user = detector.getUser(); - - AnomalyResult anomalyResult = new AnomalyResult( - detectorId, - null, // no task id - new ArrayList(), - dataStartTime, - dataEndTime, - executionStartTime, - Instant.now(), - errorMessage, - Optional.empty(), // single-stream detectors have no entity - user, - anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), - null // no model id + updateLatestRealtimeTask( + configId, + null, + response.getRcfTotalUpdates(), + response.getConfigIntervalInMinutes(), + response.getError() ); - String resultIndex = detector.getCustomResultIndex(); - if (resultIndex != null && !anomalyDetectionIndices.doesIndexExist(resultIndex)) { - // Set result index as null, will write exception to default result index. - anomalyResultHandler.index(anomalyResult, detectorId, null); - } else { - anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); - } - - if (errorMessage.contains(ADCommonMessages.NO_MODEL_ERR_MSG) && !detector.isHighCardinality()) { - // single stream detector raises ResourceNotFoundException containing CommonErrorMessages.NO_CHECKPOINT_ERR_MSG - // when there is no checkpoint. - // Delay real time cache update by one minute so we will have trained models by then and update the state - // document accordingly. - threadPool.schedule(() -> { - RCFPollingRequest request = new RCFPollingRequest(detectorId); - client.execute(RCFPollingAction.INSTANCE, request, ActionListener.wrap(rcfPollResponse -> { - long totalUpdates = rcfPollResponse.getTotalUpdates(); - // if there are updates, don't record failures - updateLatestRealtimeTask( - detectorId, - taskState, - totalUpdates, - detector.getIntervalInMinutes(), - totalUpdates > 0 ? "" : errorMessage - ); - }, e -> { - log.error("Fail to execute RCFRollingAction", e); - updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); - })); - }, new TimeValue(60, TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); - } else { - updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); - } - - } catch (Exception e) { - log.error("Failed to index anomaly result for " + detectorId, e); } } - - private void confirmTotalRCFUpdatesFound( - String detectorId, - String taskState, - Long rcfTotalUpdates, - Long detectorIntervalInMinutes, - String error, - ActionListener listener - ) { - nodeStateManager.getConfig(detectorId, AnalysisType.AD, ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - listener.onFailure(new TimeSeriesException(detectorId, "fail to get detector")); - return; - } - nodeStateManager.getJob(detectorId, ActionListener.wrap(jobOptional -> { - if (!jobOptional.isPresent()) { - listener.onFailure(new TimeSeriesException(detectorId, "fail to get job")); - return; - } - - ProfileUtil - .confirmDetectorRealtimeInitStatus( - (AnomalyDetector) detectorOptional.get(), - jobOptional.get().getEnabledTime().toEpochMilli(), - client, - ActionListener.wrap(searchResponse -> { - ActionListener.completeWith(listener, () -> { - SearchHits hits = searchResponse.getHits(); - Long correctedTotalUpdates = rcfTotalUpdates; - if (hits.getTotalHits().value > 0L) { - // correct the number if we have already had results after job enabling time - // so that the detector won't stay initialized - correctedTotalUpdates = Long.valueOf(rcfMinSamples); - } - adTaskCacheManager.markResultIndexQueried(detectorId); - return correctedTotalUpdates; - }); - }, exception -> { - if (ExceptionUtil.isIndexNotAvailable(exception)) { - // anomaly result index is not created yet - adTaskCacheManager.markResultIndexQueried(detectorId); - listener.onResponse(0L); - } else { - listener.onFailure(exception); - } - }) - ); - }, e -> listener.onFailure(new TimeSeriesException(detectorId, "fail to get job")))); - }, e -> listener.onFailure(new TimeSeriesException(detectorId, "fail to get detector")))); - } } diff --git a/src/main/java/org/opensearch/ad/ProfileUtil.java b/src/main/java/org/opensearch/ad/ProfileUtil.java deleted file mode 100644 index 3d77924d0..000000000 --- a/src/main/java/org/opensearch/ad/ProfileUtil.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad; - -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.client.Client; -import org.opensearch.core.action.ActionListener; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.ExistsQueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.timeseries.constant.CommonName; - -public class ProfileUtil { - /** - * Create search request to check if we have at least 1 anomaly score larger than 0 after AD job enabled time. - * Note this function is only meant to check for status of real time analysis. - * - * @param detectorId detector id - * @param enabledTime the time when AD job is enabled in milliseconds - * @return the search request - */ - private static SearchRequest createRealtimeInittedEverRequest(String detectorId, long enabledTime, String resultIndex) { - BoolQueryBuilder filterQuery = new BoolQueryBuilder(); - filterQuery.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); - filterQuery.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); - filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.ANOMALY_SCORE_FIELD).gt(0)); - // Historical analysis result also stored in result index, which has non-null task_id. - // For realtime detection result, we should filter task_id == null - ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); - filterQuery.mustNot(taskIdExistsFilter); - - SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); - - SearchRequest request = new SearchRequest(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); - request.source(source); - if (resultIndex != null) { - request.indices(resultIndex); - } - return request; - } - - public static void confirmDetectorRealtimeInitStatus( - AnomalyDetector detector, - long enabledTime, - Client client, - ActionListener listener - ) { - SearchRequest searchLatestResult = createRealtimeInittedEverRequest(detector.getId(), enabledTime, detector.getCustomResultIndex()); - client.search(searchLatestResult, listener); - } -} diff --git a/src/main/java/org/opensearch/ad/caching/ADCacheBuffer.java b/src/main/java/org/opensearch/ad/caching/ADCacheBuffer.java new file mode 100644 index 000000000..828146516 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/ADCacheBuffer.java @@ -0,0 +1,75 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import java.time.Clock; +import java.time.Duration; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.CacheBuffer; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * We use a layered cache to manage active entities’ states. We have a two-level + * cache that stores active entity states in each node. Each detector has its + * dedicated cache that stores ten (dynamically adjustable) entities’ states per + * node. A detector’s hottest entities load their states in the dedicated cache. + * If less than 10 entities use the dedicated cache, the secondary cache can use + * the rest of the free memory available to AD. The secondary cache is a shared + * memory among all detectors for the long tail. The shared cache size is 10% + * heap minus all of the dedicated cache consumed by single-entity and multi-entity + * detectors. The shared cache’s size shrinks as the dedicated cache is filled + * up or more detectors are started. + * + * Implementation-wise, both dedicated cache and shared cache are stored in items + * and minimumCapacity controls the boundary. If items size is equals to or less + * than minimumCapacity, consider items as dedicated cache; otherwise, consider + * top minimumCapacity active entities (last X entities in priorityList) as in dedicated + * cache and all others in shared cache. + */ +public class ADCacheBuffer extends + CacheBuffer { + + public ADCacheBuffer( + int minimumCapacity, + Clock clock, + MemoryTracker memoryTracker, + int checkpointIntervalHrs, + Duration modelTtl, + long memoryConsumptionPerEntity, + ADCheckpointWriteWorker checkpointWriteQueue, + ADCheckpointMaintainWorker checkpointMaintainQueue, + String configId, + long intervalSecs + ) { + super( + minimumCapacity, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + memoryConsumptionPerEntity, + checkpointWriteQueue, + checkpointMaintainQueue, + configId, + intervalSecs, + Origin.REAL_TIME_DETECTOR + ); + } +} diff --git a/src/main/java/org/opensearch/ad/caching/ADCacheProvider.java b/src/main/java/org/opensearch/ad/caching/ADCacheProvider.java new file mode 100644 index 000000000..e71c89962 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/ADCacheProvider.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.caching; + +import org.opensearch.timeseries.caching.CacheProvider; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Allows Guice dependency based on types. Otherwise, Guice cannot + * decide which instance to inject based on generic types of CacheProvider + * + */ +public class ADCacheProvider extends CacheProvider { + +} diff --git a/src/main/java/org/opensearch/ad/caching/ADPriorityCache.java b/src/main/java/org/opensearch/ad/caching/ADPriorityCache.java new file mode 100644 index 000000000..a028333c5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/ADPriorityCache.java @@ -0,0 +1,117 @@ +/* +f * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_DEDICATED_CACHE_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Optional; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.PriorityCache; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.model.Config; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ADPriorityCache extends + PriorityCache { + private ADCheckpointWriteWorker checkpointWriteQueue; + private ADCheckpointMaintainWorker checkpointMaintainQueue; + + public ADPriorityCache( + ADCheckpointDao checkpointDao, + int hcDedicatedCacheSize, + Setting checkpointTtl, + int maxInactiveStates, + MemoryTracker memoryTracker, + int numberOfTrees, + Clock clock, + ClusterService clusterService, + Duration modelTtl, + ThreadPool threadPool, + int maintenanceFreqConstant, + Settings settings, + Setting checkpointSavingFreq, + ADCheckpointWriteWorker checkpointWriteQueue, + ADCheckpointMaintainWorker checkpointMaintainQueue + ) { + super( + checkpointDao, + hcDedicatedCacheSize, + checkpointTtl, + maxInactiveStates, + memoryTracker, + numberOfTrees, + clock, + clusterService, + modelTtl, + threadPool, + AD_THREAD_POOL_NAME, + maintenanceFreqConstant, + settings, + checkpointSavingFreq, + Origin.REAL_TIME_DETECTOR, + AD_DEDICATED_CACHE_SIZE, + AD_MODEL_MAX_SIZE_PERCENTAGE + ); + + this.checkpointWriteQueue = checkpointWriteQueue; + this.checkpointMaintainQueue = checkpointMaintainQueue; + } + + @Override + protected ADCacheBuffer createEmptyCacheBuffer(Config detector, long memoryConsumptionPerEntity) { + return new ADCacheBuffer( + detector.isHighCardinality() ? hcDedicatedCacheSize : 1, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + memoryConsumptionPerEntity, + checkpointWriteQueue, + checkpointMaintainQueue, + detector.getId(), + detector.getIntervalInSeconds() + ); + } + + @Override + protected ModelState createEmptyModelState(String modelId, String detectorId) { + return new ModelState<>( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + 0, + Optional.empty(), + new ArrayDeque<>() + ); + } +} diff --git a/src/main/java/org/opensearch/ad/caching/CacheProvider.java b/src/main/java/org/opensearch/ad/caching/CacheProvider.java deleted file mode 100644 index ab8fd191c..000000000 --- a/src/main/java/org/opensearch/ad/caching/CacheProvider.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.caching; - -import org.opensearch.common.inject.Provider; - -/** - * A wrapper to call concrete implementation of caching. Used in transport - * action. Don't use interface because transport action handler constructor - * requires a concrete class as input. - * - */ -public class CacheProvider implements Provider { - private EntityCache cache; - - public CacheProvider() { - - } - - @Override - public EntityCache get() { - return cache; - } - - public void set(EntityCache cache) { - this.cache = cache; - } -} diff --git a/src/main/java/org/opensearch/ad/caching/EntityCache.java b/src/main/java/org/opensearch/ad/caching/EntityCache.java deleted file mode 100644 index 287994efd..000000000 --- a/src/main/java/org/opensearch/ad/caching/EntityCache.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.caching; - -import java.util.Collection; -import java.util.List; -import java.util.Optional; - -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.ad.DetectorModelSize; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.timeseries.CleanState; -import org.opensearch.timeseries.MaintenanceState; -import org.opensearch.timeseries.model.Entity; - -public interface EntityCache extends MaintenanceState, CleanState, DetectorModelSize { - /** - * Get the ModelState associated with the entity. May or may not load the - * ModelState depending on the underlying cache's eviction policy. - * - * @param modelId Model Id - * @param detector Detector config object - * @return the ModelState associated with the model or null if no cached item - * for the entity - */ - ModelState get(String modelId, AnomalyDetector detector); - - /** - * Get the number of active entities of a detector - * @param detector Detector Id - * @return The number of active entities - */ - int getActiveEntities(String detector); - - /** - * - * @return total active entities in the cache - */ - int getTotalActiveEntities(); - - /** - * Whether an entity is active or not - * @param detectorId The Id of the detector that an entity belongs to - * @param entityModelId Entity model Id - * @return Whether an entity is active or not - */ - boolean isActive(String detectorId, String entityModelId); - - /** - * Get total updates of detector's most active entity's RCF model. - * - * @param detectorId detector id - * @return RCF model total updates of most active entity. - */ - long getTotalUpdates(String detectorId); - - /** - * Get RCF model total updates of specific entity - * - * @param detectorId detector id - * @param entityModelId entity model id - * @return RCF model total updates of specific entity. - */ - long getTotalUpdates(String detectorId, String entityModelId); - - /** - * Gets modelStates of all model hosted on a node - * - * @return list of modelStates - */ - List> getAllModels(); - - /** - * Return when the last active time of an entity's state. - * - * If the entity's state is active in the cache, the value indicates when the cache - * is lastly accessed (get/put). If the entity's state is inactive in the cache, - * the value indicates when the cache state is created or when the entity is evicted - * from active entity cache. - * - * @param detectorId The Id of the detector that an entity belongs to - * @param entityModelId Entity's Model Id - * @return if the entity is in the cache, return the timestamp in epoch - * milliseconds when the entity's state is lastly used. Otherwise, return -1. - */ - long getLastActiveMs(String detectorId, String entityModelId); - - /** - * Release memory when memory circuit breaker is open - */ - void releaseMemoryForOpenCircuitBreaker(); - - /** - * Select candidate entities for which we can load models - * @param cacheMissEntities Cache miss entities - * @param detectorId Detector Id - * @param detector Detector object - * @return A list of entities that are admitted into the cache as a result of the - * update and the left-over entities - */ - Pair, List> selectUpdateCandidate( - Collection cacheMissEntities, - String detectorId, - AnomalyDetector detector - ); - - /** - * - * @param detector Detector config - * @param toUpdate Model state candidate - * @return if we can host the given model state - */ - boolean hostIfPossible(AnomalyDetector detector, ModelState toUpdate); - - /** - * - * @param detectorId Detector Id - * @return a detector's model information - */ - List getAllModelProfile(String detectorId); - - /** - * Gets an entity's model sizes - * - * @param detectorId Detector Id - * @param entityModelId Entity's model Id - * @return the entity's memory size - */ - Optional getModelProfile(String detectorId, String entityModelId); - - /** - * Get a model state without incurring priority update. Used in maintenance. - * @param detectorId Detector Id - * @param modelId Model Id - * @return Model state - */ - Optional> getForMaintainance(String detectorId, String modelId); - - /** - * Remove entity model from active entity buffer and delete checkpoint. Used to clean corrupted model. - * @param detectorId Detector Id - * @param entityModelId Model Id - */ - void removeEntityModel(String detectorId, String entityModelId); -} diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java index 7dfd223b9..eda7192c3 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java @@ -8,10 +8,10 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.transport.GetAnomalyDetectorRequest; import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.transport.GetConfigRequest; /** * A client to provide interfaces for anomaly detection functionality. This will be used by other plugins. @@ -58,7 +58,7 @@ default ActionFuture searchAnomalyResults(SearchRequest searchRe * @param profileRequest request to fetch the detector profile * @return ActionFuture of GetAnomalyDetectorResponse */ - default ActionFuture getDetectorProfile(GetAnomalyDetectorRequest profileRequest) { + default ActionFuture getDetectorProfile(GetConfigRequest profileRequest) { PlainActionFuture actionFuture = PlainActionFuture.newFuture(); getDetectorProfile(profileRequest, actionFuture); return actionFuture; @@ -69,6 +69,6 @@ default ActionFuture getDetectorProfile(GetAnomalyDe * @param profileRequest request to fetch the detector profile * @param listener a listener to be notified of the result */ - void getDetectorProfile(GetAnomalyDetectorRequest profileRequest, ActionListener listener); + void getDetectorProfile(GetConfigRequest profileRequest, ActionListener listener); } diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java index 60bb274ab..051aae2c6 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java @@ -10,13 +10,13 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.transport.GetAnomalyDetectorAction; -import org.opensearch.ad.transport.GetAnomalyDetectorRequest; import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.ad.transport.SearchAnomalyDetectorAction; import org.opensearch.ad.transport.SearchAnomalyResultAction; import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; +import org.opensearch.timeseries.transport.GetConfigRequest; public class AnomalyDetectionNodeClient implements AnomalyDetectionClient { private final Client client; @@ -40,7 +40,7 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener listener) { + public void getDetectorProfile(GetConfigRequest profileRequest, ActionListener listener) { this.client.execute(GetAnomalyDetectorAction.INSTANCE, profileRequest, getAnomalyDetectorResponseActionListener(listener)); } diff --git a/src/main/java/org/opensearch/ad/cluster/diskcleanup/ADCheckpointIndexRetention.java b/src/main/java/org/opensearch/ad/cluster/diskcleanup/ADCheckpointIndexRetention.java new file mode 100644 index 000000000..6cf8c2385 --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/diskcleanup/ADCheckpointIndexRetention.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.cluster.diskcleanup; + +import java.time.Clock; +import java.time.Duration; + +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.timeseries.cluster.diskcleanup.BaseModelCheckpointIndexRetention; +import org.opensearch.timeseries.cluster.diskcleanup.IndexCleanup; + +public class ADCheckpointIndexRetention extends BaseModelCheckpointIndexRetention { + + public ADCheckpointIndexRetention(Duration defaultCheckpointTtl, Clock clock, IndexCleanup indexCleanup) { + super(defaultCheckpointTtl, clock, indexCleanup, ADCommonName.CHECKPOINT_INDEX_NAME); + } + +} diff --git a/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java b/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java index 1f186f647..d782c5e4c 100644 --- a/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java +++ b/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java @@ -26,7 +26,6 @@ public class ADCommonMessages { public static String DETECTOR_MISSING = "Detector is missing"; public static String AD_TASK_ACTION_MISSING = "AD task action is missing"; public static final String INDEX_NOT_FOUND = "index does not exist"; - public static final String NOT_EXISTENT_VALIDATION_TYPE = "The given validation type doesn't exist"; public static final String UNSUPPORTED_PROFILE_TYPE = "Unsupported profile types"; public static final String REQUEST_THROTTLED_MSG = "Request throttled. Please try again later."; diff --git a/src/main/java/org/opensearch/ad/constant/ADCommonName.java b/src/main/java/org/opensearch/ad/constant/ADCommonName.java index 3a97db889..260d162f1 100644 --- a/src/main/java/org/opensearch/ad/constant/ADCommonName.java +++ b/src/main/java/org/opensearch/ad/constant/ADCommonName.java @@ -11,8 +11,6 @@ package org.opensearch.ad.constant; -import org.opensearch.timeseries.stats.StatNames; - public class ADCommonName { // ====================================== // Index name @@ -25,46 +23,11 @@ public class ADCommonName { // The alias of the index in which to write AD result history public static final String ANOMALY_RESULT_INDEX_ALIAS = ".opendistro-anomaly-results"; - // ====================================== - // Format name - // ====================================== - public static final String EPOCH_MILLIS_FORMAT = "epoch_millis"; - // ====================================== // Anomaly Detector name for X-Opaque-Id header // ====================================== public static final String ANOMALY_DETECTOR = "[Anomaly Detector]"; - // ====================================== - // Ultrawarm node attributes - // ====================================== - - // hot node - public static String HOT_BOX_TYPE = "hot"; - - // warm node - public static String WARM_BOX_TYPE = "warm"; - - // box type - public static final String BOX_TYPE_KEY = "box_type"; - - // ====================================== - // Profile name - // ====================================== - public static final String STATE = "state"; - public static final String ERROR = "error"; - public static final String COORDINATING_NODE = "coordinating_node"; - public static final String SHINGLE_SIZE = "shingle_size"; - public static final String TOTAL_SIZE_IN_BYTES = "total_size_in_bytes"; - public static final String MODELS = "models"; - public static final String MODEL = "model"; - public static final String INIT_PROGRESS = "init_progress"; - public static final String CATEGORICAL_FIELD = "category_field"; - public static final String TOTAL_ENTITIES = "total_entities"; - public static final String ACTIVE_ENTITIES = "active_entities"; - public static final String ENTITY_INFO = "entity_info"; - public static final String TOTAL_UPDATES = "total_updates"; - public static final String MODEL_COUNT = StatNames.MODEL_COUNT.getName(); // ====================================== // Historical detectors // ====================================== @@ -87,11 +50,8 @@ public class ADCommonName { public static final String CONFIDENCE_JSON_KEY = "confidence"; public static final String ANOMALY_GRADE_JSON_KEY = "anomalyGrade"; public static final String QUEUE_JSON_KEY = "queue"; - // ====================================== - // Used for backward-compatibility in messaging - // ====================================== - public static final String EMPTY_FIELD = ""; + // ====================================== // Validation // ====================================== // detector validation aspect diff --git a/src/main/java/org/opensearch/ad/constant/CommonValue.java b/src/main/java/org/opensearch/ad/constant/ADCommonValue.java similarity index 81% rename from src/main/java/org/opensearch/ad/constant/CommonValue.java rename to src/main/java/org/opensearch/ad/constant/ADCommonValue.java index f5d5b15eb..91b9f72f7 100644 --- a/src/main/java/org/opensearch/ad/constant/CommonValue.java +++ b/src/main/java/org/opensearch/ad/constant/ADCommonValue.java @@ -11,9 +11,7 @@ package org.opensearch.ad.constant; -public class CommonValue { - // unknown or no schema version - public static Integer NO_SCHEMA_VERSION = 0; +public class ADCommonValue { public static String INTERNAL_ACTION_PREFIX = "cluster:admin/opendistro/adinternal/"; public static String EXTERNAL_ACTION_PREFIX = "cluster:admin/opendistro/ad/"; } diff --git a/src/main/java/org/opensearch/ad/ml/CheckpointDao.java b/src/main/java/org/opensearch/ad/ml/ADCheckpointDao.java similarity index 56% rename from src/main/java/org/opensearch/ad/ml/CheckpointDao.java rename to src/main/java/org/opensearch/ad/ml/ADCheckpointDao.java index adb097cb6..e2edcfdd0 100644 --- a/src/main/java/org/opensearch/ad/ml/CheckpointDao.java +++ b/src/main/java/org/opensearch/ad/ml/ADCheckpointDao.java @@ -12,45 +12,29 @@ package org.opensearch.ad.ml; import java.io.IOException; +import java.lang.reflect.Type; import java.security.AccessController; import java.security.PrivilegedAction; import java.time.Clock; -import java.time.Duration; import java.time.Instant; import java.time.ZoneOffset; -import java.time.ZonedDateTime; -import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; +import java.util.Deque; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import org.apache.commons.pool2.impl.GenericObjectPool; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ExceptionsHelper; -import org.opensearch.ResourceAlreadyExistsException; -import org.opensearch.action.bulk.BulkAction; -import org.opensearch.action.bulk.BulkItemResponse; -import org.opensearch.action.bulk.BulkRequest; -import org.opensearch.action.bulk.BulkResponse; -import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.get.MultiGetAction; -import org.opensearch.action.get.MultiGetRequest; -import org.opensearch.action.get.MultiGetResponse; import org.opensearch.action.support.IndicesOptions; -import org.opensearch.action.update.UpdateRequest; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; @@ -58,13 +42,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.MatchQueryBuilder; -import org.opensearch.index.reindex.BulkByScrollResponse; -import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; -import org.opensearch.index.reindex.ScrollableHitSource; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; -import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.util.ClientUtil; @@ -80,6 +64,7 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; +import com.google.gson.reflect.TypeToken; import io.protostuff.LinkedBuffer; import io.protostuff.ProtostuffIOUtil; @@ -88,30 +73,18 @@ /** * DAO for model checkpoints. */ -public class CheckpointDao { - - private static final Logger logger = LogManager.getLogger(CheckpointDao.class); - static final String TIMEOUT_LOG_MSG = "Timeout while deleting checkpoints of"; - static final String BULK_FAILURE_LOG_MSG = "Bulk failure while deleting checkpoints of"; - static final String SEARCH_FAILURE_LOG_MSG = "Search failure while deleting checkpoints of"; - static final String DOC_GOT_DELETED_LOG_MSG = "checkpoints docs get deleted"; - static final String INDEX_DELETED_LOG_MSG = "Checkpoint index has been deleted. Has nothing to do:"; - static final String NOT_ABLE_TO_DELETE_LOG_MSG = "Cannot delete all checkpoints of detector"; +public class ADCheckpointDao extends CheckpointDao { + private static final Logger logger = LogManager.getLogger(ADCheckpointDao.class); + // ====================================== + // Model serialization/deserialization + // ====================================== public static final String ENTITY_RCF = "rcf"; public static final String ENTITY_THRESHOLD = "th"; public static final String ENTITY_TRCF = "trcf"; public static final String FIELD_MODELV2 = "modelV2"; public static final String DETECTOR_ID = "detectorId"; - // dependencies - private final Client client; - private final ClientUtil clientUtil; - - // configuration - private final String indexName; - - private Gson gson; private RandomCutForestMapper mapper; // For further reference v1, v2 and v3 refer to the different variations of RCF models @@ -129,20 +102,17 @@ public class CheckpointDao { private final ADIndexManagement indexUtil; private final JsonParser parser = new JsonParser(); - // we won't read/write a checkpoint larger than a threshold - private final int maxCheckpointBytes; - private final GenericObjectPool serializeRCFBufferPool; - private final int serializeRCFBufferSize; // anomaly rate private double anomalyRate; + // Use TypeToken to properly deserialize the double array + private final Type doubleArrayType; /** * Constructor with dependencies and configuration. * * @param client ES search client * @param clientUtil utility with ES client - * @param indexName name of the index for model checkpoints * @param gson accessor to Gson functionality * @param mapper RCF model serialization utility * @param converter converter from rcf v1 serde to protostuff based format @@ -155,10 +125,9 @@ public class CheckpointDao { * @param serializeRCFBufferSize the size of the buffer for RCF serialization * @param anomalyRate anomaly rate */ - public CheckpointDao( + public ADCheckpointDao( Client client, ClientUtil clientUtil, - String indexName, Gson gson, RandomCutForestMapper mapper, V1JsonToV3StateConverter converter, @@ -169,30 +138,29 @@ public CheckpointDao( int maxCheckpointBytes, GenericObjectPool serializeRCFBufferPool, int serializeRCFBufferSize, - double anomalyRate + double anomalyRate, + Clock clock ) { - this.client = client; - this.clientUtil = clientUtil; - this.indexName = indexName; - this.gson = gson; + super( + client, + clientUtil, + ADCommonName.CHECKPOINT_INDEX_NAME, + gson, + maxCheckpointBytes, + serializeRCFBufferPool, + serializeRCFBufferSize, + indexUtil, + clock + ); this.mapper = mapper; this.converter = converter; this.trcfMapper = trcfMapper; this.trcfSchema = trcfSchema; this.thresholdingModelClass = thresholdingModelClass; this.indexUtil = indexUtil; - this.maxCheckpointBytes = maxCheckpointBytes; - this.serializeRCFBufferPool = serializeRCFBufferPool; - this.serializeRCFBufferSize = serializeRCFBufferSize; this.anomalyRate = anomalyRate; - } - - private void putModelCheckpoint(String modelId, Map source, ActionListener listener) { - if (indexUtil.doesCheckpointIndexExist()) { - saveModelCheckpointAsync(source, modelId, listener); - } else { - onCheckpointNotExist(source, modelId, listener); - } + this.doubleArrayType = new TypeToken() { + }.getType(); } /** @@ -207,7 +175,7 @@ public void putTRCFCheckpoint(String modelId, ThresholdedRandomCutForest forest, String modelCheckpoint = toCheckpoint(forest); if (modelCheckpoint != null) { source.put(FIELD_MODELV2, modelCheckpoint); - source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + source.put(CommonName.TIMESTAMP, clock.instant().atZone(ZoneOffset.UTC)); putModelCheckpoint(modelId, source, listener); } else { listener.onFailure(new RuntimeException("Fail to create checkpoint to save")); @@ -225,80 +193,58 @@ public void putThresholdCheckpoint(String modelId, ThresholdingModel threshold, String modelCheckpoint = AccessController.doPrivileged((PrivilegedAction) () -> gson.toJson(threshold)); Map source = new HashMap<>(); source.put(CommonName.FIELD_MODEL, modelCheckpoint); - source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + source.put(CommonName.TIMESTAMP, clock.instant().atZone(ZoneOffset.UTC)); putModelCheckpoint(modelId, source, listener); } - private void onCheckpointNotExist(Map source, String modelId, ActionListener listener) { - indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { - saveModelCheckpointAsync(source, modelId, listener); - } else { - throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - saveModelCheckpointAsync(source, modelId, listener); - } else { - logger.error(String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), exception); - } - })); - } - - /** - * Update the model doc using fields in source. This ensures we won't touch - * the old checkpoint and nodes with old/new logic can coexist in a cluster. - * This is useful for introducing compact rcf new model format. - * - * @param source fields to update - * @param modelId model Id, used as doc id in the checkpoint index - * @param listener Listener to return response - */ - private void saveModelCheckpointAsync(Map source, String modelId, ActionListener listener) { - - UpdateRequest updateRequest = new UpdateRequest(indexName, modelId); - updateRequest.doc(source); - // If the document does not already exist, the contents of the upsert element are inserted as a new document. - // If the document exists, update fields in the map - updateRequest.docAsUpsert(true); - clientUtil - .asyncRequest( - updateRequest, - client::update, - ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) - ); - } - /** * Prepare for index request using the contents of the given model state * @param modelState an entity model state * @return serialized JSON map or empty map if the state is too bloated * @throws IOException when serialization fails */ - public Map toIndexSource(ModelState modelState) throws IOException { + @Override + public Map toIndexSource(ModelState modelState) throws IOException { String modelId = modelState.getModelId(); Map source = new HashMap<>(); - EntityModel model = modelState.getModel(); - Optional serializedModel = toCheckpoint(model, modelId); - if (!serializedModel.isPresent() || serializedModel.get().length() > maxCheckpointBytes) { - logger - .warn( - new ParameterizedMessage( - "[{}]'s model is empty or too large: [{}] bytes", - modelState.getModelId(), - serializedModel.isPresent() ? serializedModel.get().length() : 0 - ) - ); + + Optional model = modelState.getModel(); + if (model.isPresent()) { + ThresholdedRandomCutForest entityModel = model.get(); + + Optional serializedModel = toCheckpoint(entityModel, modelId); + if (!serializedModel.isPresent() || serializedModel.get().length() > maxCheckpointBytes) { + logger + .warn( + new ParameterizedMessage( + "[{}]'s model is empty or too large: [{}] bytes", + modelState.getModelId(), + serializedModel.isPresent() ? serializedModel.get().length() : 0 + ) + ); + return source; + } + source.put(FIELD_MODELV2, serializedModel.get()); + } + + Optional samples = toCheckpoint(modelState.getSamples()); + if (samples.isPresent()) { + source.put(CommonName.SAMPLE_QUEUE, samples.get()); + } + + // if there are no samples and no model, no need to index as other information are meta data + if (!source.containsKey(CommonName.SAMPLE_QUEUE) && !source.containsKey(FIELD_MODELV2)) { return source; } - String detectorId = modelState.getId(); + + String detectorId = modelState.getConfigId(); source.put(DETECTOR_ID, detectorId); // we cannot pass Optional as OpenSearch does not know how to serialize an Optional value - source.put(FIELD_MODELV2, serializedModel.get()); - source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); - source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ADIndex.CHECKPOINT)); - Optional entity = model.getEntity(); + + source.put(CommonName.TIMESTAMP, clock.instant().atZone(ZoneOffset.UTC)); + source.put(org.opensearch.timeseries.constant.CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ADIndex.CHECKPOINT)); + + Optional entity = modelState.getEntity(); if (entity.isPresent()) { source.put(CommonName.ENTITY_KEY, entity.get()); } @@ -312,7 +258,7 @@ public Map toIndexSource(ModelState modelState) thr * @param modelId model id * @return serialized string */ - public Optional toCheckpoint(EntityModel model, String modelId) { + public Optional toCheckpoint(ThresholdedRandomCutForest model, String modelId) { return AccessController.doPrivileged((PrivilegedAction>) () -> { if (model == null) { logger.warn("Empty model"); @@ -320,11 +266,8 @@ public Optional toCheckpoint(EntityModel model, String modelId) { } try { JsonObject json = new JsonObject(); - if (model.getSamples() != null && !(model.getSamples().isEmpty())) { - json.add(CommonName.ENTITY_SAMPLE, gson.toJsonTree(model.getSamples())); - } - if (model.getTrcf().isPresent()) { - json.addProperty(ENTITY_TRCF, toCheckpoint(model.getTrcf().get())); + if (model != null) { + json.addProperty(ENTITY_TRCF, toCheckpoint(model)); } // if json is empty, it will be an empty Json string {}. No need to save it on disk. return json.entrySet().isEmpty() ? Optional.empty() : Optional.ofNullable(gson.toJson(json)); @@ -335,7 +278,7 @@ public Optional toCheckpoint(EntityModel model, String modelId) { }); } - private String toCheckpoint(ThresholdedRandomCutForest trcf) { + String toCheckpoint(ThresholdedRandomCutForest trcf) { String checkpoint = null; Map.Entry result = checkoutOrNewBuffer(); LinkedBuffer buffer = result.getKey(); @@ -369,21 +312,6 @@ private String toCheckpoint(ThresholdedRandomCutForest trcf) { return checkpoint; } - private Map.Entry checkoutOrNewBuffer() { - LinkedBuffer buffer = null; - boolean isCheckout = true; - try { - buffer = serializeRCFBufferPool.borrowObject(); - } catch (Exception e) { - logger.warn("Failed to borrow a buffer from pool", e); - } - if (buffer == null) { - buffer = LinkedBuffer.allocate(serializeRCFBufferSize); - isCheckout = false; - } - return new SimpleImmutableEntry(buffer, isCheckout); - } - private String toCheckpoint(ThresholdedRandomCutForest trcf, LinkedBuffer buffer) { try { byte[] bytes = AccessController.doPrivileged((PrivilegedAction) () -> { @@ -396,73 +324,6 @@ private String toCheckpoint(ThresholdedRandomCutForest trcf, LinkedBuffer buffer } } - /** - * Deletes the model checkpoint for the model. - * - * @param modelId id of the model - * @param listener onReponse is called with null when the operation is completed - */ - public void deleteModelCheckpoint(String modelId, ActionListener listener) { - clientUtil - .asyncRequest( - new DeleteRequest(indexName, modelId), - client::delete, - ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) - ); - } - - /** - * Delete checkpoints associated with a detector. Used in multi-entity detector. - * @param detectorID Detector Id - */ - public void deleteModelCheckpointByDetectorId(String detectorID) { - // A bulk delete request is performed for each batch of matching documents. If a - // search or bulk request is rejected, the requests are retried up to 10 times, - // with exponential back off. If the maximum retry limit is reached, processing - // halts and all failed requests are returned in the response. Any delete - // requests that completed successfully still stick, they are not rolled back. - DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(ADCommonName.CHECKPOINT_INDEX_NAME) - .setQuery(new MatchQueryBuilder(DETECTOR_ID, detectorID)) - .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) - .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. - // Retry in this case - .setRequestsPerSecond(500); // throttle delete requests - logger.info("Delete checkpoints of detector {}", detectorID); - client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { - if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { - logFailure(response, detectorID); - } - // can return 0 docs get deleted because: - // 1) we cannot find matching docs - // 2) bad stats from OpenSearch. In this case, docs are deleted, but - // OpenSearch says deleted is 0. - logger.info("{} " + DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); - }, exception -> { - if (exception instanceof IndexNotFoundException) { - logger.info(INDEX_DELETED_LOG_MSG + " {}", detectorID); - } else { - // Gonna eventually delete in daily cron. - logger.error(NOT_ABLE_TO_DELETE_LOG_MSG, exception); - } - })); - } - - private void logFailure(BulkByScrollResponse response, String detectorID) { - if (response.isTimedOut()) { - logger.warn(TIMEOUT_LOG_MSG + " {}", detectorID); - } else if (!response.getBulkFailures().isEmpty()) { - logger.warn(BULK_FAILURE_LOG_MSG + " {}", detectorID); - for (BulkItemResponse.Failure bulkFailure : response.getBulkFailures()) { - logger.warn(bulkFailure); - } - } else { - logger.warn(SEARCH_FAILURE_LOG_MSG + " {}", detectorID); - for (ScrollableHitSource.SearchFailure searchFailure : response.getSearchFailures()) { - logger.warn(searchFailure); - } - } - } - /** * Load json checkpoint into models * @@ -471,9 +332,14 @@ private void logFailure(BulkByScrollResponse response, String detectorID) { * @return a pair of entity model and its last checkpoint time; or empty if * the raw checkpoint is too large */ - public Optional> fromEntityModelCheckpoint(Map checkpoint, String modelId) { + @Override + protected ModelState fromEntityModelCheckpoint( + Map checkpoint, + String modelId, + String configId + ) { try { - return AccessController.doPrivileged((PrivilegedAction>>) () -> { + return AccessController.doPrivileged((PrivilegedAction>) () -> { Object modelObj = checkpoint.get(FIELD_MODELV2); if (modelObj == null) { // in case there is old -format checkpoint @@ -481,24 +347,14 @@ public Optional> fromEntityModelCheckpoint(Map maxCheckpointBytes) { logger.warn(new ParameterizedMessage("[{}]'s model too large: [{}] bytes", modelId, model.length())); - return Optional.empty(); + return null; } JsonObject json = parser.parse(model).getAsJsonObject(); - ArrayDeque samples = null; - if (json.has(CommonName.ENTITY_SAMPLE)) { - // verified, don't need privileged call to get permission - samples = new ArrayDeque<>( - Arrays.asList(this.gson.fromJson(json.getAsJsonArray(CommonName.ENTITY_SAMPLE), new double[0][0].getClass())) - ); - } else { - // avoid possible null pointer exception - samples = new ArrayDeque<>(); - } ThresholdedRandomCutForest trcf = null; if (json.has(ENTITY_TRCF)) { @@ -518,15 +374,19 @@ public Optional> fromEntityModelCheckpoint(Map convertedTRCF = convertToTRCF(rcf, threshold); - // if checkpoint is corrupted (e.g., some unexpected checkpoint when we missed - // the mark in backward compatibility), we are not gonna load the model part - // the model will have to use live data to initialize - if (convertedTRCF.isPresent()) { - trcf = convertedTRCF.get(); + if (rcf.isPresent()) { + Optional convertedTRCF = convertToTRCF(rcf.get(), threshold); + // if checkpoint is corrupted (e.g., some unexpected checkpoint when we missed + // the mark in backward compatibility), we are not gonna load the model part + // the model will have to use live data to initialize + if (convertedTRCF.isPresent()) { + trcf = convertedTRCF.get(); + } } } + Deque sampleQueue = processSampleQueue(json, checkpoint, modelId); + String lastCheckpointTimeString = (String) (checkpoint.get(CommonName.TIMESTAMP)); Instant timestamp = Instant.parse(lastCheckpointTimeString); Entity entity = null; @@ -538,17 +398,43 @@ public Optional> fromEntityModelCheckpoint(Map(entityModel, timestamp)); + + ModelState modelState = new ModelState( + trcf, + modelId, + configId, + ModelManager.ModelType.TRCF.getName(), + clock, + 0, + Optional.ofNullable(entity), + sampleQueue + ); + modelState.setLastCheckpointTime(timestamp); + return modelState; }); } catch (Exception e) { logger.warn("Exception while deserializing checkpoint " + modelId, e); // checkpoint corrupted (e.g., a checkpoint not recognized by current code // due to bugs). Better redo training. - return Optional.empty(); + return null; } } + private Deque processSampleQueue(JsonObject json, Map checkpoint, String modelId) { + Deque sampleQueue = new ArrayDeque<>(); + if (json.has(CommonName.ENTITY_SAMPLE)) { + double[][] samplesArray = this.gson.fromJson(json.getAsJsonArray(CommonName.ENTITY_SAMPLE), doubleArrayType); + // this branch exists for bwc. Since we didn't record start and end time, we have to give a default 0. + Arrays + .stream(samplesArray) + .map(sampleArray -> new Sample(sampleArray, Instant.ofEpochMilli(0), Instant.ofEpochMilli(0))) + .forEach(sampleQueue::add); + } else { + sampleQueue = loadSampleQueue(checkpoint, modelId); + } + return sampleQueue; + } + ThresholdedRandomCutForest toTrcf(String checkpoint) { ThresholdedRandomCutForest trcf = null; if (checkpoint != null && !checkpoint.isEmpty()) { @@ -604,7 +490,7 @@ private void deserializeTRCFModel( String thresholdingModelId = SingleStreamModelIdMapper.getThresholdModelIdFromRCFModelId(rcfModelId); // query for threshold model and combinne rcf and threshold model into a ThresholdedRandomCutForest getThresholdModel(thresholdingModelId, ActionListener.wrap(thresholdingModel -> { - listener.onResponse(convertToTRCF(forest, thresholdingModel)); + listener.onResponse(convertToTRCF(forest.get(), thresholdingModel)); }, listener::onFailure)); } } catch (Exception e) { @@ -616,30 +502,14 @@ private void deserializeTRCFModel( } } - /** - * Read a checkpoint from the index and return the EntityModel object - * @param modelId Model Id - * @param listener Listener to return a pair of entity model and its last checkpoint time - */ - public void deserializeModelCheckpoint(String modelId, ActionListener>> listener) { - clientUtil.asyncRequest(new GetRequest(indexName, modelId), client::get, ActionListener.wrap(response -> { - listener.onResponse(processGetResponse(response, modelId)); - }, listener::onFailure)); - } - - /** - * Process a checkpoint GetResponse and return the EntityModel object - * @param response Checkpoint Index GetResponse - * @param modelId Model Id - * @return a pair of entity model and its last checkpoint time - */ - public Optional> processGetResponse(GetResponse response, String modelId) { - Optional> checkpointString = processRawCheckpoint(response); - if (checkpointString.isPresent()) { - return fromEntityModelCheckpoint(checkpointString.get(), modelId); - } else { - return Optional.empty(); - } + @Override + protected ModelState fromSingleStreamModelCheckpoint( + Map checkpoint, + String modelId, + String configId + ) { + // single stream AD code path is still using old way + throw new UnsupportedOperationException("This method is not supported"); } /** @@ -703,39 +573,8 @@ private Optional processThresholdModelCheckpoint(GetResponse response) { .map(source -> source.get(CommonName.FIELD_MODEL)); } - private Optional> processRawCheckpoint(GetResponse response) { - return Optional.ofNullable(response).filter(GetResponse::isExists).map(GetResponse::getSource); - } - - public void batchRead(MultiGetRequest request, ActionListener listener) { - clientUtil.execute(MultiGetAction.INSTANCE, request, listener); - } - - public void batchWrite(BulkRequest request, ActionListener listener) { - if (indexUtil.doesCheckpointIndexExist()) { - clientUtil.execute(BulkAction.INSTANCE, request, listener); - } else { - indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { - clientUtil.execute(BulkAction.INSTANCE, request, listener); - } else { - // create index failure. Notify callers using listener. - listener.onFailure(new TimeSeriesException("Creating checkpoint with mappings call not acknowledged.")); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - clientUtil.execute(BulkAction.INSTANCE, request, listener); - } else { - logger.error(String.format(Locale.ROOT, "Unexpected error creating checkpoint index"), exception); - listener.onFailure(exception); - } - })); - } - } - - private Optional convertToTRCF(Optional rcf, Optional kllThreshold) { - if (!rcf.isPresent()) { + private Optional convertToTRCF(RandomCutForest rcf, Optional kllThreshold) { + if (rcf == null) { return Optional.empty(); } // if there is no threshold model (e.g., threshold model is deleted by HourlyCron), we are gonna @@ -744,20 +583,17 @@ private Optional convertToTRCF(Optional { + private static final Logger logger = LogManager.getLogger(ADColdStart.class); + + /** + * Constructor + * + * @param clock UTC clock + * @param threadPool Accessor to different threadpools + * @param nodeStateManager Storing node state + * @param rcfSampleSize The sample size used by stream samplers in this forest + * @param numberOfTrees The number of trees in this forest. + * @param numMinSamples The number of points required by stream samplers before + * results are returned. + * @param defaultSampleStride default sample distances measured in detector intervals. + * @param defaultTrainSamples Default train samples to collect. + * @param searchFeatureDao Used to issue OS queries. + * @param thresholdMinPvalue min P-value for thresholding + * @param featureManager Used to create features for models. + * @param modelTtl time-to-live before last access time of the cold start cache. + * We have a cache to record entities that have run cold starts to avoid + * repeated unsuccessful cold start. + * @param checkpointWriteWorker queue to insert model checkpoints + * @param rcfSeed rcf random seed + * @param maxRoundofColdStart max number of rounds of cold start + * @param coolDownMinutes cool down minutes when OpenSearch is overloaded + */ + public ADColdStart( + Clock clock, + ThreadPool threadPool, + NodeStateManager nodeStateManager, + int rcfSampleSize, + int numberOfTrees, + int numMinSamples, + int defaultSampleStride, + int defaultTrainSamples, + SearchFeatureDao searchFeatureDao, + double thresholdMinPvalue, + FeatureManager featureManager, + Duration modelTtl, + ADCheckpointWriteWorker checkpointWriteWorker, + long rcfSeed, + int maxRoundofColdStart, + int coolDownMinutes + ) { + super( + modelTtl, + coolDownMinutes, + clock, + threadPool, + numMinSamples, + checkpointWriteWorker, + rcfSeed, + numberOfTrees, + rcfSampleSize, + thresholdMinPvalue, + nodeStateManager, + defaultSampleStride, + defaultTrainSamples, + searchFeatureDao, + featureManager, + maxRoundofColdStart, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + AnalysisType.AD + ); + } + + public ADColdStart( + Clock clock, + ThreadPool threadPool, + NodeStateManager nodeStateManager, + int rcfSampleSize, + int numberOfTrees, + int numMinSamples, + int maxSampleStride, + int maxTrainSamples, + SearchFeatureDao searchFeatureDao, + double thresholdMinPvalue, + FeatureManager featureManager, + Duration modelTtl, + ADCheckpointWriteWorker checkpointWriteQueue, + int maxRoundofColdStart, + int coolDownMinutes + ) { + this( + clock, + threadPool, + nodeStateManager, + rcfSampleSize, + numberOfTrees, + numMinSamples, + maxSampleStride, + maxTrainSamples, + searchFeatureDao, + thresholdMinPvalue, + featureManager, + modelTtl, + checkpointWriteQueue, + -1, + maxRoundofColdStart, + coolDownMinutes + ); + } + + /** + * Train model using given data points and save the trained model. + * + * @param pointSamples A pair consisting of a queue of continuous data points, + * in ascending order of timestamps and last seen sample. + * @param entity Entity instance + * @param entityState Entity state associated with the model Id + * @return the training samples. We can save the + * training data in result index so that the frontend can plot it. + */ + @Override + protected List trainModelFromDataSegments( + List pointSamples, + Optional entity, + ModelState entityState, + Config config, + String taskId + ) { + if (entity.isEmpty()) { + throw new IllegalArgumentException("We offer only HC cold start"); + } + + if (pointSamples == null || pointSamples.size() == 0) { + logger.info("Return early since data points must not be empty."); + return null; + } + + double[] firstPoint = pointSamples.get(0).getValueList(); + if (firstPoint == null || firstPoint.length == 0) { + logger.info("Return early since data points must not be empty."); + return null; + } + + int shingleSize = config.getShingleSize(); + int baseDimension = firstPoint.length; + int dimensions = baseDimension * shingleSize; + ThresholdedRandomCutForest.Builder rcfBuilder = ThresholdedRandomCutForest + .builder() + .dimensions(dimensions) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(config.getTimeDecay()) + .transformDecay(config.getTimeDecay()) + .outputAfter(Math.max(shingleSize, numMinSamples)) + .initialAcceptFraction(initialAcceptFraction) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + // same with dimension for opportunistic memory saving + // Usually, we use it as shingleSize(dimension). When a new point comes in, we will + // look at the point store if there is any overlapping. Say the previously-stored + // vector is x1, x2, x3, x4, now we add x3, x4, x5, x6. RCF will recognize + // overlapping x3, x4, and only store x5, x6. + .shingleSize(shingleSize) + .internalShinglingEnabled(true) + .anomalyRate(1 - this.thresholdMinPvalue) + .transformMethod(TransformMethod.NORMALIZE) + .alertOnce(true) + .autoAdjust(true); + + if (shingleSize > 1) { + rcfBuilder.forestMode(ForestMode.STREAMING_IMPUTE); + rcfBuilder = applyImputationMethod(config, rcfBuilder); + } else { + // imputation with shingle size 1 is not meaningful + rcfBuilder.forestMode(ForestMode.STANDARD); + } + + if (rcfSeed > 0) { + rcfBuilder.randomSeed(rcfSeed); + } + + AnomalyDetector detector = (AnomalyDetector) config; + ThresholdArrays thresholdArrays = IgnoreSimilarExtractor.processDetectorRules(detector); + + if (thresholdArrays.ignoreSimilarFromAbove != null && thresholdArrays.ignoreSimilarFromAbove.length > 0) { + rcfBuilder.ignoreNearExpectedFromAbove(thresholdArrays.ignoreSimilarFromAbove); + } + + if (thresholdArrays.ignoreSimilarFromBelow != null && thresholdArrays.ignoreSimilarFromBelow.length > 0) { + rcfBuilder.ignoreNearExpectedFromBelow(thresholdArrays.ignoreSimilarFromBelow); + } + + if (thresholdArrays.ignoreSimilarFromAboveByRatio != null && thresholdArrays.ignoreSimilarFromAboveByRatio.length > 0) { + rcfBuilder.ignoreNearExpectedFromAboveByRatio(thresholdArrays.ignoreSimilarFromAboveByRatio); + } + + if (thresholdArrays.ignoreSimilarFromBelowByRatio != null && thresholdArrays.ignoreSimilarFromBelowByRatio.length > 0) { + rcfBuilder.ignoreNearExpectedFromBelowByRatio(thresholdArrays.ignoreSimilarFromBelowByRatio); + } + + // use build instead of new TRCF(Builder) because build method did extra validation and initialization + ThresholdedRandomCutForest trcf = rcfBuilder.build(); + + for (int i = 0; i < pointSamples.size(); i++) { + Sample dataSample = pointSamples.get(i); + double[] dataValue = dataSample.getValueList(); + trcf.process(dataValue, dataSample.getDataEndTime().getEpochSecond()); + } + + entityState.setModel(trcf); + + entityState.setLastUsedTime(clock.instant()); + + // save to checkpoint + checkpointWriteWorker.write(entityState, true, RequestPriority.MEDIUM); + + return pointSamples; + } +} diff --git a/src/main/java/org/opensearch/ad/ml/ModelManager.java b/src/main/java/org/opensearch/ad/ml/ADModelManager.java similarity index 71% rename from src/main/java/org/opensearch/ad/ml/ModelManager.java rename to src/main/java/org/opensearch/ad/ml/ADModelManager.java index 14f935aae..6d26a5448 100644 --- a/src/main/java/org/opensearch/ad/ml/ModelManager.java +++ b/src/main/java/org/opensearch/ad/ml/ADModelManager.java @@ -14,7 +14,6 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.ArrayDeque; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; @@ -23,7 +22,6 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Optional; -import java.util.Queue; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; @@ -31,22 +29,28 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ad.DetectorModelSize; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.util.DateUtils; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.AnalysisModelSize; import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; -import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DateUtils; import com.amazon.randomcutforest.RandomCutForest; import com.amazon.randomcutforest.config.Precision; @@ -57,51 +61,27 @@ /** * A facade managing ML operations and models. */ -public class ModelManager implements DetectorModelSize { +public class ADModelManager extends + ModelManager + implements + AnalysisModelSize { protected static final String ENTITY_SAMPLE = "sp"; protected static final String ENTITY_RCF = "rcf"; protected static final String ENTITY_THRESHOLD = "th"; - public enum ModelType { - RCF("rcf"), - THRESHOLD("threshold"), - ENTITY("entity"); - - private String name; - - ModelType(String name) { - this.name = name; - } - - public String getName() { - return name; - } - } - - private static final Logger logger = LogManager.getLogger(ModelManager.class); + private static final Logger logger = LogManager.getLogger(ADModelManager.class); // states - private TRCFMemoryAwareConcurrentHashmap forests; + private MemoryAwareConcurrentHashmap forests; private Map> thresholds; // configuration - private final int rcfNumTrees; - private final int rcfNumSamplesInTree; - private final double rcfTimeDecay; - private final int rcfNumMinSamples; + private final double thresholdMinPvalue; private final int minPreviewSize; private final Duration modelTtl; private Duration checkpointInterval; - // dependencies - private final CheckpointDao checkpointDao; - private final Clock clock; - public FeatureManager featureManager; - - private EntityColdStarter entityColdStarter; - private MemoryTracker memoryTracker; - private final double initialAcceptFraction; /** @@ -111,7 +91,6 @@ public String getName() { * @param clock clock for system time * @param rcfNumTrees number of trees used in RCF * @param rcfNumSamplesInTree number of samples in a RCF tree - * @param rcfTimeDecay time decay for RCF * @param rcfNumMinSamples minimum samples for RCF to score * @param thresholdMinPvalue min P-value for thresholding * @param minPreviewSize minimum number of data points for preview @@ -123,29 +102,24 @@ public String getName() { * @param settings Node settings * @param clusterService Cluster service accessor */ - public ModelManager( - CheckpointDao checkpointDao, + public ADModelManager( + ADCheckpointDao checkpointDao, Clock clock, int rcfNumTrees, int rcfNumSamplesInTree, - double rcfTimeDecay, int rcfNumMinSamples, double thresholdMinPvalue, int minPreviewSize, Duration modelTtl, Setting checkpointIntervalSetting, - EntityColdStarter entityColdStarter, + ADColdStart entityColdStarter, FeatureManager featureManager, MemoryTracker memoryTracker, Settings settings, ClusterService clusterService ) { - this.checkpointDao = checkpointDao; - this.clock = clock; - this.rcfNumTrees = rcfNumTrees; - this.rcfNumSamplesInTree = rcfNumSamplesInTree; - this.rcfTimeDecay = rcfTimeDecay; - this.rcfNumMinSamples = rcfNumMinSamples; + super(rcfNumTrees, rcfNumSamplesInTree, rcfNumMinSamples, entityColdStarter, memoryTracker, clock, featureManager, checkpointDao); + this.thresholdMinPvalue = thresholdMinPvalue; this.minPreviewSize = minPreviewSize; this.modelTtl = modelTtl; @@ -156,12 +130,9 @@ public ModelManager( .addSettingsUpdateConsumer(checkpointIntervalSetting, it -> this.checkpointInterval = DateUtils.toDuration(it)); } - this.forests = new TRCFMemoryAwareConcurrentHashmap<>(memoryTracker); + this.forests = new MemoryAwareConcurrentHashmap<>(memoryTracker); this.thresholds = new ConcurrentHashMap<>(); - this.entityColdStarter = entityColdStarter; - this.featureManager = featureManager; - this.memoryTracker = memoryTracker; this.initialAcceptFraction = rcfNumMinSamples * 1.0d / rcfNumSamplesInTree; } @@ -198,10 +169,14 @@ private void getTRcfResult( ) { modelState.setLastUsedTime(clock.instant()); - ThresholdedRandomCutForest trcf = modelState.getModel(); + Optional trcfOptional = modelState.getModel(); + if (trcfOptional.isEmpty()) { + listener.onFailure(new TimeSeriesException("empty model")); + return; + } try { - AnomalyDescriptor result = trcf.process(point, 0); - double[] attribution = normalizeAttribution(trcf.getForest(), result.getRelevantAttribution()); + AnomalyDescriptor result = trcfOptional.get().process(point, 0); + double[] attribution = normalizeAttribution(trcfOptional.get().getForest(), result.getRelevantAttribution()); listener .onResponse( new ThresholdingResult( @@ -276,7 +251,7 @@ private double[] createEmptyAttribution(RandomCutForest forest) { return new double[baseDimensions]; } - private Optional> restoreModelState( + Optional> restoreModelState( Optional rcfModel, String modelId, String detectorId @@ -286,7 +261,7 @@ private Optional> restoreModelState( } return rcfModel .filter(rcf -> memoryTracker.isHostingAllowed(detectorId, rcf)) - .map(rcf -> ModelState.createSingleEntityModelState(rcf, modelId, detectorId, ModelType.RCF.getName(), clock)); + .map(rcf -> new ModelState(rcf, modelId, detectorId, ModelManager.ModelType.TRCF.getName(), clock)); } private void processRestoredTRcf( @@ -320,13 +295,16 @@ private void processRestoredCheckpoint( ) { logger.info("Restoring checkpoint for {}", modelId); Optional> model = restoreModelState(checkpointModel, modelId, detectorId); - if (model.isPresent()) { - forests.put(modelId, model.get()); - if (model.get().getModel() != null && model.get().getModel().getForest() != null) - listener.onResponse(model.get().getModel().getForest().getTotalUpdates()); - } else { - listener.onFailure(new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId)); - } + model.ifPresentOrElse(modelState -> { + forests.put(modelId, modelState); + modelState.getModel().ifPresent(trcf -> { + if (trcf.getForest() != null) { + listener.onResponse(trcf.getForest().getTotalUpdates()); + } else { + listener.onFailure(new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId)); + } + }); + }, () -> listener.onFailure(new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId))); } /** @@ -356,14 +334,26 @@ private void getThresholdingResult( double score, ActionListener listener ) { - ThresholdingModel threshold = modelState.getModel(); - double grade = threshold.grade(score); - double confidence = threshold.confidence(); - if (score > 0) { - threshold.update(score); + Optional thresholdOptional = modelState.getModel(); + if (thresholdOptional.isPresent()) { + ThresholdingModel threshold = thresholdOptional.get(); + double grade = threshold.grade(score); + double confidence = threshold.confidence(); + if (score > 0) { + threshold.update(score); + } + modelState.setLastUsedTime(clock.instant()); + listener.onResponse(new ThresholdingResult(grade, confidence, score)); + } else { + listener + .onFailure( + new ResourceNotFoundException( + modelState.getConfigId(), + ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelState.getModelId() + ) + ); } - modelState.setLastUsedTime(clock.instant()); - listener.onResponse(new ThresholdingResult(grade, confidence, score)); + } private void processThresholdCheckpoint( @@ -374,9 +364,7 @@ private void processThresholdCheckpoint( ActionListener listener ) { Optional> model = thresholdModel - .map( - threshold -> ModelState.createSingleEntityModelState(threshold, modelId, detectorId, ModelType.THRESHOLD.getName(), clock) - ); + .map(threshold -> new ModelState<>(threshold, modelId, detectorId, ModelManager.ModelType.THRESHOLD.getName(), clock)); if (model.isPresent()) { thresholds.put(modelId, model.get()); getThresholdingResult(model.get(), score, listener); @@ -424,8 +412,8 @@ private void stopModel(Map> models, String modelId, Ac Optional> modelState = Optional .ofNullable(models.remove(modelId)) .filter(model -> model.getLastCheckpointTime().plus(checkpointInterval).isBefore(now)); - if (modelState.isPresent()) { - T model = modelState.get().getModel(); + if (modelState.isPresent() && modelState.get().getModel().isPresent()) { + T model = modelState.get().getModel().get(); if (model instanceof ThresholdedRandomCutForest) { checkpointDao .putTRCFCheckpoint( @@ -460,29 +448,6 @@ public void clear(String detectorId, ActionListener listener) { clearModels(detectorId, forests, ActionListener.wrap(r -> clearModels(detectorId, thresholds, listener), listener::onFailure)); } - private void clearModels(String detectorId, Map models, ActionListener listener) { - Iterator id = models.keySet().iterator(); - clearModelForIterator(detectorId, models, id, listener); - } - - private void clearModelForIterator(String detectorId, Map models, Iterator idIter, ActionListener listener) { - if (idIter.hasNext()) { - String modelId = idIter.next(); - if (SingleStreamModelIdMapper.getDetectorIdForModelId(modelId).equals(detectorId)) { - models.remove(modelId); - checkpointDao - .deleteModelCheckpoint( - modelId, - ActionListener.wrap(r -> clearModelForIterator(detectorId, models, idIter, listener), listener::onFailure) - ); - } else { - clearModelForIterator(detectorId, models, idIter, listener); - } - } else { - listener.onResponse(null); - } - } - /** * Trains and saves cold-start AD models. * @@ -523,7 +488,7 @@ private void trainModelForStep( .dimensions(rcfNumFeatures) .sampleSize(rcfNumSamplesInTree) .numberOfTrees(rcfNumTrees) - .timeDecay(rcfTimeDecay) + .timeDecay(detector.getTimeDecay()) .outputAfter(rcfNumMinSamples) .initialAcceptFraction(initialAcceptFraction) .parallelExecutionEnabled(false) @@ -579,13 +544,18 @@ private void maintenanceForIterator( logger.warn("Failed to finish maintenance for model id " + modelId, e); maintenanceForIterator(models, iter, listener); }); - T model = modelState.getModel(); - if (model instanceof ThresholdedRandomCutForest) { - checkpointDao.putTRCFCheckpoint(modelId, (ThresholdedRandomCutForest) model, checkpointListener); - } else if (model instanceof ThresholdingModel) { - checkpointDao.putThresholdCheckpoint(modelId, (ThresholdingModel) model, checkpointListener); + Optional modelOptional = modelState.getModel(); + if (modelOptional.isPresent()) { + T model = modelOptional.get(); + if (model instanceof ThresholdedRandomCutForest) { + checkpointDao.putTRCFCheckpoint(modelId, (ThresholdedRandomCutForest) model, checkpointListener); + } else if (model instanceof ThresholdingModel) { + checkpointDao.putThresholdCheckpoint(modelId, (ThresholdingModel) model, checkpointListener); + } else { + checkpointListener.onFailure(new IllegalArgumentException("Unexpected model type")); + } } else { - checkpointListener.onFailure(new IllegalArgumentException("Unexpected model type")); + maintenanceForIterator(models, iter, listener); } } else { maintenanceForIterator(models, iter, listener); @@ -603,7 +573,7 @@ private void maintenanceForIterator( * @return thresholding results of preview data points * @throws IllegalArgumentException when preview data points are not valid */ - public List getPreviewResults(double[][] dataPoints, int shingleSize) { + public List getPreviewResults(double[][] dataPoints, int shingleSize, double rcfTimeDecay) { if (dataPoints.length < minPreviewSize) { throw new IllegalArgumentException("Insufficient data for preview results. Minimum required: " + minPreviewSize); } @@ -657,17 +627,11 @@ public List getPreviewResults(double[][] dataPoints, int shi @Override public Map getModelSize(String detectorId) { Map res = new HashMap<>(); - forests - .entrySet() - .stream() - .filter(entry -> SingleStreamModelIdMapper.getDetectorIdForModelId(entry.getKey()).equals(detectorId)) - .forEach(entry -> { - res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(entry.getValue().getModel())); - }); + res.putAll(forests.getModelSize(detectorId)); thresholds .entrySet() .stream() - .filter(entry -> SingleStreamModelIdMapper.getDetectorIdForModelId(entry.getKey()).equals(detectorId)) + .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(detectorId)) .forEach(entry -> { res.put(entry.getKey(), (long) memoryTracker.getThresholdModelBytes()); }); @@ -683,8 +647,8 @@ public Map getModelSize(String detectorId) { public void getTotalUpdates(String modelId, String detectorId, ActionListener listener) { ModelState model = forests.get(modelId); if (model != null) { - if (model.getModel() != null && model.getModel().getForest() != null) { - listener.onResponse(model.getModel().getForest().getTotalUpdates()); + if (model.getModel().isPresent() && model.getModel().get().getForest() != null) { + listener.onResponse(model.getModel().get().getForest().getTotalUpdates()); } else { listener.onResponse(0L); } @@ -698,131 +662,13 @@ public void getTotalUpdates(String modelId, String detectorId, ActionListener modelState, - String modelId, - Entity entity, - int shingleSize - ) { - ThresholdingResult result = new ThresholdingResult(0, 0, 0); - if (modelState != null) { - EntityModel entityModel = modelState.getModel(); - - if (entityModel == null) { - entityModel = new EntityModel(entity, new ArrayDeque<>(), null); - modelState.setModel(entityModel); - } - - if (!entityModel.getTrcf().isPresent()) { - entityColdStarter.trainModelFromExistingSamples(modelState, shingleSize); - } - - if (entityModel.getTrcf().isPresent()) { - result = score(datapoint, modelId, modelState); - } else { - entityModel.addSample(datapoint); - } - } - return result; - } - - public ThresholdingResult score(double[] feature, String modelId, ModelState modelState) { - ThresholdingResult result = new ThresholdingResult(0, 0, 0); - EntityModel model = modelState.getModel(); - try { - if (model != null && model.getTrcf().isPresent()) { - ThresholdedRandomCutForest trcf = model.getTrcf().get(); - Optional.ofNullable(model.getSamples()).ifPresent(q -> { - q.stream().forEach(s -> trcf.process(s, 0)); - q.clear(); - }); - result = toResult(trcf.getForest(), trcf.process(feature, 0)); - } - } catch (Exception e) { - logger - .error( - new ParameterizedMessage( - "Fail to score for [{}]: model Id [{}], feature [{}]", - modelState.getModel().getEntity(), - modelId, - Arrays.toString(feature) - ), - e - ); - throw e; - } finally { - modelState.setLastUsedTime(clock.instant()); - } - return result; - } - - /** - * Instantiate an entity state out of checkpoint. Train models if there are - * enough samples. - * @param checkpoint Checkpoint loaded from index - * @param entity objects to access Entity attributes - * @param modelId Model Id - * @param detectorId Detector Id - * @param shingleSize Shingle size - * - * @return updated model state - * - */ - public ModelState processEntityCheckpoint( - Optional> checkpoint, - Entity entity, - String modelId, - String detectorId, - int shingleSize - ) { - // entity state to instantiate - ModelState modelState = new ModelState<>( - new EntityModel(entity, new ArrayDeque<>(), null), - modelId, - detectorId, - ModelType.ENTITY.getName(), - clock, - 0 - ); - - if (checkpoint.isPresent()) { - Entry modelToTime = checkpoint.get(); - EntityModel restoredModel = modelToTime.getKey(); - combineSamples(modelState.getModel(), restoredModel); - modelState.setModel(restoredModel); - modelState.setLastCheckpointTime(modelToTime.getValue()); - } - EntityModel model = modelState.getModel(); - if (model == null) { - model = new EntityModel(null, new ArrayDeque<>(), null); - modelState.setModel(model); - } - - if (!model.getTrcf().isPresent() && model.getSamples() != null && model.getSamples().size() >= rcfNumMinSamples) { - entityColdStarter.trainModelFromExistingSamples(modelState, shingleSize); - } - return modelState; - } - - private void combineSamples(EntityModel fromModel, EntityModel toModel) { - Queue samples = fromModel.getSamples(); - while (samples.peek() != null) { - toModel.addSample(samples.poll()); - } + @Override + protected ThresholdingResult createEmptyResult() { + return new ThresholdingResult(0, 0, 0); } - private ThresholdingResult toResult(RandomCutForest rcf, AnomalyDescriptor anomalyDescriptor) { + @Override + protected ThresholdingResult toResult(RandomCutForest rcf, AnomalyDescriptor anomalyDescriptor) { return new ThresholdingResult( anomalyDescriptor.getAnomalyGrade(), anomalyDescriptor.getDataConfidence(), diff --git a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java deleted file mode 100644 index 1044b84ce..000000000 --- a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java +++ /dev/null @@ -1,758 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ml; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; - -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.util.AbstractMap.SimpleImmutableEntry; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; -import java.util.Queue; -import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.core.util.Throwables; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.ad.caching.DoorKeeper; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.settings.ADEnabledSetting; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.AnalysisType; -import org.opensearch.timeseries.CleanState; -import org.opensearch.timeseries.MaintenanceState; -import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.TimeSeriesException; -import org.opensearch.timeseries.dataprocessor.Imputer; -import org.opensearch.timeseries.feature.SearchFeatureDao; -import org.opensearch.timeseries.model.Config; -import org.opensearch.timeseries.model.Entity; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.settings.TimeSeriesSettings; -import org.opensearch.timeseries.util.ExceptionUtil; - -import com.amazon.randomcutforest.config.Precision; -import com.amazon.randomcutforest.config.TransformMethod; -import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; - -/** - * Training models for HCAD detectors - * - */ -public class EntityColdStarter implements MaintenanceState, CleanState { - private static final Logger logger = LogManager.getLogger(EntityColdStarter.class); - private final Clock clock; - private final ThreadPool threadPool; - private final NodeStateManager nodeStateManager; - private final int rcfSampleSize; - private final int numberOfTrees; - private final double rcfTimeDecay; - private final int numMinSamples; - private final double thresholdMinPvalue; - private final int defaulStrideLength; - private final int defaultNumberOfSamples; - private final Imputer imputer; - private final SearchFeatureDao searchFeatureDao; - private Instant lastThrottledColdStartTime; - private final FeatureManager featureManager; - private int coolDownMinutes; - // A bloom filter checked before cold start to ensure we don't repeatedly - // retry cold start of the same model. - // keys are detector ids. - private Map doorKeepers; - private final Duration modelTtl; - private final CheckpointWriteWorker checkpointWriteQueue; - // make sure rcf use a specific random seed. Otherwise, we will use a random random (not a typo) seed. - // this is mainly used for testing to make sure the model we trained and the reference rcf produce - // the same results - private final long rcfSeed; - private final int maxRoundofColdStart; - private final double initialAcceptFraction; - - /** - * Constructor - * - * @param clock UTC clock - * @param threadPool Accessor to different threadpools - * @param nodeStateManager Storing node state - * @param rcfSampleSize The sample size used by stream samplers in this forest - * @param numberOfTrees The number of trees in this forest. - * @param rcfTimeDecay rcf samples time decay constant - * @param numMinSamples The number of points required by stream samplers before - * results are returned. - * @param defaultSampleStride default sample distances measured in detector intervals. - * @param defaultTrainSamples Default train samples to collect. - * @param imputer Used to generate data points between samples. - * @param searchFeatureDao Used to issue ES queries. - * @param thresholdMinPvalue min P-value for thresholding - * @param featureManager Used to create features for models. - * @param settings ES settings accessor - * @param modelTtl time-to-live before last access time of the cold start cache. - * We have a cache to record entities that have run cold starts to avoid - * repeated unsuccessful cold start. - * @param checkpointWriteQueue queue to insert model checkpoints - * @param rcfSeed rcf random seed - * @param maxRoundofColdStart max number of rounds of cold start - */ - public EntityColdStarter( - Clock clock, - ThreadPool threadPool, - NodeStateManager nodeStateManager, - int rcfSampleSize, - int numberOfTrees, - double rcfTimeDecay, - int numMinSamples, - int defaultSampleStride, - int defaultTrainSamples, - Imputer imputer, - SearchFeatureDao searchFeatureDao, - double thresholdMinPvalue, - FeatureManager featureManager, - Settings settings, - Duration modelTtl, - CheckpointWriteWorker checkpointWriteQueue, - long rcfSeed, - int maxRoundofColdStart - ) { - this.clock = clock; - this.lastThrottledColdStartTime = Instant.MIN; - this.threadPool = threadPool; - this.nodeStateManager = nodeStateManager; - this.rcfSampleSize = rcfSampleSize; - this.numberOfTrees = numberOfTrees; - this.rcfTimeDecay = rcfTimeDecay; - this.numMinSamples = numMinSamples; - this.defaulStrideLength = defaultSampleStride; - this.defaultNumberOfSamples = defaultTrainSamples; - this.imputer = imputer; - this.searchFeatureDao = searchFeatureDao; - this.thresholdMinPvalue = thresholdMinPvalue; - this.featureManager = featureManager; - this.coolDownMinutes = (int) (AD_COOLDOWN_MINUTES.get(settings).getMinutes()); - this.doorKeepers = new ConcurrentHashMap<>(); - this.modelTtl = modelTtl; - this.checkpointWriteQueue = checkpointWriteQueue; - this.rcfSeed = rcfSeed; - this.maxRoundofColdStart = maxRoundofColdStart; - this.initialAcceptFraction = numMinSamples * 1.0d / rcfSampleSize; - } - - public EntityColdStarter( - Clock clock, - ThreadPool threadPool, - NodeStateManager nodeStateManager, - int rcfSampleSize, - int numberOfTrees, - double rcfTimeDecay, - int numMinSamples, - int maxSampleStride, - int maxTrainSamples, - Imputer imputer, - SearchFeatureDao searchFeatureDao, - double thresholdMinPvalue, - FeatureManager featureManager, - Settings settings, - Duration modelTtl, - CheckpointWriteWorker checkpointWriteQueue, - int maxRoundofColdStart - ) { - this( - clock, - threadPool, - nodeStateManager, - rcfSampleSize, - numberOfTrees, - rcfTimeDecay, - numMinSamples, - maxSampleStride, - maxTrainSamples, - imputer, - searchFeatureDao, - thresholdMinPvalue, - featureManager, - settings, - modelTtl, - checkpointWriteQueue, - -1, - maxRoundofColdStart - ); - } - - /** - * Training model for an entity - * @param modelId model Id corresponding to the entity - * @param entity the entity's information - * @param detectorId the detector Id corresponding to the entity - * @param modelState model state associated with the entity - * @param listener call back to call after cold start - */ - private void coldStart( - String modelId, - Entity entity, - String detectorId, - ModelState modelState, - AnomalyDetector detector, - ActionListener listener - ) { - logger.debug("Trigger cold start for {}", modelId); - - if (modelState == null || entity == null) { - listener - .onFailure( - new IllegalArgumentException( - String - .format( - Locale.ROOT, - "Cannot have empty model state or entity: model state [%b], entity [%b]", - modelState == null, - entity == null - ) - ) - ); - return; - } - - if (lastThrottledColdStartTime.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) { - listener.onResponse(null); - return; - } - - boolean earlyExit = true; - try { - DoorKeeper doorKeeper = doorKeepers.computeIfAbsent(detectorId, id -> { - // reset every 60 intervals - return new DoorKeeper( - TimeSeriesSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION, - TimeSeriesSettings.DOOR_KEEPER_FALSE_POSITIVE_RATE, - detector.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), - clock - ); - }); - - // Won't retry cold start within 60 intervals for an entity - if (doorKeeper.mightContain(modelId)) { - return; - } - - doorKeeper.put(modelId); - - ActionListener>> coldStartCallBack = ActionListener.wrap(trainingData -> { - try { - if (trainingData.isPresent()) { - List dataPoints = trainingData.get(); - extractTrainSamples(dataPoints, modelId, modelState); - Queue samples = modelState.getModel().getSamples(); - // only train models if we have enough samples - if (samples.size() >= numMinSamples) { - // The function trainModelFromDataSegments will save a trained a model. trainModelFromDataSegments is called by - // multiple places so I want to make the saving model implicit just in case I forgot. - trainModelFromDataSegments(samples, entity, modelState, detector.getShingleSize()); - logger.info("Succeeded in training entity: {}", modelId); - } else { - // save to checkpoint - checkpointWriteQueue.write(modelState, true, RequestPriority.MEDIUM); - logger.info("Not enough data to train entity: {}, currently we have {}", modelId, samples.size()); - } - } else { - logger.info("Cannot get training data for {}", modelId); - } - listener.onResponse(null); - } catch (Exception e) { - listener.onFailure(e); - } - }, exception -> { - try { - logger.error(new ParameterizedMessage("Error while cold start {}", modelId), exception); - Throwable cause = Throwables.getRootCause(exception); - if (ExceptionUtil.isOverloaded(cause)) { - logger.error("too many requests"); - lastThrottledColdStartTime = Instant.now(); - } else if (cause instanceof TimeSeriesException || exception instanceof TimeSeriesException) { - // e.g., cannot find anomaly detector - nodeStateManager.setException(detectorId, exception); - } else { - nodeStateManager.setException(detectorId, new TimeSeriesException(detectorId, cause)); - } - listener.onFailure(exception); - } catch (Exception e) { - listener.onFailure(e); - } - }); - - threadPool - .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) - .execute( - () -> getEntityColdStartData( - detectorId, - entity, - new ThreadedActionListener<>( - logger, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - coldStartCallBack, - false - ) - ) - ); - earlyExit = false; - } finally { - if (earlyExit) { - listener.onResponse(null); - } - } - } - - /** - * Train model using given data points and save the trained model. - * - * @param dataPoints Queue of continuous data points, in ascending order of timestamps - * @param entity Entity instance - * @param entityState Entity state associated with the model Id - */ - private void trainModelFromDataSegments( - Queue dataPoints, - Entity entity, - ModelState entityState, - int shingleSize - ) { - if (dataPoints == null || dataPoints.size() == 0) { - throw new IllegalArgumentException("Data points must not be empty."); - } - - double[] firstPoint = dataPoints.peek(); - if (firstPoint == null || firstPoint.length == 0) { - throw new IllegalArgumentException("Data points must not be empty."); - } - int dimensions = firstPoint.length * shingleSize; - ThresholdedRandomCutForest.Builder rcfBuilder = ThresholdedRandomCutForest - .builder() - .dimensions(dimensions) - .sampleSize(rcfSampleSize) - .numberOfTrees(numberOfTrees) - .timeDecay(rcfTimeDecay) - .outputAfter(numMinSamples) - .initialAcceptFraction(initialAcceptFraction) - .parallelExecutionEnabled(false) - .compact(true) - .precision(Precision.FLOAT_32) - .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) - // same with dimension for opportunistic memory saving - // Usually, we use it as shingleSize(dimension). When a new point comes in, we will - // look at the point store if there is any overlapping. Say the previously-stored - // vector is x1, x2, x3, x4, now we add x3, x4, x5, x6. RCF will recognize - // overlapping x3, x4, and only store x5, x6. - .shingleSize(shingleSize) - .internalShinglingEnabled(true) - .anomalyRate(1 - this.thresholdMinPvalue) - .transformMethod(TransformMethod.NORMALIZE) - .alertOnce(true) - .autoAdjust(true); - - if (rcfSeed > 0) { - rcfBuilder.randomSeed(rcfSeed); - } - ThresholdedRandomCutForest trcf = new ThresholdedRandomCutForest(rcfBuilder); - while (!dataPoints.isEmpty()) { - trcf.process(dataPoints.poll(), 0); - } - EntityModel model = entityState.getModel(); - if (model == null) { - model = new EntityModel(entity, new ArrayDeque<>(), null); - } - model.setTrcf(trcf); - - entityState.setLastUsedTime(clock.instant()); - - // save to checkpoint - checkpointWriteQueue.write(entityState, true, RequestPriority.MEDIUM); - } - - /** - * Get training data for an entity. - * - * We first note the maximum and minimum timestamp, and sample at most 24 points - * (with 60 points apart between two neighboring samples) between those minimum - * and maximum timestamps. Samples can be missing. We only interpolate points - * between present neighboring samples. We then transform samples and interpolate - * points to shingles. Finally, full shingles will be used for cold start. - * - * @param detectorId detector Id - * @param entity the entity's information - * @param listener listener to return training data - */ - private void getEntityColdStartData(String detectorId, Entity entity, ActionListener>> listener) { - ActionListener> getDetectorListener = ActionListener.wrap(detectorOp -> { - if (!detectorOp.isPresent()) { - listener.onFailure(new EndRunException(detectorId, "AnomalyDetector is not available.", false)); - return; - } - List coldStartData = new ArrayList<>(); - AnomalyDetector detector = (AnomalyDetector) detectorOp.get(); - - ActionListener> minTimeListener = ActionListener.wrap(earliest -> { - if (earliest.isPresent()) { - long startTimeMs = earliest.get().longValue(); - - // End time uses milliseconds as start time is assumed to be in milliseconds. - // Opensearch uses a set of preconfigured formats to recognize and parse these - // strings into a long value - // representing milliseconds-since-the-epoch in UTC. - // More on https://tinyurl.com/wub4fk92 - - long endTimeMs = clock.millis(); - Pair params = selectRangeParam(detector); - int stride = params.getLeft(); - int numberOfSamples = params.getRight(); - - // we start with round 0 - getFeatures(listener, 0, coldStartData, detector, entity, stride, numberOfSamples, startTimeMs, endTimeMs); - } else { - listener.onResponse(Optional.empty()); - } - }, listener::onFailure); - - searchFeatureDao - .getMinDataTime( - detector, - Optional.ofNullable(entity), - AnalysisType.AD, - new ThreadedActionListener<>(logger, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, minTimeListener, false) - ); - - }, listener::onFailure); - - nodeStateManager - .getConfig( - detectorId, - AnalysisType.AD, - new ThreadedActionListener<>(logger, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, getDetectorListener, false) - ); - } - - private void getFeatures( - ActionListener>> listener, - int round, - List lastRoundColdStartData, - AnomalyDetector detector, - Entity entity, - int stride, - int numberOfSamples, - long startTimeMs, - long endTimeMs - ) { - if (startTimeMs >= endTimeMs || endTimeMs - startTimeMs < detector.getIntervalInMilliseconds()) { - listener.onResponse(Optional.of(lastRoundColdStartData)); - return; - } - - // create ranges in desending order, we will reorder it in ascending order - // in Opensearch's response - List> sampleRanges = getTrainSampleRanges(detector, startTimeMs, endTimeMs, stride, numberOfSamples); - - if (sampleRanges.isEmpty()) { - listener.onResponse(Optional.of(lastRoundColdStartData)); - return; - } - - ActionListener>> getFeaturelistener = ActionListener.wrap(featureSamples -> { - // storing lastSample = null; - List currentRoundColdStartData = new ArrayList<>(); - - // featuresSamples are in ascending order of time. - for (int i = 0; i < featureSamples.size(); i++) { - Optional featuresOptional = featureSamples.get(i); - if (featuresOptional.isPresent()) { - // we only need the most recent two samples - // For the missing samples we use linear interpolation as well. - // Denote the Samples S0, S1, ... as samples in reverse order of time. - // Each [Si​,Si−1​]corresponds to strideLength * detector interval. - // If we got samples for S0, S1, S4 (both S2 and S3 are missing), then - // we interpolate the [S4,S1] into 3*strideLength pieces. - if (lastSample != null) { - // right sample has index i and feature featuresOptional.get() - int numInterpolants = (i - lastSample.getLeft()) * stride + 1; - double[][] points = featureManager - .transpose( - imputer - .impute( - featureManager.transpose(new double[][] { lastSample.getRight(), featuresOptional.get() }), - numInterpolants - ) - ); - // the last point will be included in the next iteration or we process - // it in the end. We don't want to repeatedly include the samples twice. - currentRoundColdStartData.add(Arrays.copyOfRange(points, 0, points.length - 1)); - } - lastSample = Pair.of(i, featuresOptional.get()); - } - } - - if (lastSample != null) { - currentRoundColdStartData.add(new double[][] { lastSample.getRight() }); - } - if (lastRoundColdStartData.size() > 0) { - currentRoundColdStartData.addAll(lastRoundColdStartData); - } - - // If the first round of probe provides (32+shingleSize) points (note that if S0 is - // missing or all Si​ for some i > N is missing then we would miss a lot of points. - // Otherwise we can issue another round of query — if there is any sample in the - // second round then we would have 32 + shingleSize points. If there is no sample - // in the second round then we should wait for real data. - if (calculateColdStartDataSize(currentRoundColdStartData) >= detector.getShingleSize() + numMinSamples - || round + 1 >= maxRoundofColdStart) { - listener.onResponse(Optional.of(currentRoundColdStartData)); - } else { - // the last sample's start time is the endTimeMs of next round of probe. - long lastSampleStartTime = sampleRanges.get(sampleRanges.size() - 1).getKey(); - getFeatures( - listener, - round + 1, - currentRoundColdStartData, - detector, - entity, - stride, - numberOfSamples, - startTimeMs, - lastSampleStartTime - ); - } - }, listener::onFailure); - - try { - searchFeatureDao - .getColdStartSamplesForPeriods( - detector, - sampleRanges, - Optional.ofNullable(entity), - // Accept empty bucket. - // 0, as returned by the engine should constitute a valid answer, “null” is a missing answer — it may be that 0 - // is meaningless in some case, but 0 is also meaningful in some cases. It may be that the query defining the - // metric is ill-formed, but that cannot be solved by cold-start strategy of the AD plugin — if we attempt to do - // that, we will have issues with legitimate interpretations of 0. - true, - AnalysisType.AD, - new ThreadedActionListener<>( - logger, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - getFeaturelistener, - false - ) - ); - } catch (Exception e) { - listener.onFailure(e); - } - } - - private int calculateColdStartDataSize(List coldStartData) { - int size = 0; - for (int i = 0; i < coldStartData.size(); i++) { - size += coldStartData.get(i).length; - } - return size; - } - - /** - * Select strideLength and numberOfSamples, where stride is the number of intervals - * between two samples and trainSamples is training samples to fetch. If we disable - * interpolation, strideLength is 1 and numberOfSamples is shingleSize + numMinSamples; - * - * Algorithm: - * - * delta is the length of the detector interval in minutes. - * - * 1. Suppose delta ≤ 30 and divides 60. Then set numberOfSamples = ceil ( (shingleSize + 32)/ 24 )*24 - * and strideLength = 60/delta. Note that if there is enough data — we may have lot more than shingleSize+32 - * points — which is only good. This step tries to match data with hourly pattern. - * 2. otherwise, set numberOfSamples = (shingleSize + 32) and strideLength = 1. - * This should be an uncommon case as we are assuming most users think in terms of multiple of 5 minutes - *(say 10 or 30 minutes). But if someone wants a 23 minutes interval —- and the system permits -- - * we give it to them. In this case, we disable interpolation as we want to interpolate based on the hourly pattern. - * That's why we use 60 as a dividend in case 1. The 23 minute case does not fit that pattern. - * Note the smallest delta that does not divide 60 is 7 which is quite large to wait for one data point. - * @return the chosen strideLength and numberOfSamples - */ - private Pair selectRangeParam(AnomalyDetector detector) { - int shingleSize = detector.getShingleSize(); - if (ADEnabledSetting.isInterpolationInColdStartEnabled()) { - long delta = detector.getIntervalInMinutes(); - - int strideLength = defaulStrideLength; - int numberOfSamples = defaultNumberOfSamples; - if (delta <= 30 && 60 % delta == 0) { - strideLength = (int) (60 / delta); - numberOfSamples = (int) Math.ceil((shingleSize + numMinSamples) / 24.0d) * 24; - } else { - strideLength = 1; - numberOfSamples = shingleSize + numMinSamples; - } - return Pair.of(strideLength, numberOfSamples); - } else { - return Pair.of(1, shingleSize + numMinSamples); - } - - } - - /** - * Get train samples within a time range. - * - * @param detector accessor to detector config - * @param startMilli range start - * @param endMilli range end - * @param stride the number of intervals between two samples - * @param numberOfSamples maximum training samples to fetch - * @return list of sample time ranges - */ - private List> getTrainSampleRanges( - AnomalyDetector detector, - long startMilli, - long endMilli, - int stride, - int numberOfSamples - ) { - long bucketSize = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMillis(); - int numBuckets = (int) Math.floor((endMilli - startMilli) / (double) bucketSize); - // adjust if numStrides is more than the max samples - int numStrides = Math.min((int) Math.floor(numBuckets / (double) stride), numberOfSamples); - List> sampleRanges = Stream - .iterate(endMilli, i -> i - stride * bucketSize) - .limit(numStrides) - .map(time -> new SimpleImmutableEntry<>(time - bucketSize, time)) - .collect(Collectors.toList()); - return sampleRanges; - } - - /** - * Train models for the given entity - * @param entity The entity info - * @param detectorId Detector Id - * @param modelState Model state associated with the entity - * @param listener callback before the method returns whenever EntityColdStarter - * finishes training or encounters exceptions. The listener helps notify the - * cold start queue to pull another request (if any) to execute. - */ - public void trainModel(Entity entity, String detectorId, ModelState modelState, ActionListener listener) { - nodeStateManager.getConfig(detectorId, AnalysisType.AD, ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - logger.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); - listener.onFailure(new TimeSeriesException(detectorId, "fail to find detector")); - return; - } - - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); - - Queue samples = modelState.getModel().getSamples(); - String modelId = modelState.getModelId(); - - if (samples.size() < this.numMinSamples) { - // we cannot get last RCF score since cold start happens asynchronously - coldStart(modelId, entity, detectorId, modelState, detector, listener); - } else { - try { - trainModelFromDataSegments(samples, entity, modelState, detector.getShingleSize()); - listener.onResponse(null); - } catch (Exception e) { - listener.onFailure(e); - } - } - - }, listener::onFailure)); - } - - public void trainModelFromExistingSamples(ModelState modelState, int shingleSize) { - if (modelState == null || modelState.getModel() == null || modelState.getModel().getSamples() == null) { - return; - } - - EntityModel model = modelState.getModel(); - Queue samples = model.getSamples(); - if (samples.size() >= this.numMinSamples) { - try { - trainModelFromDataSegments(samples, model.getEntity().orElse(null), modelState, shingleSize); - } catch (Exception e) { - // e.g., exception from rcf. We can do nothing except logging the error - // We won't retry training for the same entity in the cooldown period - // (60 detector intervals). - logger.error("Unexpected training error", e); - } - - } - } - - /** - * Extract training data and put them into ModelState - * - * @param coldstartDatapoints training data generated from cold start - * @param modelId model Id - * @param modelState entity State - */ - private void extractTrainSamples(List coldstartDatapoints, String modelId, ModelState modelState) { - if (coldstartDatapoints == null || coldstartDatapoints.size() == 0 || modelState == null) { - return; - } - - EntityModel model = modelState.getModel(); - if (model == null) { - model = new EntityModel(null, new ArrayDeque<>(), null); - modelState.setModel(model); - } - - Queue newSamples = new ArrayDeque<>(); - for (double[][] consecutivePoints : coldstartDatapoints) { - for (int i = 0; i < consecutivePoints.length; i++) { - newSamples.add(consecutivePoints[i]); - } - } - - model.setSamples(newSamples); - } - - @Override - public void maintenance() { - doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { - String detectorId = doorKeeperEntry.getKey(); - DoorKeeper doorKeeper = doorKeeperEntry.getValue(); - if (doorKeeper.expired(modelTtl)) { - doorKeepers.remove(detectorId); - } else { - doorKeeper.maintenance(); - } - }); - } - - @Override - public void clear(String detectorId) { - doorKeepers.remove(detectorId); - } -} diff --git a/src/main/java/org/opensearch/ad/ml/EntityModel.java b/src/main/java/org/opensearch/ad/ml/EntityModel.java deleted file mode 100644 index 348ad8c6e..000000000 --- a/src/main/java/org/opensearch/ad/ml/EntityModel.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ml; - -import java.util.ArrayDeque; -import java.util.Optional; -import java.util.Queue; - -import org.opensearch.timeseries.model.Entity; - -import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; - -public class EntityModel { - private Entity entity; - // TODO: sample should record timestamp - private Queue samples; - - private ThresholdedRandomCutForest trcf; - - /** - * Constructor with TRCF. - * - * @param entity entity if any - * @param samples samples with the model - * @param trcf thresholded rcf model - */ - public EntityModel(Entity entity, Queue samples, ThresholdedRandomCutForest trcf) { - this.entity = entity; - this.samples = samples; - this.trcf = trcf; - } - - /** - * In old checkpoint mapping, we don't have entity. It's fine we are missing - * entity as it is mostly used for debugging. - * @return entity - */ - public Optional getEntity() { - return Optional.ofNullable(entity); - } - - public Queue getSamples() { - return this.samples; - } - - public void setSamples(Queue samples) { - this.samples = samples; - } - - public void addSample(double[] sample) { - if (this.samples == null) { - this.samples = new ArrayDeque<>(); - } - if (sample != null && sample.length != 0) { - this.samples.add(sample); - } - } - - /** - * Sets an trcf model. - * - * @param trcf an trcf model - */ - public void setTrcf(ThresholdedRandomCutForest trcf) { - this.trcf = trcf; - } - - /** - * Returns optional trcf model. - * - * @return the trcf model or empty - */ - public Optional getTrcf() { - return Optional.ofNullable(this.trcf); - } - - public void clear() { - if (samples != null) { - samples.clear(); - } - trcf = null; - } -} diff --git a/src/main/java/org/opensearch/ad/ml/IgnoreSimilarExtractor.java b/src/main/java/org/opensearch/ad/ml/IgnoreSimilarExtractor.java new file mode 100644 index 000000000..bd15b64c7 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/IgnoreSimilarExtractor.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.ml; + +import java.util.List; + +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Condition; +import org.opensearch.ad.model.Rule; +import org.opensearch.ad.model.ThresholdType; + +/** + * The IgnoreSimilarExtractor class provides functionality to process anomaly detection rules, + * specifically focusing on extracting threshold values for various conditions to ignore + * similar anomalies. It supports handling conditions based on absolute values and ratios, + * distinguishing between thresholds that specify ignoring similarities from above or below + * a certain value. + */ +public class IgnoreSimilarExtractor { + + // Define a class to hold the arrays + public static class ThresholdArrays { + public double[] ignoreSimilarFromAbove; + public double[] ignoreSimilarFromBelow; + public double[] ignoreSimilarFromAboveByRatio; + public double[] ignoreSimilarFromBelowByRatio; + + public ThresholdArrays( + double[] ignoreSimilarFromAbove, + double[] ignoreSimilarFromBelow, + double[] ignoreSimilarFromAboveByRatio, + double[] ignoreSimilarFromBelowByRatio + ) { + this.ignoreSimilarFromAbove = ignoreSimilarFromAbove; + this.ignoreSimilarFromBelow = ignoreSimilarFromBelow; + this.ignoreSimilarFromAboveByRatio = ignoreSimilarFromAboveByRatio; + this.ignoreSimilarFromBelowByRatio = ignoreSimilarFromBelowByRatio; + } + } + + public static ThresholdArrays processDetectorRules(AnomalyDetector detector) { + List featureNames = detector.getEnabledFeatureNames(); + int baseDimension = featureNames.size(); + Ref ignoreSimilarFromAbove = Ref.of(null); + Ref ignoreSimilarFromBelow = Ref.of(null); + Ref ignoreSimilarFromAboveByRatio = Ref.of(null); + Ref ignoreSimilarFromBelowByRatio = Ref.of(null); + + List rules = detector.getRules(); + for (Rule rule : rules) { + for (Condition condition : rule.getConditions()) { + processCondition( + condition, + featureNames, + baseDimension, + ignoreSimilarFromAbove, + ignoreSimilarFromBelow, + ignoreSimilarFromAboveByRatio, + ignoreSimilarFromBelowByRatio + ); + } + } + + // Return a new ThresholdArrays instance containing the processed arrays + return new ThresholdArrays( + ignoreSimilarFromAbove.value, + ignoreSimilarFromBelow.value, + ignoreSimilarFromAboveByRatio.value, + ignoreSimilarFromBelowByRatio.value + ); + } + + private static class Ref { + public T value; + + private Ref(T value) { + this.value = value; + } + + public static Ref of(T value) { + return new Ref<>(value); + } + } + + private static void processCondition( + Condition condition, + List featureNames, + int baseDimension, + Ref ignoreSimilarFromAbove, + Ref ignoreSimilarFromBelow, + Ref ignoreSimilarFromAboveByRatio, + Ref ignoreSimilarFromBelowByRatio + ) { + String featureName = condition.getFeatureName(); + int featureIndex = featureNames.indexOf(featureName); + + ThresholdType thresholdType = condition.getThresholdType(); + double value = condition.getValue(); + + switch (thresholdType) { + case ACTUAL_OVER_EXPECTED_MARGIN: + updateThresholdValue(baseDimension, ignoreSimilarFromAbove, featureIndex, value); + break; + case EXPECTED_OVER_ACTUAL_MARGIN: + updateThresholdValue(baseDimension, ignoreSimilarFromBelow, featureIndex, value); + break; + case ACTUAL_OVER_EXPECTED_RATIO: + updateThresholdValue(baseDimension, ignoreSimilarFromAboveByRatio, featureIndex, value); + break; + case EXPECTED_OVER_ACTUAL_RATIO: + updateThresholdValue(baseDimension, ignoreSimilarFromBelowByRatio, featureIndex, value); + break; + default: + break; + } + } + + private static void updateThresholdValue(int baseDimension, Ref thresholdArrayRef, int featureIndex, double value) { + if (thresholdArrayRef.value == null) { + thresholdArrayRef.value = new double[baseDimension]; + } + thresholdArrayRef.value[featureIndex] = value; + } +} diff --git a/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java b/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java deleted file mode 100644 index 2380173b0..000000000 --- a/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ml; - -import java.util.concurrent.ConcurrentHashMap; - -import org.opensearch.timeseries.MemoryTracker; -import org.opensearch.timeseries.MemoryTracker.Origin; - -import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; - -/** - * A customized ConcurrentHashMap that can automatically consume and release memory. - * This enables minimum change to our single-entity code as we just have to replace - * the map implementation. - * - * Note: this is mainly used for single-entity detectors. - */ -public class TRCFMemoryAwareConcurrentHashmap extends ConcurrentHashMap> { - private final MemoryTracker memoryTracker; - - public TRCFMemoryAwareConcurrentHashmap(MemoryTracker memoryTracker) { - this.memoryTracker = memoryTracker; - } - - @Override - public ModelState remove(Object key) { - ModelState deletedModelState = super.remove(key); - if (deletedModelState != null && deletedModelState.getModel() != null) { - long memoryToRelease = memoryTracker.estimateTRCFModelSize(deletedModelState.getModel()); - memoryTracker.releaseMemory(memoryToRelease, true, Origin.REAL_TIME_DETECTOR); - } - return deletedModelState; - } - - @Override - public ModelState put(K key, ModelState value) { - ModelState previousAssociatedState = super.put(key, value); - if (value != null && value.getModel() != null) { - long memoryToConsume = memoryTracker.estimateTRCFModelSize(value.getModel()); - memoryTracker.consumeMemory(memoryToConsume, true, Origin.REAL_TIME_DETECTOR); - } - return previousAssociatedState; - } -} diff --git a/src/main/java/org/opensearch/ad/model/ADTask.java b/src/main/java/org/opensearch/ad/model/ADTask.java index 93566a0f0..96061219a 100644 --- a/src/main/java/org/opensearch/ad/model/ADTask.java +++ b/src/main/java/org/opensearch/ad/model/ADTask.java @@ -141,7 +141,7 @@ public static Builder builder() { } @Override - public boolean isEntityTask() { + public boolean isHistoricalEntityTask() { return ADTaskType.HISTORICAL_HC_ENTITY.name().equals(taskType); } @@ -337,7 +337,12 @@ public static ADTask parse(XContentParser parser, String taskId) throws IOExcept detector.getCategoryFields(), detector.getUser(), detector.getCustomResultIndex(), - detector.getImputationOption() + detector.getImputationOption(), + detector.getRecencyEmphasis(), + detector.getSeasonIntervals(), + detector.getHistoryIntervals(), + detector.getRules() + ); return new Builder() .taskId(parsedTaskId) @@ -369,10 +374,12 @@ public static ADTask parse(XContentParser parser, String taskId) throws IOExcept @Generated @Override public boolean equals(Object other) { - if (this == other) + if (this == other) { return true; - if (other == null || getClass() != other.getClass()) + } + if (other == null || getClass() != other.getClass()) { return false; + } ADTask that = (ADTask) other; return super.equals(that) && Objects.equal(getDetector(), that.getDetector()) diff --git a/src/main/java/org/opensearch/ad/model/ADTaskProfile.java b/src/main/java/org/opensearch/ad/model/ADTaskProfile.java index cd6eaeaa0..0a31d5d95 100644 --- a/src/main/java/org/opensearch/ad/model/ADTaskProfile.java +++ b/src/main/java/org/opensearch/ad/model/ADTaskProfile.java @@ -21,44 +21,30 @@ import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.model.EntityTaskProfile; /** * One anomaly detection task means one detector starts to run until stopped. */ -public class ADTaskProfile implements ToXContentObject, Writeable { +public class ADTaskProfile extends TaskProfile { public static final String AD_TASK_FIELD = "ad_task"; - public static final String SHINGLE_SIZE_FIELD = "shingle_size"; - public static final String RCF_TOTAL_UPDATES_FIELD = "rcf_total_updates"; public static final String THRESHOLD_MODEL_TRAINED_FIELD = "threshold_model_trained"; public static final String THRESHOLD_MODEL_TRAINING_DATA_SIZE_FIELD = "threshold_model_training_data_size"; - public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes"; - public static final String NODE_ID_FIELD = "node_id"; - public static final String TASK_ID_FIELD = "task_id"; - public static final String AD_TASK_TYPE_FIELD = "task_type"; public static final String DETECTOR_TASK_SLOTS_FIELD = "detector_task_slots"; public static final String TOTAL_ENTITIES_INITED_FIELD = "total_entities_inited"; public static final String TOTAL_ENTITIES_COUNT_FIELD = "total_entities_count"; public static final String PENDING_ENTITIES_COUNT_FIELD = "pending_entities_count"; public static final String RUNNING_ENTITIES_COUNT_FIELD = "running_entities_count"; public static final String RUNNING_ENTITIES_FIELD = "running_entities"; - public static final String ENTITY_TASK_PROFILE_FIELD = "entity_task_profiles"; public static final String LATEST_HC_TASK_RUN_TIME_FIELD = "latest_hc_task_run_time"; - private ADTask adTask; - private Integer shingleSize; - private Long rcfTotalUpdates; private Boolean thresholdModelTrained; private Integer thresholdModelTrainingDataSize; - private Long modelSizeInBytes; - private String nodeId; - private String taskId; - private String adTaskType; private Integer detectorTaskSlots; private Boolean totalEntitiesInited; private Integer totalEntitiesCount; @@ -66,15 +52,14 @@ public class ADTaskProfile implements ToXContentObject, Writeable { private Integer runningEntitiesCount; private List runningEntities; private Long latestHCTaskRunTime; - - private List entityTaskProfiles; + protected List entityTaskProfiles; public ADTaskProfile() { } public ADTaskProfile(ADTask adTask) { - this.adTask = adTask; + super(adTask); } public ADTaskProfile( @@ -86,13 +71,9 @@ public ADTaskProfile( long modelSizeInBytes, String nodeId ) { - this.taskId = taskId; - this.shingleSize = shingleSize; - this.rcfTotalUpdates = rcfTotalUpdates; + super(taskId, shingleSize, rcfTotalUpdates, modelSizeInBytes, nodeId); this.thresholdModelTrained = thresholdModelTrained; this.thresholdModelTrainingDataSize = thresholdModelTrainingDataSize; - this.modelSizeInBytes = modelSizeInBytes; - this.nodeId = nodeId; } public ADTaskProfile( @@ -113,15 +94,9 @@ public ADTaskProfile( List runningEntities, Long latestHCTaskRunTime ) { - this.adTask = adTask; - this.shingleSize = shingleSize; - this.rcfTotalUpdates = rcfTotalUpdates; + super(adTask, shingleSize, rcfTotalUpdates, modelSizeInBytes, nodeId, taskId, adTaskType); this.thresholdModelTrained = thresholdModelTrained; this.thresholdModelTrainingDataSize = thresholdModelTrainingDataSize; - this.modelSizeInBytes = modelSizeInBytes; - this.nodeId = nodeId; - this.taskId = taskId; - this.adTaskType = adTaskType; this.detectorTaskSlots = detectorTaskSlots; this.totalEntitiesInited = totalEntitiesInited; this.totalEntitiesCount = totalEntitiesCount; @@ -133,9 +108,9 @@ public ADTaskProfile( public ADTaskProfile(StreamInput input) throws IOException { if (input.readBoolean()) { - this.adTask = new ADTask(input); + this.task = new ADTask(input); } else { - this.adTask = null; + this.task = null; } this.shingleSize = input.readOptionalInt(); this.rcfTotalUpdates = input.readOptionalLong(); @@ -145,7 +120,7 @@ public ADTaskProfile(StreamInput input) throws IOException { this.nodeId = input.readOptionalString(); if (input.available() > 0) { this.taskId = input.readOptionalString(); - this.adTaskType = input.readOptionalString(); + this.taskType = input.readOptionalString(); this.detectorTaskSlots = input.readOptionalInt(); this.totalEntitiesInited = input.readOptionalBoolean(); this.totalEntitiesCount = input.readOptionalInt(); @@ -155,7 +130,7 @@ public ADTaskProfile(StreamInput input) throws IOException { this.runningEntities = input.readStringList(); } if (input.readBoolean()) { - this.entityTaskProfiles = input.readList(ADEntityTaskProfile::new); + this.entityTaskProfiles = input.readList(EntityTaskProfile::new); } this.latestHCTaskRunTime = input.readOptionalLong(); } @@ -167,9 +142,9 @@ public void writeTo(StreamOutput out) throws IOException { } public void writeTo(StreamOutput out, Version adVersion) throws IOException { - if (adTask != null) { + if (task != null) { out.writeBoolean(true); - adTask.writeTo(out); + task.writeTo(out); } else { out.writeBoolean(false); } @@ -182,7 +157,7 @@ public void writeTo(StreamOutput out, Version adVersion) throws IOException { out.writeOptionalString(nodeId); if (adVersion != null) { out.writeOptionalString(taskId); - out.writeOptionalString(adTaskType); + out.writeOptionalString(taskType); out.writeOptionalInt(detectorTaskSlots); out.writeOptionalBoolean(totalEntitiesInited); out.writeOptionalInt(totalEntitiesCount); @@ -207,33 +182,13 @@ public void writeTo(StreamOutput out, Version adVersion) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { XContentBuilder xContentBuilder = builder.startObject(); - if (adTask != null) { - xContentBuilder.field(AD_TASK_FIELD, adTask); - } - if (shingleSize != null) { - xContentBuilder.field(SHINGLE_SIZE_FIELD, shingleSize); - } - if (rcfTotalUpdates != null) { - xContentBuilder.field(RCF_TOTAL_UPDATES_FIELD, rcfTotalUpdates); - } + super.toXContent(xContentBuilder); if (thresholdModelTrained != null) { xContentBuilder.field(THRESHOLD_MODEL_TRAINED_FIELD, thresholdModelTrained); } if (thresholdModelTrainingDataSize != null) { xContentBuilder.field(THRESHOLD_MODEL_TRAINING_DATA_SIZE_FIELD, thresholdModelTrainingDataSize); } - if (modelSizeInBytes != null) { - xContentBuilder.field(MODEL_SIZE_IN_BYTES, modelSizeInBytes); - } - if (nodeId != null) { - xContentBuilder.field(NODE_ID_FIELD, nodeId); - } - if (taskId != null) { - xContentBuilder.field(TASK_ID_FIELD, taskId); - } - if (adTaskType != null) { - xContentBuilder.field(AD_TASK_TYPE_FIELD, adTaskType); - } if (detectorTaskSlots != null) { xContentBuilder.field(DETECTOR_TASK_SLOTS_FIELD, detectorTaskSlots); } @@ -252,12 +207,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (runningEntities != null) { xContentBuilder.field(RUNNING_ENTITIES_FIELD, runningEntities); } - if (entityTaskProfiles != null && entityTaskProfiles.size() > 0) { - xContentBuilder.field(ENTITY_TASK_PROFILE_FIELD, entityTaskProfiles.toArray()); - } if (latestHCTaskRunTime != null) { xContentBuilder.field(LATEST_HC_TASK_RUN_TIME_FIELD, latestHCTaskRunTime); } + if (entityTaskProfiles != null && entityTaskProfiles.size() > 0) { + xContentBuilder.field(ENTITY_TASK_PROFILE_FIELD, entityTaskProfiles.toArray()); + } return xContentBuilder.endObject(); } @@ -277,7 +232,7 @@ public static ADTaskProfile parse(XContentParser parser) throws IOException { Integer pendingEntitiesCount = null; Integer runningEntitiesCount = null; List runningEntities = null; - List entityTaskProfiles = null; + List entityTaskProfiles = null; Long latestHCTaskRunTime = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -310,7 +265,7 @@ public static ADTaskProfile parse(XContentParser parser) throws IOException { case TASK_ID_FIELD: taskId = parser.text(); break; - case AD_TASK_TYPE_FIELD: + case TASK_TYPE_FIELD: taskType = parser.text(); break; case DETECTOR_TASK_SLOTS_FIELD: @@ -339,7 +294,7 @@ public static ADTaskProfile parse(XContentParser parser) throws IOException { entityTaskProfiles = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - entityTaskProfiles.add(ADEntityTaskProfile.parse(parser)); + entityTaskProfiles.add(EntityTaskProfile.parse(parser)); } break; case LATEST_HC_TASK_RUN_TIME_FIELD: @@ -370,30 +325,6 @@ public static ADTaskProfile parse(XContentParser parser) throws IOException { ); } - public ADTask getAdTask() { - return adTask; - } - - public void setAdTask(ADTask adTask) { - this.adTask = adTask; - } - - public Integer getShingleSize() { - return shingleSize; - } - - public void setShingleSize(Integer shingleSize) { - this.shingleSize = shingleSize; - } - - public Long getRcfTotalUpdates() { - return rcfTotalUpdates; - } - - public void setRcfTotalUpdates(Long rcfTotalUpdates) { - this.rcfTotalUpdates = rcfTotalUpdates; - } - public Boolean getThresholdModelTrained() { return thresholdModelTrained; } @@ -410,38 +341,6 @@ public void setThresholdModelTrainingDataSize(Integer thresholdModelTrainingData this.thresholdModelTrainingDataSize = thresholdModelTrainingDataSize; } - public Long getModelSizeInBytes() { - return modelSizeInBytes; - } - - public void setModelSizeInBytes(Long modelSizeInBytes) { - this.modelSizeInBytes = modelSizeInBytes; - } - - public String getNodeId() { - return nodeId; - } - - public void setNodeId(String nodeId) { - this.nodeId = nodeId; - } - - public String getTaskId() { - return taskId; - } - - public void setTaskId(String taskId) { - this.taskId = taskId; - } - - public String getAdTaskType() { - return adTaskType; - } - - public void setAdTaskType(String adTaskType) { - this.adTaskType = adTaskType; - } - public boolean getTotalEntitiesInited() { return totalEntitiesInited != null && totalEntitiesInited.booleanValue(); } @@ -498,11 +397,11 @@ public void setRunningEntities(List runningEntities) { this.runningEntities = runningEntities; } - public List getEntityTaskProfiles() { + public List getEntityTaskProfiles() { return entityTaskProfiles; } - public void setEntityTaskProfiles(List entityTaskProfiles) { + public void setEntityTaskProfiles(List entityTaskProfiles) { this.entityTaskProfiles = entityTaskProfiles; } @@ -514,15 +413,9 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; ADTaskProfile that = (ADTaskProfile) o; - return Objects.equals(adTask, that.adTask) - && Objects.equals(shingleSize, that.shingleSize) - && Objects.equals(rcfTotalUpdates, that.rcfTotalUpdates) + return super.equals(o) && Objects.equals(thresholdModelTrained, that.thresholdModelTrained) && Objects.equals(thresholdModelTrainingDataSize, that.thresholdModelTrainingDataSize) - && Objects.equals(modelSizeInBytes, that.modelSizeInBytes) - && Objects.equals(nodeId, that.nodeId) - && Objects.equals(taskId, that.taskId) - && Objects.equals(adTaskType, that.adTaskType) && Objects.equals(detectorTaskSlots, that.detectorTaskSlots) && Objects.equals(totalEntitiesInited, that.totalEntitiesInited) && Objects.equals(totalEntitiesCount, that.totalEntitiesCount) @@ -536,17 +429,11 @@ public boolean equals(Object o) { @Generated @Override public int hashCode() { - return Objects + int hash = super.hashCode(); + hash = 89 * hash + Objects .hash( - adTask, - shingleSize, - rcfTotalUpdates, thresholdModelTrained, thresholdModelTrainingDataSize, - modelSizeInBytes, - nodeId, - taskId, - adTaskType, detectorTaskSlots, totalEntitiesInited, totalEntitiesCount, @@ -554,15 +441,17 @@ public int hashCode() { runningEntitiesCount, runningEntities, entityTaskProfiles, - latestHCTaskRunTime + latestHCTaskRunTime, + entityTaskProfiles ); + return hash; } @Override public String toString() { return "ADTaskProfile{" + "adTask=" - + adTask + + task + ", shingleSize=" + shingleSize + ", rcfTotalUpdates=" @@ -580,7 +469,7 @@ public String toString() { + taskId + '\'' + ", adTaskType='" - + adTaskType + + taskType + '\'' + ", detectorTaskSlots=" + detectorTaskSlots @@ -600,4 +489,9 @@ public String toString() { + entityTaskProfiles + '}'; } + + @Override + protected String getTaskFieldName() { + return AD_TASK_FIELD; + } } diff --git a/src/main/java/org/opensearch/ad/model/ADTaskType.java b/src/main/java/org/opensearch/ad/model/ADTaskType.java index d235bad7e..a26e73f80 100644 --- a/src/main/java/org/opensearch/ad/model/ADTaskType.java +++ b/src/main/java/org/opensearch/ad/model/ADTaskType.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; +// enum names need to start with REALTIME or HISTORICAL we use prefix in TaskManager to check if a task is of certain type (e.g., historical) public enum ADTaskType implements TaskType { @Deprecated HISTORICAL, @@ -31,7 +32,7 @@ public enum ADTaskType implements TaskType { public static List HISTORICAL_DETECTOR_TASK_TYPES = ImmutableList .of(ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.HISTORICAL_SINGLE_ENTITY, ADTaskType.HISTORICAL); public static List ALL_HISTORICAL_TASK_TYPES = ImmutableList - .of(ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.HISTORICAL_SINGLE_ENTITY, ADTaskType.HISTORICAL_HC_ENTITY, ADTaskType.HISTORICAL); + .of(ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.REALTIME_SINGLE_ENTITY, ADTaskType.HISTORICAL_HC_ENTITY, ADTaskType.HISTORICAL); public static List REALTIME_TASK_TYPES = ImmutableList .of(ADTaskType.REALTIME_SINGLE_ENTITY, ADTaskType.REALTIME_HC_DETECTOR); public static List ALL_DETECTOR_TASK_TYPES = ImmutableList diff --git a/src/main/java/org/opensearch/ad/model/Action.java b/src/main/java/org/opensearch/ad/model/Action.java new file mode 100644 index 000000000..ecd7fe11a --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/Action.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +public enum Action { + // ignore anomaly if found + IGNORE_ANOMALY; +} diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java index aa86fa842..e72242b58 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java @@ -47,11 +47,15 @@ import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ShingleGetter; import org.opensearch.timeseries.model.TimeConfiguration; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.ParseUtils; +import com.google.common.base.Objects; + /** * An AnomalyDetector is used to represent anomaly detection model(RCF) related parameters. * NOTE: If change detector config index mapping, you should change AD task index mapping as well. @@ -59,6 +63,34 @@ * in code rather than config it in anomaly-detection-state.json file. */ public class AnomalyDetector extends Config { + static class ADShingleGetter implements ShingleGetter { + private Integer seasonIntervals; + + public ADShingleGetter(Integer seasonIntervals) { + this.seasonIntervals = seasonIntervals; + } + + /** + * If the given shingle size not null, return given shingle size; + * if seasonality not null, return max(seasonality hint / 2, horizon / 3); + * otherwise, return default shingle size. + * + * @param customShingleSize Given shingle size + * @return Shingle size + */ + @Override + public Integer getShingleSize(Integer customShingleSize) { + if (customShingleSize != null) { + return customShingleSize; + } + + if (seasonIntervals != null) { + return seasonIntervals / TimeSeriesSettings.SEASONALITY_TO_SHINGLE_RATIO; + } + + return TimeSeriesSettings.DEFAULT_SHINGLE_SIZE; + } + } public static final String PARSE_FIELD_NAME = "AnomalyDetector"; public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( @@ -73,6 +105,7 @@ public class AnomalyDetector extends Config { public static final String DETECTOR_TYPE_FIELD = "detector_type"; @Deprecated public static final String DETECTION_DATE_RANGE_FIELD = "detection_date_range"; + public static final String RULES_FIELD = "rules"; protected String detectorType; @@ -84,6 +117,8 @@ public class AnomalyDetector extends Config { + MAX_RESULT_INDEX_NAME_SIZE + " characters"; + private List rules; + /** * Constructor function. * @@ -105,6 +140,10 @@ public class AnomalyDetector extends Config { * @param user user to which detector is associated * @param resultIndex result index * @param imputationOption interpolation method and optional default values + * @param recencyEmphasis Aggregation period to smooth the emphasis on the most recent data. + * @param seasonIntervals seasonality in terms of intervals + * @param historyIntervals history intervals we look back during cold start + * @param rules custom rules to filter out AD results */ public AnomalyDetector( String detectorId, @@ -124,7 +163,11 @@ public AnomalyDetector( List categoryFields, User user, String resultIndex, - ImputationOption imputationOption + ImputationOption imputationOption, + Integer recencyEmphasis, + Integer seasonIntervals, + Integer historyIntervals, + List rules ) { super( detectorId, @@ -144,7 +187,11 @@ public AnomalyDetector( user, resultIndex, detectionInterval, - imputationOption + imputationOption, + recencyEmphasis, + seasonIntervals, + new ADShingleGetter(seasonIntervals), + historyIntervals ); checkAndThrowValidationErrors(ValidationAspect.DETECTOR); @@ -166,6 +213,8 @@ public AnomalyDetector( checkAndThrowValidationErrors(ValidationAspect.DETECTOR); this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name(); + + this.rules = rules; } /* @@ -210,7 +259,12 @@ public AnomalyDetector(StreamInput input) throws IOException { } else { this.imputationOption = null; } - this.imputer = createImputer(); + this.recencyEmphasis = input.readInt(); + this.seasonIntervals = input.readInt(); + this.historyIntervals = input.readInt(); + if (input.readBoolean()) { + this.rules = input.readList(Rule::new); + } } public XContentBuilder toXContent(XContentBuilder builder) throws IOException { @@ -264,6 +318,15 @@ public void writeTo(StreamOutput output) throws IOException { } else { output.writeBoolean(false); } + output.writeInt(recencyEmphasis); + output.writeInt(seasonIntervals); + output.writeInt(historyIntervals); + if (rules != null) { + output.writeBoolean(true); + output.writeList(rules); + } else { + output.writeBoolean(false); + } } @Override @@ -278,6 +341,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (detectionDateRange != null) { xContentBuilder.field(DETECTION_DATE_RANGE_FIELD, detectionDateRange); } + if (rules != null) { + xContentBuilder.field(RULES_FIELD, rules.toArray()); + } return xContentBuilder.endObject(); } @@ -350,6 +416,11 @@ public static AnomalyDetector parse( List categoryField = null; ImputationOption imputationOption = null; + Integer recencyEmphasis = null; + Integer seasonality = null; + Integer historyIntervals = null; + + List rules = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -460,6 +531,21 @@ public static AnomalyDetector parse( case IMPUTATION_OPTION_FIELD: imputationOption = ImputationOption.parse(parser); break; + case RECENCY_EMPHASIS_FIELD: + recencyEmphasis = parser.intValue(); + break; + case SEASONALITY_FIELD: + seasonality = parser.currentToken() == XContentParser.Token.VALUE_NULL ? null : parser.intValue(); + break; + case HISTORY_INTERVAL_FIELD: + historyIntervals = parser.intValue(); + break; + case RULES_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + rules.add(Rule.parse(parser)); + } + break; default: parser.skipChildren(); break; @@ -476,14 +562,18 @@ public static AnomalyDetector parse( filterQuery, detectionInterval, windowDelay, - getShingleSize(shingleSize), + shingleSize, uiMetadata, schemaVersion, lastUpdateTime, categoryField, user, resultIndex, - imputationOption + imputationOption, + recencyEmphasis, + seasonality, + historyIntervals, + rules ); detector.setDetectionDateRange(detectionDateRange); return detector; @@ -501,6 +591,10 @@ public DateRange getDetectionDateRange() { return detectionDateRange; } + public List getRules() { + return rules; + } + @Override protected ValidationAspect getConfigValidationAspect() { return ValidationAspect.DETECTOR; @@ -513,4 +607,24 @@ public String validateCustomResultIndex(String resultIndex) { } return super.validateCustomResultIndex(resultIndex); } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + AnomalyDetector detector = (AnomalyDetector) o; + return super.equals(o) && Objects.equal(rules, detector.rules); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = super.hashCode(); + result = prime * result + Objects.hashCode(rules); + return result; + } } diff --git a/src/main/java/org/opensearch/ad/model/AnomalyResult.java b/src/main/java/org/opensearch/ad/model/AnomalyResult.java index 4ee4e0ee7..bdfe4eb3c 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyResult.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyResult.java @@ -182,6 +182,16 @@ So if we detect anomaly late, we get the baseDimension values from the past (cur // rcf score threshold at the time of writing a result private final Double threshold; protected final Double confidence; + /* + * model id for easy aggregations of entities. The front end needs to query + * for entities ordered by the descending/ascending order of feature values. + * After supporting multi-category fields, it is hard to write such queries + * since the entity information is stored in a nested object array. + * Also, the front end has all code/queries/ helper functions in place to + * rely on a single key per entity combo. Adding model id to forecast result + * to help the transition to multi-categorical field less painful. + */ + private final String modelId; // used when indexing exception or error or an empty result public AnomalyResult( @@ -255,12 +265,12 @@ public AnomalyResult( entity, user, schemaVersion, - modelId, taskId ); this.confidence = confidence; this.anomalyScore = anomalyScore; this.anomalyGrade = anomalyGrade; + this.modelId = modelId; this.approxAnomalyStartTime = approxAnomalyStartTime; this.relevantAttribution = relevantAttribution; this.pastValues = pastValues; @@ -422,6 +432,7 @@ public static AnomalyResult fromRawTRCFResult( public AnomalyResult(StreamInput input) throws IOException { super(input); + this.modelId = input.readOptionalString(); this.confidence = input.readDouble(); this.anomalyScore = input.readDouble(); this.anomalyGrade = input.readDouble(); @@ -502,7 +513,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(CommonName.ERROR_FIELD, error); } if (optionalEntity.isPresent()) { - xContentBuilder.field(CommonName.ENTITY_FIELD, optionalEntity.get()); + xContentBuilder.field(CommonName.ENTITY_KEY, optionalEntity.get()); } if (user != null) { xContentBuilder.field(CommonName.USER_FIELD, user); @@ -598,7 +609,7 @@ public static AnomalyResult parse(XContentParser parser) throws IOException { case CommonName.ERROR_FIELD: error = parser.text(); break; - case CommonName.ENTITY_FIELD: + case CommonName.ENTITY_KEY: entity = Entity.parse(parser); break; case CommonName.USER_FIELD: @@ -675,7 +686,8 @@ public boolean equals(Object o) { if (getClass() != o.getClass()) return false; AnomalyResult that = (AnomalyResult) o; - return Objects.equal(confidence, that.confidence) + return Objects.equal(modelId, that.modelId) + && Objects.equal(confidence, that.confidence) && Objects.equal(anomalyScore, that.anomalyScore) && Objects.equal(anomalyGrade, that.anomalyGrade) && Objects.equal(approxAnomalyStartTime, that.approxAnomalyStartTime) @@ -692,6 +704,7 @@ public int hashCode() { int result = super.hashCode(); result = prime * result + Objects .hashCode( + modelId, confidence, anomalyScore, anomalyGrade, @@ -710,6 +723,7 @@ public String toString() { return super.toString() + ", " + new ToStringBuilder(this) + .append("modelId", modelId) .append("confidence", confidence) .append("anomalyScore", anomalyScore) .append("anomalyGrade", anomalyGrade) @@ -757,6 +771,10 @@ public Double getThreshold() { return threshold; } + public String getModelId() { + return modelId; + } + /** * Anomaly result index consists of overwhelmingly (99.5%) zero-grade non-error documents. * This function exclude the majority case. @@ -772,6 +790,7 @@ public boolean isHighPriority() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + out.writeOptionalString(modelId); out.writeDouble(confidence); out.writeDouble(anomalyScore); out.writeDouble(anomalyGrade); diff --git a/src/main/java/org/opensearch/ad/model/Condition.java b/src/main/java/org/opensearch/ad/model/Condition.java new file mode 100644 index 000000000..e69d2fa49 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/Condition.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +public class Condition implements Writeable, ToXContentObject { + private static final String FEATURE_NAME_FIELD = "feature_Name"; + private static final String THRESHOLD_TYPE_FIELD = "threshsold_type"; + private static final String OPERATOR_FIELD = "operator"; + private static final String VALUE_FIELD = "value"; + + private String featureName; + private ThresholdType thresholdType; + private Operator operator; + private double value; + + public Condition(String featureName, ThresholdType thresholdType, Operator operator, double value) { + this.featureName = featureName; + this.thresholdType = thresholdType; + this.operator = operator; + this.value = value; + } + + public Condition(StreamInput input) throws IOException { + this.featureName = input.readString(); + this.thresholdType = input.readEnum(ThresholdType.class); + this.operator = input.readEnum(Operator.class); + this.value = input.readDouble(); + } + + /** + * Parse raw json content into rule instance. + * + * @param parser json based content parser + * @return rule instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static Condition parse(XContentParser parser) throws IOException { + String featureName = null; + ThresholdType thresholdType = null; + Operator operator = null; + Double value = 0d; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + + parser.nextToken(); + switch (fieldName) { + case FEATURE_NAME_FIELD: + featureName = parser.text(); + break; + case THRESHOLD_TYPE_FIELD: + thresholdType = ThresholdType.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case OPERATOR_FIELD: + operator = Operator.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case VALUE_FIELD: + value = parser.doubleValue(); + break; + default: + break; + } + } + return new Condition(featureName, thresholdType, operator, value); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(FEATURE_NAME_FIELD, featureName) + .field(THRESHOLD_TYPE_FIELD, thresholdType) + .field(OPERATOR_FIELD, operator) + .field(VALUE_FIELD, value); + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(featureName); + out.writeEnum(thresholdType); + out.writeEnum(operator); + out.writeDouble(value); + } + + public String getFeatureName() { + return featureName; + } + + public ThresholdType getThresholdType() { + return thresholdType; + } + + public Operator getOperator() { + return operator; + } + + public double getValue() { + return value; + } +} diff --git a/src/main/java/org/opensearch/ad/model/DetectorProfile.java b/src/main/java/org/opensearch/ad/model/DetectorProfile.java index 77418552e..7ffd8c9f6 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorProfile.java +++ b/src/main/java/org/opensearch/ad/model/DetectorProfile.java @@ -13,131 +13,28 @@ import java.io.IOException; -import org.apache.commons.lang.builder.EqualsBuilder; -import org.apache.commons.lang.builder.HashCodeBuilder; -import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.ConfigProfile; -public class DetectorProfile implements Writeable, ToXContentObject, Mergeable { - private DetectorState state; - private String error; - private ModelProfileOnNode[] modelProfile; - private int shingleSize; - private String coordinatingNode; - private long totalSizeInBytes; - private InitProgressProfile initProgress; - private Long totalEntities; - private Long activeEntities; - private ADTaskProfile adTaskProfile; - private long modelCount; +public class DetectorProfile extends ConfigProfile { - public XContentBuilder toXContent(XContentBuilder builder) throws IOException { - return toXContent(builder, ToXContent.EMPTY_PARAMS); - } - - public DetectorProfile(StreamInput in) throws IOException { - if (in.readBoolean()) { - this.state = in.readEnum(DetectorState.class); - } - - this.error = in.readOptionalString(); - this.modelProfile = in.readOptionalArray(ModelProfileOnNode::new, ModelProfileOnNode[]::new); - this.shingleSize = in.readOptionalInt(); - this.coordinatingNode = in.readOptionalString(); - this.totalSizeInBytes = in.readOptionalLong(); - this.totalEntities = in.readOptionalLong(); - this.activeEntities = in.readOptionalLong(); - if (in.readBoolean()) { - this.initProgress = new InitProgressProfile(in); - } - if (in.readBoolean()) { - this.adTaskProfile = new ADTaskProfile(in); - } - this.modelCount = in.readVLong(); - } - - private DetectorProfile() {} - - public static class Builder { - private DetectorState state = null; - private String error = null; - private ModelProfileOnNode[] modelProfile = null; - private int shingleSize = -1; - private String coordinatingNode = null; - private long totalSizeInBytes = -1; - private InitProgressProfile initProgress = null; - private Long totalEntities; - private Long activeEntities; + public static class Builder extends ConfigProfile.Builder { private ADTaskProfile adTaskProfile; - private long modelCount = 0; public Builder() {} - public Builder state(DetectorState state) { - this.state = state; - return this; - } - - public Builder error(String error) { - this.error = error; - return this; - } - - public Builder modelProfile(ModelProfileOnNode[] modelProfile) { - this.modelProfile = modelProfile; - return this; - } - - public Builder modelCount(long modelCount) { - this.modelCount = modelCount; - return this; - } - - public Builder shingleSize(int shingleSize) { - this.shingleSize = shingleSize; - return this; - } - - public Builder coordinatingNode(String coordinatingNode) { - this.coordinatingNode = coordinatingNode; - return this; - } - - public Builder totalSizeInBytes(long totalSizeInBytes) { - this.totalSizeInBytes = totalSizeInBytes; - return this; - } - - public Builder initProgress(InitProgressProfile initProgress) { - this.initProgress = initProgress; - return this; - } - - public Builder totalEntities(Long totalEntities) { - this.totalEntities = totalEntities; - return this; - } - - public Builder activeEntities(Long activeEntities) { - this.activeEntities = activeEntities; - return this; - } - - public Builder adTaskProfile(ADTaskProfile adTaskProfile) { + @Override + public Builder taskProfile(ADTaskProfile adTaskProfile) { this.adTaskProfile = adTaskProfile; return this; } + @Override public DetectorProfile build() { DetectorProfile profile = new DetectorProfile(); - profile.state = this.state; - profile.error = this.error; + profile.state = state; + profile.error = error; profile.modelProfile = modelProfile; profile.modelCount = modelCount; profile.shingleSize = shingleSize; @@ -146,320 +43,25 @@ public DetectorProfile build() { profile.initProgress = initProgress; profile.totalEntities = totalEntities; profile.activeEntities = activeEntities; - profile.adTaskProfile = adTaskProfile; + profile.taskProfile = adTaskProfile; return profile; } } - @Override - public void writeTo(StreamOutput out) throws IOException { - if (state == null) { - out.writeBoolean(false); - } else { - out.writeBoolean(true); - out.writeEnum(state); - } - - out.writeOptionalString(error); - out.writeOptionalArray(modelProfile); - out.writeOptionalInt(shingleSize); - out.writeOptionalString(coordinatingNode); - out.writeOptionalLong(totalSizeInBytes); - out.writeOptionalLong(totalEntities); - out.writeOptionalLong(activeEntities); - if (initProgress == null) { - out.writeBoolean(false); - } else { - out.writeBoolean(true); - initProgress.writeTo(out); - } - if (adTaskProfile == null) { - out.writeBoolean(false); - } else { - out.writeBoolean(true); - adTaskProfile.writeTo(out); - } - out.writeVLong(modelCount); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - XContentBuilder xContentBuilder = builder.startObject(); - - if (state != null) { - xContentBuilder.field(ADCommonName.STATE, state); - } - if (error != null) { - xContentBuilder.field(ADCommonName.ERROR, error); - } - if (modelProfile != null && modelProfile.length > 0) { - xContentBuilder.startArray(ADCommonName.MODELS); - for (ModelProfileOnNode profile : modelProfile) { - profile.toXContent(xContentBuilder, params); - } - xContentBuilder.endArray(); - } - if (shingleSize != -1) { - xContentBuilder.field(ADCommonName.SHINGLE_SIZE, shingleSize); - } - if (coordinatingNode != null && !coordinatingNode.isEmpty()) { - xContentBuilder.field(ADCommonName.COORDINATING_NODE, coordinatingNode); - } - if (totalSizeInBytes != -1) { - xContentBuilder.field(ADCommonName.TOTAL_SIZE_IN_BYTES, totalSizeInBytes); - } - if (initProgress != null) { - xContentBuilder.field(ADCommonName.INIT_PROGRESS, initProgress); - } - if (totalEntities != null) { - xContentBuilder.field(ADCommonName.TOTAL_ENTITIES, totalEntities); - } - if (activeEntities != null) { - xContentBuilder.field(ADCommonName.ACTIVE_ENTITIES, activeEntities); - } - if (adTaskProfile != null) { - xContentBuilder.field(ADCommonName.AD_TASK, adTaskProfile); - } - if (modelCount > 0) { - xContentBuilder.field(ADCommonName.MODEL_COUNT, modelCount); - } - return xContentBuilder.endObject(); - } - - public DetectorState getState() { - return state; - } - - public void setState(DetectorState state) { - this.state = state; - } - - public String getError() { - return error; - } - - public void setError(String error) { - this.error = error; - } - - public ModelProfileOnNode[] getModelProfile() { - return modelProfile; - } - - public void setModelProfile(ModelProfileOnNode[] modelProfile) { - this.modelProfile = modelProfile; - } - - public int getShingleSize() { - return shingleSize; - } - - public void setShingleSize(int shingleSize) { - this.shingleSize = shingleSize; - } - - public String getCoordinatingNode() { - return coordinatingNode; - } - - public void setCoordinatingNode(String coordinatingNode) { - this.coordinatingNode = coordinatingNode; - } - - public long getTotalSizeInBytes() { - return totalSizeInBytes; - } - - public void setTotalSizeInBytes(long totalSizeInBytes) { - this.totalSizeInBytes = totalSizeInBytes; - } - - public InitProgressProfile getInitProgress() { - return initProgress; - } - - public void setInitProgress(InitProgressProfile initProgress) { - this.initProgress = initProgress; - } - - public Long getTotalEntities() { - return totalEntities; - } - - public void setTotalEntities(Long totalEntities) { - this.totalEntities = totalEntities; - } - - public Long getActiveEntities() { - return activeEntities; - } - - public void setActiveEntities(Long activeEntities) { - this.activeEntities = activeEntities; - } - - public ADTaskProfile getAdTaskProfile() { - return adTaskProfile; - } - - public void setAdTaskProfile(ADTaskProfile adTaskProfile) { - this.adTaskProfile = adTaskProfile; - } - - public long getModelCount() { - return modelCount; - } - - public void setModelCount(long modelCount) { - this.modelCount = modelCount; - } - - @Override - public void merge(Mergeable other) { - if (this == other || other == null || getClass() != other.getClass()) { - return; - } - DetectorProfile otherProfile = (DetectorProfile) other; - if (otherProfile.getState() != null) { - this.state = otherProfile.getState(); - } - if (otherProfile.getError() != null) { - this.error = otherProfile.getError(); - } - if (otherProfile.getCoordinatingNode() != null) { - this.coordinatingNode = otherProfile.getCoordinatingNode(); - } - if (otherProfile.getShingleSize() != -1) { - this.shingleSize = otherProfile.getShingleSize(); - } - if (otherProfile.getModelProfile() != null) { - this.modelProfile = otherProfile.getModelProfile(); - } - if (otherProfile.getTotalSizeInBytes() != -1) { - this.totalSizeInBytes = otherProfile.getTotalSizeInBytes(); - } - if (otherProfile.getInitProgress() != null) { - this.initProgress = otherProfile.getInitProgress(); - } - if (otherProfile.getTotalEntities() != null) { - this.totalEntities = otherProfile.getTotalEntities(); - } - if (otherProfile.getActiveEntities() != null) { - this.activeEntities = otherProfile.getActiveEntities(); - } - if (otherProfile.getAdTaskProfile() != null) { - this.adTaskProfile = otherProfile.getAdTaskProfile(); - } - if (otherProfile.getModelCount() > 0) { - this.modelCount = otherProfile.getModelCount(); - } - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - if (obj instanceof DetectorProfile) { - DetectorProfile other = (DetectorProfile) obj; + public DetectorProfile() {} - EqualsBuilder equalsBuilder = new EqualsBuilder(); - if (state != null) { - equalsBuilder.append(state, other.state); - } - if (error != null) { - equalsBuilder.append(error, other.error); - } - if (modelProfile != null && modelProfile.length > 0) { - equalsBuilder.append(modelProfile, other.modelProfile); - } - if (shingleSize != -1) { - equalsBuilder.append(shingleSize, other.shingleSize); - } - if (coordinatingNode != null) { - equalsBuilder.append(coordinatingNode, other.coordinatingNode); - } - if (totalSizeInBytes != -1) { - equalsBuilder.append(totalSizeInBytes, other.totalSizeInBytes); - } - if (initProgress != null) { - equalsBuilder.append(initProgress, other.initProgress); - } - if (totalEntities != null) { - equalsBuilder.append(totalEntities, other.totalEntities); - } - if (activeEntities != null) { - equalsBuilder.append(activeEntities, other.activeEntities); - } - if (adTaskProfile != null) { - equalsBuilder.append(adTaskProfile, other.adTaskProfile); - } - if (modelCount > 0) { - equalsBuilder.append(modelCount, other.modelCount); - } - return equalsBuilder.isEquals(); - } - return false; + public DetectorProfile(StreamInput in) throws IOException { + super(in); } @Override - public int hashCode() { - return new HashCodeBuilder() - .append(state) - .append(error) - .append(modelProfile) - .append(shingleSize) - .append(coordinatingNode) - .append(totalSizeInBytes) - .append(initProgress) - .append(totalEntities) - .append(activeEntities) - .append(adTaskProfile) - .append(modelCount) - .toHashCode(); + protected ADTaskProfile createTaskProfile(StreamInput in) throws IOException { + return new ADTaskProfile(in); } @Override - public String toString() { - ToStringBuilder toStringBuilder = new ToStringBuilder(this); - - if (state != null) { - toStringBuilder.append(ADCommonName.STATE, state); - } - if (error != null) { - toStringBuilder.append(ADCommonName.ERROR, error); - } - if (modelProfile != null && modelProfile.length > 0) { - toStringBuilder.append(modelProfile); - } - if (shingleSize != -1) { - toStringBuilder.append(ADCommonName.SHINGLE_SIZE, shingleSize); - } - if (coordinatingNode != null) { - toStringBuilder.append(ADCommonName.COORDINATING_NODE, coordinatingNode); - } - if (totalSizeInBytes != -1) { - toStringBuilder.append(ADCommonName.TOTAL_SIZE_IN_BYTES, totalSizeInBytes); - } - if (initProgress != null) { - toStringBuilder.append(ADCommonName.INIT_PROGRESS, initProgress); - } - if (totalEntities != null) { - toStringBuilder.append(ADCommonName.TOTAL_ENTITIES, totalEntities); - } - if (activeEntities != null) { - toStringBuilder.append(ADCommonName.ACTIVE_ENTITIES, activeEntities); - } - if (adTaskProfile != null) { - toStringBuilder.append(ADCommonName.AD_TASK, adTaskProfile); - } - if (modelCount > 0) { - toStringBuilder.append(ADCommonName.MODEL_COUNT, modelCount); - } - return toStringBuilder.toString(); + protected String getTaskFieldName() { + return ADCommonName.AD_TASK; } } diff --git a/src/main/java/org/opensearch/ad/model/DetectorProfileName.java b/src/main/java/org/opensearch/ad/model/DetectorProfileName.java deleted file mode 100644 index 443066ac8..000000000 --- a/src/main/java/org/opensearch/ad/model/DetectorProfileName.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.model; - -import java.util.Collection; -import java.util.Set; - -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.timeseries.Name; - -public enum DetectorProfileName implements Name { - STATE(ADCommonName.STATE), - ERROR(ADCommonName.ERROR), - COORDINATING_NODE(ADCommonName.COORDINATING_NODE), - SHINGLE_SIZE(ADCommonName.SHINGLE_SIZE), - TOTAL_SIZE_IN_BYTES(ADCommonName.TOTAL_SIZE_IN_BYTES), - MODELS(ADCommonName.MODELS), - INIT_PROGRESS(ADCommonName.INIT_PROGRESS), - TOTAL_ENTITIES(ADCommonName.TOTAL_ENTITIES), - ACTIVE_ENTITIES(ADCommonName.ACTIVE_ENTITIES), - AD_TASK(ADCommonName.AD_TASK); - - private String name; - - DetectorProfileName(String name) { - this.name = name; - } - - /** - * Get profile name - * - * @return name - */ - @Override - public String getName() { - return name; - } - - public static DetectorProfileName getName(String name) { - switch (name) { - case ADCommonName.STATE: - return STATE; - case ADCommonName.ERROR: - return ERROR; - case ADCommonName.COORDINATING_NODE: - return COORDINATING_NODE; - case ADCommonName.SHINGLE_SIZE: - return SHINGLE_SIZE; - case ADCommonName.TOTAL_SIZE_IN_BYTES: - return TOTAL_SIZE_IN_BYTES; - case ADCommonName.MODELS: - return MODELS; - case ADCommonName.INIT_PROGRESS: - return INIT_PROGRESS; - case ADCommonName.TOTAL_ENTITIES: - return TOTAL_ENTITIES; - case ADCommonName.ACTIVE_ENTITIES: - return ACTIVE_ENTITIES; - case ADCommonName.AD_TASK: - return AD_TASK; - default: - throw new IllegalArgumentException(ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); - } - } - - public static Set getNames(Collection names) { - return Name.getNameFromCollection(names, DetectorProfileName::getName); - } -} diff --git a/src/main/java/org/opensearch/ad/model/EntityAnomalyResult.java b/src/main/java/org/opensearch/ad/model/EntityAnomalyResult.java index 7eeb02e6c..2d58c2be8 100644 --- a/src/main/java/org/opensearch/ad/model/EntityAnomalyResult.java +++ b/src/main/java/org/opensearch/ad/model/EntityAnomalyResult.java @@ -13,6 +13,8 @@ import java.util.List; +import org.opensearch.timeseries.model.Mergeable; + public class EntityAnomalyResult implements Mergeable { private List anomalyResults; diff --git a/src/main/java/org/opensearch/ad/model/Operator.java b/src/main/java/org/opensearch/ad/model/Operator.java new file mode 100644 index 000000000..ecb83ad86 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/Operator.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +// incomplete list. Will add more when needed. +public enum Operator { + LTE +} diff --git a/src/main/java/org/opensearch/ad/model/Rule.java b/src/main/java/org/opensearch/ad/model/Rule.java new file mode 100644 index 000000000..2acff353c --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/Rule.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +public class Rule implements Writeable, ToXContentObject { + private static final String ACTION_FIELD = "action"; + private static final String CONDITIONS_FIELD = "conditions"; + + private Action action; + private List conditions; + + public Rule(Action action, List conditions) { + this.action = action; + this.conditions = conditions; + } + + public Rule(StreamInput input) throws IOException { + this.action = input.readEnum(Action.class); + this.conditions = input.readList(Condition::new); + } + + /** + * Parse raw json content into rule instance. + * + * @param parser json based content parser + * @return rule instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static Rule parse(XContentParser parser) throws IOException { + Action action = null; + List conditions = new ArrayList<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + + parser.nextToken(); + switch (fieldName) { + case ACTION_FIELD: + action = Action.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case CONDITIONS_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + conditions.add(Condition.parse(parser)); + } + break; + default: + break; + } + } + return new Rule(action, conditions); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject().field(ACTION_FIELD, action).field(CONDITIONS_FIELD, conditions.toArray()); + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(action); + out.writeList(conditions); + } + + public Action getAction() { + return action; + } + + public List getConditions() { + return conditions; + } +} diff --git a/src/main/java/org/opensearch/ad/model/ThresholdType.java b/src/main/java/org/opensearch/ad/model/ThresholdType.java new file mode 100644 index 000000000..dd17751eb --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ThresholdType.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +/** + * Enumerates the types of thresholds used in anomaly detection. + * + * This enumeration defines various types of thresholds that dictate + * how anomalies are identified in comparison to actual and expected values. + * These thresholds include direct comparisons of actual and expected values, + * as well as differences and ratios between these values from specified directions. + */ +public enum ThresholdType { + /** + * Specifies a threshold for ignoring anomalies where the actual value + * exceeds the expected value by a certain margin. + * + * Assume a represents the actual value and b signifies the expected value. + * IGNORE_SIMILAR_FROM_ABOVE implies the anomaly should be disregarded if a-b + * is less than or equal to ignoreSimilarFromAbove. + */ + ACTUAL_OVER_EXPECTED_MARGIN("a margin by which the actual values exceed the expected one"), + + /** + * Specifies a threshold for ignoring anomalies where the actual value + * is below the expected value by a certain margin. + * + * Assume a represents the actual value and b signifies the expected value. + * Likewise, IGNORE_SIMILAR_FROM_BELOW + * implies the anomaly should be disregarded if b-a is less than or equal to + * ignoreSimilarFromBelow. + */ + EXPECTED_OVER_ACTUAL_MARGIN("a margin by which expected values exceed actual ones"), + + /** + * Specifies a threshold for ignoring anomalies based on the ratio of + * the difference to the actual value when the actual value exceeds + * the expected value. + * + * Assume a represents the actual value and b signifies the expected value. + * The variable IGNORE_NEAR_EXPECTED_FROM_ABOVE_BY_RATIO presumably implies the + * anomaly should be disregarded if the ratio of the deviation from the actual + * to the expected (a-b)/|a| is less than or equal to IGNORE_NEAR_EXPECTED_FROM_ABOVE_BY_RATIO. + */ + ACTUAL_OVER_EXPECTED_RATIO("the ratio of the actual value over the expected value"), + + /** + * Specifies a threshold for ignoring anomalies based on the ratio of + * the difference to the actual value when the actual value is below + * the expected value. + * + * Assume a represents the actual value and b signifies the expected value. + * Likewise, IGNORE_NEAR_EXPECTED_FROM_BELOW_BY_RATIO appears to indicate that the anomaly + * should be ignored if the ratio of the deviation from the expected to the actual + * (b-a)/|a| is less than or equal to ignoreNearExpectedFromBelowByRatio. + */ + EXPECTED_OVER_ACTUAL_RATIO("the ratio of the expected value over the actual value"); + + private final String description; + + /** + * Constructs a ThresholdType with a descriptive name. + * + * @param description The human-readable description of the threshold type. + */ + ThresholdType(String description) { + this.description = description; + } + + /** + * Retrieves the description of the threshold type. + * + * @return A string describing the threshold type. + */ + public String getDescription() { + return description; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointMaintainWorker.java similarity index 70% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java rename to src/main/java/org/opensearch/ad/ratelimit/ADCheckpointMaintainWorker.java index 05f9480a7..f2e5e3d8a 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointMaintainWorker.java @@ -16,28 +16,27 @@ import java.time.Clock; import java.time.Duration; -import java.util.ArrayList; -import java.util.List; import java.util.Optional; import java.util.Random; +import java.util.function.Function; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteRequest; -public class CheckpointMaintainWorker extends ScheduledWorker { - private static final Logger LOG = LogManager.getLogger(CheckpointMaintainWorker.class); - public static final String WORKER_NAME = "checkpoint-maintain"; +public class ADCheckpointMaintainWorker extends CheckpointMaintainWorker { + public static final String WORKER_NAME = "ad-checkpoint-maintain"; - private CheckPointMaintainRequestAdapter adapter; - - public CheckpointMaintainWorker( + public ADCheckpointMaintainWorker( long heapSizeInBytes, int singleRequestSizeInBytes, Setting maxHeapPercentForQueueSetting, @@ -51,10 +50,10 @@ public CheckpointMaintainWorker( float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, - CheckpointWriteWorker checkpointWriteQueue, + ADCheckpointWriteWorker checkpointWriteQueue, Duration stateTtl, NodeStateManager nodeStateManager, - CheckPointMaintainRequestAdapter adapter + Function> converter ) { super( WORKER_NAME, @@ -65,6 +64,7 @@ public CheckpointMaintainWorker( random, adCircuitBreakerService, threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, settings, maxQueuedTaskRatio, clock, @@ -73,7 +73,9 @@ public CheckpointMaintainWorker( maintenanceFreqConstant, checkpointWriteQueue, stateTtl, - nodeStateManager + nodeStateManager, + converter, + AnalysisType.AD ); this.batchSize = AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.get(settings); @@ -87,18 +89,5 @@ public CheckpointMaintainWorker( AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, it -> this.expectedExecutionTimeInMilliSecsPerRequest = it ); - this.adapter = adapter; - } - - @Override - protected List transformRequests(List requests) { - List allRequests = new ArrayList<>(); - for (CheckpointMaintainRequest request : requests) { - Optional converted = adapter.convert(request); - if (!converted.isEmpty()) { - allRequests.add(converted.get()); - } - } - return allRequests; } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java new file mode 100644 index 000000000..40ea61ae6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; +import org.opensearch.timeseries.stats.StatNames; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * a queue for loading model checkpoint. The read is a multi-get query. Possible results are: + * a). If a checkpoint is not found, we forward that request to the cold start queue. + * b). When a request gets errors, the queue does not change its expiry time and puts + * that request to the end of the queue and automatically retries them before they expire. + * c) When a checkpoint is found, we load that point to memory and score the input + * data point and save the result if a complete model exists. Otherwise, we enqueue + * the sample. If we can host that model in memory (e.g., there is enough memory), + * we put the loaded model to cache. Otherwise (e.g., a cold entity), we write the + * updated checkpoint back to disk. + * + */ +public class ADCheckpointReadWorker extends + CheckpointReadWorker { + public static final String WORKER_NAME = "ad-checkpoint-read"; + + public ADCheckpointReadWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADModelManager modelManager, + ADCheckpointDao checkpointDao, + ADColdStartWorker entityColdStartQueue, + NodeStateManager stateManager, + ADIndexManagement indexUtil, + Provider cacheProvider, + Duration stateTtl, + ADCheckpointWriteWorker checkpointWriteQueue, + ADStats adStats, + ADSaveResultStrategy resultWriteWorker + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + executionTtl, + modelManager, + checkpointDao, + entityColdStartQueue, + stateManager, + indexUtil, + cacheProvider, + stateTtl, + checkpointWriteQueue, + adStats, + AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, + AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + ADCommonName.CHECKPOINT_INDEX_NAME, + StatNames.AD_MODEL_CORRUTPION_COUNT, + AnalysisType.AD, + resultWriteWorker + ); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointWriteWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointWriteWorker.java new file mode 100644 index 000000000..fd0bc3a66 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointWriteWorker.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ADCheckpointWriteWorker extends + CheckpointWriteWorker { + public static final String WORKER_NAME = "ad-checkpoint-write"; + + public ADCheckpointWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADCheckpointDao checkpoint, + String indexName, + Duration checkpointInterval, + NodeStateManager adNodeStateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + adNodeStateManager, + checkpoint, + indexName, + checkpointInterval, + AnalysisType.AD + ); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java similarity index 68% rename from src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java rename to src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java index 701fc25d4..0abd7527d 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java @@ -16,17 +16,27 @@ import java.time.Clock; import java.time.Duration; -import java.util.List; import java.util.Random; -import java.util.stream.Collectors; -import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.ColdEntityWorker; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; /** * A queue slowly releasing low-priority requests to CheckpointReadQueue @@ -43,10 +53,11 @@ * entity requests.  * */ -public class ColdEntityWorker extends ScheduledWorker { - public static final String WORKER_NAME = "cold-entity"; +public class ADColdEntityWorker extends + ColdEntityWorker { + public static final String WORKER_NAME = "ad-cold-entity"; - public ColdEntityWorker( + public ADColdEntityWorker( long heapSizeInBytes, int singleRequestSizeInBytes, Setting maxHeapPercentForQueueSetting, @@ -60,7 +71,7 @@ public ColdEntityWorker( float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, - CheckpointReadWorker checkpointReadQueue, + ADCheckpointReadWorker checkpointReadQueue, Duration stateTtl, NodeStateManager nodeStateManager ) { @@ -73,6 +84,7 @@ public ColdEntityWorker( random, adCircuitBreakerService, threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, settings, maxQueuedTaskRatio, clock, @@ -81,25 +93,10 @@ public ColdEntityWorker( maintenanceFreqConstant, checkpointReadQueue, stateTtl, - nodeStateManager + nodeStateManager, + AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnalysisType.AD ); - - this.batchSize = AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, it -> this.batchSize = it); - - this.expectedExecutionTimeInMilliSecsPerRequest = AnomalyDetectorSettings.AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS - .get(settings); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer( - AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, - it -> this.expectedExecutionTimeInMilliSecsPerRequest = it - ); - } - - @Override - protected List transformRequests(List requests) { - // guarantee we only send low priority requests - return requests.stream().filter(request -> request.priority == RequestPriority.LOW).collect(Collectors.toList()); } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java new file mode 100644 index 000000000..09cd82a4a --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java @@ -0,0 +1,146 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayDeque; +import java.util.Optional; +import java.util.Random; + +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.ColdStartWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.util.ParseUtils; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * A queue for HCAD model training (a.k.a. cold start). As model training is a + * pretty expensive operation, we pull cold start requests from the queue in a + * serial fashion. Each detector has an equal chance of being pulled. The equal + * probability is achieved by putting model training requests for different + * detectors into different segments and pulling requests from segments in a + * round-robin fashion. + * + */ + +// suppress warning due to the use of generic type ModelState +public class ADColdStartWorker extends + ColdStartWorker { + public static final String WORKER_NAME = "ad-cold-start"; + + public ADColdStartWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADColdStart entityColdStarter, + Duration stateTtl, + NodeStateManager nodeStateManager, + ADPriorityCache cacheProvider, + ADModelManager modelManager, + ADSaveResultStrategy saveStrategy + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_ENTITY_COLD_START_QUEUE_CONCURRENCY, + executionTtl, + entityColdStarter, + stateTtl, + nodeStateManager, + cacheProvider, + AnalysisType.AD, + modelManager, + saveStrategy + ); + } + + @Override + protected ModelState createEmptyState(FeatureRequest request, String modelId, String configId) { + return new ModelState( + null, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + request.getEntity(), + new ArrayDeque<>() + ); + } + + @Override + protected AnomalyResult createIndexableResult(Config config, String taskId, String modelId, Sample entry, Optional entity) { + return new AnomalyResult( + config.getId(), + taskId, + ParseUtils.getFeatureData(entry.getValueList(), config), + entry.getDataStartTime(), + entry.getDataEndTime(), + Instant.now(), + Instant.now(), + "", + entity, + config.getUser(), + config.getSchemaVersion(), + modelId + ); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteRequest.java b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteRequest.java new file mode 100644 index 000000000..912396ebd --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteRequest.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.io.IOException; + +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; + +public class ADResultWriteRequest extends ResultWriteRequest { + + public ADResultWriteRequest( + long expirationEpochMs, + String detectorId, + RequestPriority priority, + AnomalyResult result, + String resultIndex + ) { + super(expirationEpochMs, detectorId, priority, result, resultIndex); + } + + public ADResultWriteRequest(StreamInput in) throws IOException { + super(in, AnomalyResult::new); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteWorker.java new file mode 100644 index 000000000..b57e99f1c --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteWorker.java @@ -0,0 +1,110 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteWorker; + +public class ADResultWriteWorker extends + ResultWriteWorker { + public static final String WORKER_NAME = "ad-result-write"; + + public ADResultWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADIndexMemoryPressureAwareResultHandler resultHandler, + NamedXContentRegistry xContentRegistry, + NodeStateManager stateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_RESULT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + AD_RESULT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + stateManager, + resultHandler, + xContentRegistry, + AnomalyResult::parse, + AnalysisType.AD + ); + } + + @Override + protected ADResultBulkRequest toBatchRequest(List toProcess) { + final ADResultBulkRequest bulkRequest = new ADResultBulkRequest(); + for (ADResultWriteRequest request : toProcess) { + bulkRequest.add(request); + } + return bulkRequest; + } + + @Override + protected ADResultWriteRequest createResultWriteRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + AnomalyResult result, + String resultIndex + ) { + return new ADResultWriteRequest(expirationEpochMs, configId, priority, result, resultIndex); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADSaveResultStrategy.java b/src/main/java/org/opensearch/ad/ratelimit/ADSaveResultStrategy.java new file mode 100644 index 000000000..aeb265072 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADSaveResultStrategy.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.ratelimit; + +import java.time.Instant; +import java.util.List; +import java.util.Optional; + +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.SaveResultStrategy; +import org.opensearch.timeseries.util.ParseUtils; + +public class ADSaveResultStrategy implements SaveResultStrategy { + private int resultMappingVersion; + private ADResultWriteWorker resultWriteWorker; + + public ADSaveResultStrategy(int resultMappingVersion, ADResultWriteWorker resultWriteWorker) { + this.resultMappingVersion = resultMappingVersion; + this.resultWriteWorker = resultWriteWorker; + } + + @Override + public void saveResult(ThresholdingResult result, Config config, FeatureRequest origRequest, String modelId) { + // result.getRcfScore() = 0 means the model is not initialized + // result.getGrade() = 0 means it is not an anomaly + saveResult( + result, + config, + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds()), + modelId, + origRequest.getCurrentFeature(), + origRequest.getEntity(), + origRequest.getTaskId() + ); + } + + @Override + public void saveResult( + ThresholdingResult result, + Config config, + Instant dataStart, + Instant dataEnd, + String modelId, + double[] currentData, + Optional entity, + String taskId + ) { + // result.getRcfScore() = 0 means the model is not initialized + // result.getGrade() = 0 means it is not an anomaly + if (result != null && result.getRcfScore() > 0) { + List indexableResults = result + .toIndexableResults( + config, + dataStart, + dataEnd, + Instant.now(), + Instant.now(), + ParseUtils.getFeatureData(currentData, config), + entity, + resultMappingVersion, + modelId, + taskId, + null + ); + + for (AnomalyResult r : indexableResults) { + saveResult(r, config); + } + } + } + + @Override + public void saveResult(AnomalyResult result, Config config) { + resultWriteWorker + .put( + new ADResultWriteRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + config.getId(), + result.getAnomalyGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, + result, + config.getCustomResultIndex() + ) + ); + } + +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java b/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java deleted file mode 100644 index 72011e156..000000000 --- a/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java +++ /dev/null @@ -1,162 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_CONCURRENCY; - -import java.time.Clock; -import java.time.Duration; -import java.util.ArrayDeque; -import java.util.Locale; -import java.util.Optional; -import java.util.Random; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.AnalysisType; -import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.breaker.CircuitBreakerService; -import org.opensearch.timeseries.util.ExceptionUtil; - -/** - * A queue for HCAD model training (a.k.a. cold start). As model training is a - * pretty expensive operation, we pull cold start requests from the queue in a - * serial fashion. Each detector has an equal chance of being pulled. The equal - * probability is achieved by putting model training requests for different - * detectors into different segments and pulling requests from segments in a - * round-robin fashion. - * - */ -public class EntityColdStartWorker extends SingleRequestWorker { - private static final Logger LOG = LogManager.getLogger(EntityColdStartWorker.class); - public static final String WORKER_NAME = "cold-start"; - - private final EntityColdStarter entityColdStarter; - private final CacheProvider cacheProvider; - - public EntityColdStartWorker( - long heapSizeInBytes, - int singleRequestSizeInBytes, - Setting maxHeapPercentForQueueSetting, - ClusterService clusterService, - Random random, - CircuitBreakerService adCircuitBreakerService, - ThreadPool threadPool, - Settings settings, - float maxQueuedTaskRatio, - Clock clock, - float mediumSegmentPruneRatio, - float lowSegmentPruneRatio, - int maintenanceFreqConstant, - Duration executionTtl, - EntityColdStarter entityColdStarter, - Duration stateTtl, - NodeStateManager nodeStateManager, - CacheProvider cacheProvider - ) { - super( - WORKER_NAME, - heapSizeInBytes, - singleRequestSizeInBytes, - maxHeapPercentForQueueSetting, - clusterService, - random, - adCircuitBreakerService, - threadPool, - settings, - maxQueuedTaskRatio, - clock, - mediumSegmentPruneRatio, - lowSegmentPruneRatio, - maintenanceFreqConstant, - AD_ENTITY_COLD_START_QUEUE_CONCURRENCY, - executionTtl, - stateTtl, - nodeStateManager - ); - this.entityColdStarter = entityColdStarter; - this.cacheProvider = cacheProvider; - } - - @Override - protected void executeRequest(EntityRequest coldStartRequest, ActionListener listener) { - String detectorId = coldStartRequest.getId(); - - Optional modelId = coldStartRequest.getModelId(); - - if (false == modelId.isPresent()) { - String error = String.format(Locale.ROOT, "Fail to get model id for request %s", coldStartRequest); - LOG.warn(error); - listener.onFailure(new RuntimeException(error)); - return; - } - - ModelState modelState = new ModelState<>( - new EntityModel(coldStartRequest.getEntity(), new ArrayDeque<>(), null), - modelId.get(), - detectorId, - ModelType.ENTITY.getName(), - clock, - 0 - ); - - ActionListener coldStartListener = ActionListener.wrap(r -> { - nodeStateManager.getConfig(detectorId, AnalysisType.AD, ActionListener.wrap(detectorOptional -> { - try { - if (!detectorOptional.isPresent()) { - LOG - .error( - new ParameterizedMessage( - "fail to load trained model [{}] to cache due to the detector not being found.", - modelState.getModelId() - ) - ); - return; - } - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); - EntityModel model = modelState.getModel(); - // load to cache if cold start succeeds - if (model != null && model.getTrcf() != null) { - cacheProvider.get().hostIfPossible(detector, modelState); - } - } finally { - listener.onResponse(null); - } - }, listener::onFailure)); - - }, e -> { - try { - if (ExceptionUtil.isOverloaded(e)) { - LOG.error("OpenSearch is overloaded"); - setCoolDownStart(); - } - nodeStateManager.setException(detectorId, e); - } finally { - listener.onFailure(e); - } - }); - - entityColdStarter.trainModel(coldStartRequest.getEntity(), detectorId, modelState, coldStartListener); - } -} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java b/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java deleted file mode 100644 index 875974dbb..000000000 --- a/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ratelimit; - -import org.opensearch.timeseries.model.Entity; - -public class EntityFeatureRequest extends EntityRequest { - private final double[] currentFeature; - private final long dataStartTimeMillis; - - public EntityFeatureRequest( - long expirationEpochMs, - String detectorId, - RequestPriority priority, - Entity entity, - double[] currentFeature, - long dataStartTimeMs - ) { - super(expirationEpochMs, detectorId, priority, entity); - this.currentFeature = currentFeature; - this.dataStartTimeMillis = dataStartTimeMs; - } - - public double[] getCurrentFeature() { - return currentFeature; - } - - public long getDataStartTimeMillis() { - return dataStartTimeMillis; - } -} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java b/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java deleted file mode 100644 index 7acf2652a..000000000 --- a/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ratelimit; - -import java.util.Optional; - -import org.opensearch.timeseries.model.Entity; - -public class EntityRequest extends QueuedRequest { - private final Entity entity; - - /** - * - * @param expirationEpochMs Expiry time of the request - * @param detectorId Detector Id - * @param priority the entity's priority - * @param entity the entity's attributes - */ - public EntityRequest(long expirationEpochMs, String detectorId, RequestPriority priority, Entity entity) { - super(expirationEpochMs, detectorId, priority); - this.entity = entity; - } - - public Entity getEntity() { - return entity; - } - - public Optional getModelId() { - return entity.getModelId(detectorId); - } -} diff --git a/src/main/java/org/opensearch/ad/rest/AbstractADSearchAction.java b/src/main/java/org/opensearch/ad/rest/AbstractADSearchAction.java new file mode 100644 index 000000000..ef901f40c --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/AbstractADSearchAction.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.rest; + +import java.util.List; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.timeseries.AbstractSearchAction; + +public abstract class AbstractADSearchAction extends AbstractSearchAction { + + public AbstractADSearchAction( + List urlPaths, + List> deprecatedPaths, + String index, + Class clazz, + ActionType actionType + ) { + super(urlPaths, deprecatedPaths, index, clazz, actionType, ADEnabledSetting::isADEnabled, ADCommonMessages.DISABLED_ERR_MSG); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java index ee0d410f5..4a10b3ad9 100644 --- a/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java @@ -18,6 +18,7 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.DETECTION_WINDOW_DELAY; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_FEATURES; +import org.opensearch.ad.settings.ADNumericSetting; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; @@ -31,6 +32,7 @@ public abstract class AbstractAnomalyDetectorAction extends BaseRestHandler { protected volatile Integer maxSingleEntityDetectors; protected volatile Integer maxMultiEntityDetectors; protected volatile Integer maxAnomalyFeatures; + protected volatile Integer maxCategoricalFields; public AbstractAnomalyDetectorAction(Settings settings, ClusterService clusterService) { this.requestTimeout = AD_REQUEST_TIMEOUT.get(settings); @@ -39,6 +41,7 @@ public AbstractAnomalyDetectorAction(Settings settings, ClusterService clusterSe this.maxSingleEntityDetectors = AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings); this.maxMultiEntityDetectors = AD_MAX_HC_ANOMALY_DETECTORS.get(settings); this.maxAnomalyFeatures = MAX_ANOMALY_FEATURES.get(settings); + this.maxCategoricalFields = ADNumericSetting.maxCategoricalFields(); // TODO: will add more cluster setting consumer later // TODO: inject ClusterSettings only if clusterService is only used to get ClusterSettings clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_REQUEST_TIMEOUT, it -> requestTimeout = it); diff --git a/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java b/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java index 175ac02e7..14ef4c652 100644 --- a/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java @@ -12,10 +12,7 @@ package org.opensearch.ad.rest; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; -import static org.opensearch.timeseries.util.RestHandlerUtils.IF_PRIMARY_TERM; -import static org.opensearch.timeseries.util.RestHandlerUtils.IF_SEQ_NO; import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; @@ -26,25 +23,23 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.AnomalyDetectorJobAction; -import org.opensearch.ad.transport.AnomalyDetectorJobRequest; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.seqno.SequenceNumbers; -import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.rest.RestJobAction; +import org.opensearch.timeseries.transport.JobRequest; import com.google.common.collect.ImmutableList; /** * This class consists of the REST handler to handle request to start/stop AD job. */ -public class RestAnomalyDetectorJobAction extends BaseRestHandler { +public class RestAnomalyDetectorJobAction extends RestJobAction { public static final String AD_JOB_ACTION = "anomaly_detector_job_action"; private volatile TimeValue requestTimeout; @@ -66,40 +61,16 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } String detectorId = request.param(DETECTOR_ID); - long seqNo = request.paramAsLong(IF_SEQ_NO, SequenceNumbers.UNASSIGNED_SEQ_NO); - long primaryTerm = request.paramAsLong(IF_PRIMARY_TERM, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); boolean historical = request.paramAsBoolean("historical", false); String rawPath = request.rawPath(); - DateRange detectionDateRange = parseDetectionDateRange(request); + DateRange detectionDateRange = parseInputDateRange(request); - AnomalyDetectorJobRequest anomalyDetectorJobRequest = new AnomalyDetectorJobRequest( - detectorId, - detectionDateRange, - historical, - seqNo, - primaryTerm, - rawPath - ); + JobRequest anomalyDetectorJobRequest = new JobRequest(detectorId, detectionDateRange, historical, rawPath); return channel -> client .execute(AnomalyDetectorJobAction.INSTANCE, anomalyDetectorJobRequest, new RestToXContentListener<>(channel)); } - private DateRange parseDetectionDateRange(RestRequest request) throws IOException { - if (!request.hasContent()) { - return null; - } - XContentParser parser = request.contentParser(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - DateRange dateRange = DateRange.parse(parser); - return dateRange; - } - - @Override - public List routes() { - return ImmutableList.of(); - } - @Override public List replacedRoutes() { return ImmutableList diff --git a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java index b7a3aae6c..1ad4f0a9a 100644 --- a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java @@ -17,18 +17,15 @@ import java.util.List; import java.util.Locale; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.rest.handler.AnomalyDetectorActionHandler; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.DeleteAnomalyDetectorAction; -import org.opensearch.ad.transport.DeleteAnomalyDetectorRequest; import org.opensearch.client.node.NodeClient; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.DeleteConfigRequest; import com.google.common.collect.ImmutableList; @@ -39,9 +36,6 @@ public class RestDeleteAnomalyDetectorAction extends BaseRestHandler { public static final String DELETE_ANOMALY_DETECTOR_ACTION = "delete_anomaly_detector"; - private static final Logger logger = LogManager.getLogger(RestDeleteAnomalyDetectorAction.class); - private final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); - public RestDeleteAnomalyDetectorAction() {} @Override @@ -56,7 +50,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } String detectorId = request.param(DETECTOR_ID); - DeleteAnomalyDetectorRequest deleteAnomalyDetectorRequest = new DeleteAnomalyDetectorRequest(detectorId); + DeleteConfigRequest deleteAnomalyDetectorRequest = new DeleteConfigRequest(detectorId); return channel -> client .execute(DeleteAnomalyDetectorAction.INSTANCE, deleteAnomalyDetectorRequest, new RestToXContentListener<>(channel)); } diff --git a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java index 315ba0410..8ef0bb473 100644 --- a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java @@ -18,24 +18,20 @@ import java.io.IOException; import java.util.List; import java.util.Locale; -import java.util.Optional; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.GetAnomalyDetectorAction; -import org.opensearch.ad.transport.GetAnomalyDetectorRequest; import org.opensearch.client.node.NodeClient; -import org.opensearch.core.common.Strings; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestActions; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.util.RestHandlerUtils; import com.google.common.collect.ImmutableList; @@ -66,7 +62,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli boolean returnJob = request.paramAsBoolean("job", false); boolean returnTask = request.paramAsBoolean("task", false); boolean all = request.paramAsBoolean("_all", false); - GetAnomalyDetectorRequest getConfigRequest = new GetAnomalyDetectorRequest( + GetConfigRequest getConfigRequest = new GetConfigRequest( detectorId, RestActions.parseVersion(request), returnJob, @@ -74,7 +70,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli typesStr, rawPath, all, - buildEntity(request, detectorId) + RestHandlerUtils.buildEntity(request, detectorId) ); return channel -> client.execute(GetAnomalyDetectorAction.INSTANCE, getConfigRequest, new RestToXContentListener<>(channel)); @@ -137,35 +133,4 @@ public List replacedRoutes() { ) ); } - - private Entity buildEntity(RestRequest request, String detectorId) throws IOException { - if (Strings.isEmpty(detectorId)) { - throw new IllegalStateException(ADCommonMessages.AD_ID_MISSING_MSG); - } - - String entityName = request.param(ADCommonName.CATEGORICAL_FIELD); - String entityValue = request.param(CommonName.ENTITY_KEY); - - if (entityName != null && entityValue != null) { - // single-stream profile request: - // GET _plugins/_anomaly_detection/detectors//_profile/init_progress?category_field=&entity= - return Entity.createSingleAttributeEntity(entityName, entityValue); - } else if (request.hasContent()) { - /* HCAD profile request: - * GET _plugins/_anomaly_detection/detectors//_profile/init_progress - * { - * "entity": [{ - * "name": "clientip", - * "value": "13.24.0.0" - * }] - * } - */ - Optional entity = Entity.fromJsonObject(request.contentParser()); - if (entity.isPresent()) { - return entity.get(); - } - } - // not a valid profile request with correct entity information - return null; - } } diff --git a/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java index 6231d8e11..66981d54c 100644 --- a/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java @@ -94,7 +94,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli requestTimeout, maxSingleEntityDetectors, maxMultiEntityDetectors, - maxAnomalyFeatures + maxAnomalyFeatures, + maxCategoricalFields ); return channel -> client diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java index 6a1bfce58..a858d46aa 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java @@ -22,7 +22,7 @@ /** * This class consists of the REST handler to search AD tasks. */ -public class RestSearchADTasksAction extends AbstractSearchAction { +public class RestSearchADTasksAction extends AbstractADSearchAction { private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/tasks/_search"; private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/tasks/_search"; diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java index 214fa8b2c..a5c1551e7 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java @@ -22,7 +22,7 @@ /** * This class consists of the REST handler to search anomaly detectors. */ -public class RestSearchAnomalyDetectorAction extends AbstractSearchAction { +public class RestSearchAnomalyDetectorAction extends AbstractADSearchAction { private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/_search"; private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/_search"; diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java index 1f2ade113..0b7f748c7 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java @@ -23,12 +23,12 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.SearchAnomalyDetectorInfoAction; -import org.opensearch.ad.transport.SearchAnomalyDetectorInfoRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.SearchConfigInfoRequest; import com.google.common.collect.ImmutableList; @@ -54,7 +54,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, org.opensearch String detectorName = request.param("name", null); String rawPath = request.rawPath(); - SearchAnomalyDetectorInfoRequest searchAnomalyDetectorInfoRequest = new SearchAnomalyDetectorInfoRequest(detectorName, rawPath); + SearchConfigInfoRequest searchAnomalyDetectorInfoRequest = new SearchConfigInfoRequest(detectorName, rawPath); return channel -> client .execute(SearchAnomalyDetectorInfoAction.INSTANCE, searchAnomalyDetectorInfoRequest, new RestToXContentListener<>(channel)); } diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java index 9db521595..b014ca753 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java @@ -35,7 +35,7 @@ /** * This class consists of the REST handler to search anomaly results. */ -public class RestSearchAnomalyResultAction extends AbstractSearchAction { +public class RestSearchAnomalyResultAction extends AbstractADSearchAction { private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/results/_search"; private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/results/_search"; public static final String SEARCH_ANOMALY_RESULT_ACTION = "search_anomaly_result"; diff --git a/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java index 65b936e98..ddceab44a 100644 --- a/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java @@ -14,47 +14,36 @@ import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BASE_URI; import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE; -import java.util.Arrays; -import java.util.HashSet; import java.util.List; -import java.util.Set; -import java.util.TreeSet; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.transport.ADStatsRequest; import org.opensearch.ad.transport.StatsAnomalyDetectorAction; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.core.common.Strings; -import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.rest.RestStatsAction; +import org.opensearch.timeseries.transport.StatsRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import com.google.common.collect.ImmutableList; /** - * RestStatsAnomalyDetectorAction consists of the REST handler to get the stats from the anomaly detector plugin. + * RestStatsAnomalyDetectorAction consists of the REST handler to get the stats from AD. */ -public class RestStatsAnomalyDetectorAction extends BaseRestHandler { +public class RestStatsAnomalyDetectorAction extends RestStatsAction { private static final String STATS_ANOMALY_DETECTOR_ACTION = "stats_anomaly_detector"; - private ADStats adStats; - private ClusterService clusterService; - private DiscoveryNodeFilterer nodeFilter; /** * Constructor * - * @param adStats ADStats object + * @param timeSeriesStats TimeSeriesStats object * @param nodeFilter util class to get eligible data nodes */ - public RestStatsAnomalyDetectorAction(ADStats adStats, DiscoveryNodeFilterer nodeFilter) { - this.adStats = adStats; - this.nodeFilter = nodeFilter; + public RestStatsAnomalyDetectorAction(ADStats timeSeriesStats, DiscoveryNodeFilterer nodeFilter) { + super(timeSeriesStats, nodeFilter); } @Override @@ -67,64 +56,10 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli if (!ADEnabledSetting.isADEnabled()) { throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); } - ADStatsRequest adStatsRequest = getRequest(request); + StatsRequest adStatsRequest = getRequest(request); return channel -> client.execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest, new RestToXContentListener<>(channel)); } - /** - * Creates a ADStatsRequest from a RestRequest - * - * @param request RestRequest - * @return ADStatsRequest Request containing stats to be retrieved - */ - private ADStatsRequest getRequest(RestRequest request) { - // parse the nodes the user wants to query the stats for - String nodesIdsStr = request.param("nodeId"); - Set validStats = adStats.getStats().keySet(); - - ADStatsRequest adStatsRequest = null; - if (!Strings.isEmpty(nodesIdsStr)) { - String[] nodeIdsArr = nodesIdsStr.split(","); - adStatsRequest = new ADStatsRequest(nodeIdsArr); - } else { - DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); - adStatsRequest = new ADStatsRequest(dataNodes); - } - - adStatsRequest.timeout(request.param("timeout")); - - // parse the stats the user wants to see - HashSet statsSet = null; - String statsStr = request.param("stat"); - if (!Strings.isEmpty(statsStr)) { - statsSet = new HashSet<>(Arrays.asList(statsStr.split(","))); - } - - if (statsSet == null) { - adStatsRequest.addAll(validStats); // retrieve all stats if none are specified - } else if (statsSet.size() == 1 && statsSet.contains(ADStatsRequest.ALL_STATS_KEY)) { - adStatsRequest.addAll(validStats); - } else if (statsSet.contains(ADStatsRequest.ALL_STATS_KEY)) { - throw new IllegalArgumentException( - "Request " + request.path() + " contains " + ADStatsRequest.ALL_STATS_KEY + " and individual stats" - ); - } else { - Set invalidStats = new TreeSet<>(); - for (String stat : statsSet) { - if (validStats.contains(stat)) { - adStatsRequest.addStat(stat); - } else { - invalidStats.add(stat); - } - } - - if (!invalidStats.isEmpty()) { - throw new IllegalArgumentException(unrecognized(request, invalidStats, adStatsRequest.getStatsToBeRetrieved(), "stat")); - } - } - return adStatsRequest; - } - @Override public List routes() { return ImmutableList.of(); diff --git a/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java index e728889f8..91d72dcf9 100644 --- a/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java @@ -16,35 +16,25 @@ import static org.opensearch.timeseries.util.RestHandlerUtils.VALIDATE; import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; import java.util.List; import java.util.Locale; -import java.util.Set; -import java.util.stream.Collectors; -import org.apache.commons.lang3.StringUtils; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.ValidateAnomalyDetectorAction; -import org.opensearch.ad.transport.ValidateAnomalyDetectorRequest; -import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.rest.BaseRestHandler; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.ValidationException; -import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ConfigValidationIssue; +import org.opensearch.timeseries.rest.RestValidateAction; +import org.opensearch.timeseries.transport.ValidateConfigRequest; import com.google.common.collect.ImmutableList; @@ -54,14 +44,18 @@ public class RestValidateAnomalyDetectorAction extends AbstractAnomalyDetectorAction { private static final String VALIDATE_ANOMALY_DETECTOR_ACTION = "validate_anomaly_detector_action"; - public static final Set ALL_VALIDATION_ASPECTS_STRS = Arrays - .asList(ValidationAspect.values()) - .stream() - .map(aspect -> aspect.getName()) - .collect(Collectors.toSet()); + private RestValidateAction validateAction; public RestValidateAnomalyDetectorAction(Settings settings, ClusterService clusterService) { super(settings, clusterService); + this.validateAction = new RestValidateAction( + AnalysisType.FORECAST, + maxSingleEntityDetectors, + maxMultiEntityDetectors, + maxAnomalyFeatures, + maxCategoricalFields, + requestTimeout + ); } @Override @@ -84,66 +78,35 @@ public List routes() { ); } - protected void sendAnomalyDetectorValidationParseResponse(DetectorValidationIssue issue, RestChannel channel) throws IOException { - try { - BytesRestResponse restResponse = new BytesRestResponse( - RestStatus.OK, - new ValidateAnomalyDetectorResponse(issue).toXContent(channel.newBuilder()) - ); - channel.sendResponse(restResponse); - } catch (Exception e) { - channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); - } - } - - private Boolean validationTypesAreAccepted(String validationType) { - Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); - return (!Collections.disjoint(typesInRequest, ALL_VALIDATION_ASPECTS_STRS)); - } - @Override protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { if (!ADEnabledSetting.isADEnabled()) { throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); } + XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + // we have to get the param from a subclass of BaseRestHandler. Otherwise, we cannot parse the type out of request params String typesStr = request.param(TYPE); - // if type param isn't blank and isn't a part of possible validation types throws exception - if (!StringUtils.isBlank(typesStr)) { - if (!validationTypesAreAccepted(typesStr)) { - throw new IllegalStateException(ADCommonMessages.NOT_EXISTENT_VALIDATION_TYPE); - } - } - return channel -> { - AnomalyDetector detector; try { - detector = AnomalyDetector.parse(parser); + ValidateConfigRequest validateAnomalyDetectorRequest = validateAction.prepareRequest(request, client, typesStr); + client + .execute(ValidateAnomalyDetectorAction.INSTANCE, validateAnomalyDetectorRequest, new RestToXContentListener<>(channel)); } catch (Exception ex) { if (ex instanceof ValidationException) { - ValidationException ADException = (ValidationException) ex; - DetectorValidationIssue issue = new DetectorValidationIssue( - ADException.getAspect(), - ADException.getType(), - ADException.getMessage() + ValidationException adException = (ValidationException) ex; + ConfigValidationIssue issue = new ConfigValidationIssue( + adException.getAspect(), + adException.getType(), + adException.getMessage() ); - sendAnomalyDetectorValidationParseResponse(issue, channel); - return; + validateAction.sendValidationParseResponse(issue, channel); } else { throw ex; } } - ValidateAnomalyDetectorRequest validateAnomalyDetectorRequest = new ValidateAnomalyDetectorRequest( - detector, - typesStr, - maxSingleEntityDetectors, - maxMultiEntityDetectors, - maxAnomalyFeatures, - requestTimeout - ); - client.execute(ValidateAnomalyDetectorAction.INSTANCE, validateAnomalyDetectorRequest, new RestToXContentListener<>(channel)); }; } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/ADIndexJobActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ADIndexJobActionHandler.java new file mode 100644 index 000000000..fc0eb5cfe --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/handler/ADIndexJobActionHandler.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.rest.handler; + +import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; +import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; + +import java.util.List; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.ADProfileAction; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultRequest; +import org.opensearch.ad.transport.StopDetectorAction; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.ResultRequest; +import org.opensearch.transport.TransportService; + +public class ADIndexJobActionHandler extends + IndexJobActionHandler { + + public ADIndexJobActionHandler( + Client client, + ADIndexManagement indexManagement, + NamedXContentRegistry xContentRegistry, + ADTaskManager adTaskManager, + ExecuteADResultResponseRecorder recorder, + NodeStateManager nodeStateManager, + Settings settings + ) { + super( + client, + indexManagement, + xContentRegistry, + adTaskManager, + recorder, + AnomalyResultAction.INSTANCE, + AnalysisType.AD, + DETECTION_STATE_INDEX, + StopDetectorAction.INSTANCE, + nodeStateManager, + settings, + AD_REQUEST_TIMEOUT + ); + } + + @Override + protected ResultRequest createResultRequest(String configID, long start, long end) { + return new AnomalyResultRequest(configID, start, end); + } + + @Override + protected List getBatchConfigTaskTypes() { + return HISTORICAL_DETECTOR_TASK_TYPES; + } + + /** + * Stop config. + * For realtime config, will set job as disabled. + * For historical config, will set its task as cancelled. + * + * @param configId config id + * @param historical stop historical analysis or not + * @param user user + * @param transportService transport service + * @param listener action listener + */ + @Override + public void stopConfig( + String configId, + boolean historical, + User user, + TransportService transportService, + ActionListener listener + ) { + // make sure detector exists + nodeStateManager.getConfig(configId, AnalysisType.AD, (config) -> { + if (!config.isPresent()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); + return; + } + if (historical) { + // stop historical analyis + taskManager + .getAndExecuteOnLatestConfigLevelTask( + configId, + getBatchConfigTaskTypes(), + (task) -> taskManager.stopHistoricalAnalysis(configId, task, user, listener), + transportService, + true,// reset task state when stop config + listener + ); + } else { + // stop realtime detector job + stopJob(configId, transportService, listener); + } + }, listener); + } + +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/ADModelValidationActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ADModelValidationActionHandler.java new file mode 100644 index 000000000..78a1dfbe9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/handler/ADModelValidationActionHandler.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.rest.handler; + +import java.time.Clock; + +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.rest.handler.ModelValidationActionHandler; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class ADModelValidationActionHandler extends ModelValidationActionHandler { + + public ADModelValidationActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + ActionListener listener, + AnomalyDetector config, + TimeValue requestTimeout, + NamedXContentRegistry xContentRegistry, + SearchFeatureDao searchFeatureDao, + String validationType, + Clock clock, + Settings settings, + User user + ) { + super( + clusterService, + client, + clientUtil, + listener, + config, + requestTimeout, + xContentRegistry, + searchFeatureDao, + validationType, + clock, + settings, + user, + AnalysisType.AD + ); + } + +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java index 614d47bee..13e0ab1e3 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java @@ -12,80 +12,46 @@ package org.opensearch.ad.rest.handler; import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; -import static org.opensearch.timeseries.util.ParseUtils.listEqualsWithoutConsideringOrder; -import static org.opensearch.timeseries.util.ParseUtils.parseAggregators; -import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; -import static org.opensearch.timeseries.util.RestHandlerUtils.isExceptionCausedByInvalidQuery; import java.io.IOException; import java.time.Clock; import java.time.Instant; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; -import org.apache.commons.lang.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsAction; -import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsRequest; -import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsResponse; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.IndicesOptions; import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.rest.RestValidateAnomalyDetectorAction; -import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; -import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.rest.RestRequest; -import org.opensearch.search.aggregations.AggregatorFactories; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.common.exception.ValidationException; -import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.feature.SearchFeatureDao; -import org.opensearch.timeseries.model.Feature; -import org.opensearch.timeseries.model.MergeableList; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; -import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; +import org.opensearch.timeseries.transport.ValidateConfigResponse; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; @@ -116,48 +82,17 @@ * instantiate the ModelValidationActionHandler class and run the non-blocker validation logic

* */ -public abstract class AbstractAnomalyDetectorActionHandler { - public static final String EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG = "Can't create more than %d multi-entity anomaly detectors."; - public static final String EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG = - "Can't create more than %d single-entity anomaly detectors."; - public static final String NO_DOCS_IN_USER_INDEX_MSG = "Can't create anomaly detector as no document is found in the indices: "; - public static final String ONLY_ONE_CATEGORICAL_FIELD_ERR_MSG = "We can have only one categorical field."; - public static final String CATEGORICAL_FIELD_TYPE_ERR_MSG = "A categorical field must be of type keyword or ip."; - public static final String CATEGORY_NOT_FOUND_ERR_MSG = "Can't find the categorical field %s"; - public static final String DUPLICATE_DETECTOR_MSG = "Cannot create anomaly detector with name [%s] as it's already used by detector %s"; - public static final String NAME_REGEX = "[a-zA-Z0-9._-]+"; - public static final Integer MAX_DETECTOR_NAME_SIZE = 64; - private static final Set DEFAULT_VALIDATION_ASPECTS = Sets.newHashSet(ValidationAspect.DETECTOR); - - public static String INVALID_NAME_SIZE = "Name should be shortened. The maximum limit is " + MAX_DETECTOR_NAME_SIZE + " characters."; - - protected final ADIndexManagement anomalyDetectionIndices; - protected final String detectorId; - protected final Long seqNo; - protected final Long primaryTerm; - protected final WriteRequest.RefreshPolicy refreshPolicy; - protected final AnomalyDetector anomalyDetector; - protected final ClusterService clusterService; - +public abstract class AbstractAnomalyDetectorActionHandler extends + AbstractTimeSeriesActionHandler { protected final Logger logger = LogManager.getLogger(AbstractAnomalyDetectorActionHandler.class); - protected final TimeValue requestTimeout; - protected final Integer maxSingleEntityAnomalyDetectors; - protected final Integer maxMultiEntityAnomalyDetectors; - protected final Integer maxAnomalyFeatures; - protected final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); - protected final RestRequest.Method method; - protected final Client client; - protected final SecurityClientUtil clientUtil; - protected final TransportService transportService; - protected final NamedXContentRegistry xContentRegistry; - protected final ActionListener listener; - protected final User user; - protected final ADTaskManager adTaskManager; - protected final SearchFeatureDao searchFeatureDao; - protected final boolean isDryRun; - protected final Clock clock; - protected final String validationType; - protected final Settings settings; + + public static final String EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG = "Can't create more than %d HC anomaly detectors."; + public static final String EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG = + "Can't create more than %d single-stream anomaly detectors."; + public static final String NO_DOCS_IN_USER_INDEX_MSG = "Can't create anomaly detector as no document is found in the indices: "; + public static final String DUPLICATE_DETECTOR_MSG = + "Cannot create anomaly detector with name [%s] as it's already used by another detector"; + public static final String VALIDATION_FEATURE_FAILURE = "Validation failed for feature(s) of detector %s"; /** * Constructor function. @@ -166,7 +101,6 @@ public abstract class AbstractAnomalyDetectorActionHandler listener, ADIndexManagement anomalyDetectionIndices, String detectorId, Long seqNo, Long primaryTerm, WriteRequest.RefreshPolicy refreshPolicy, - AnomalyDetector anomalyDetector, + Config anomalyDetector, TimeValue requestTimeout, - Integer maxSingleEntityAnomalyDetectors, - Integer maxMultiEntityAnomalyDetectors, - Integer maxAnomalyFeatures, + Integer maxSingleStreamAnomalyDetectors, + Integer maxHCAnomalyDetectors, + Integer maxFeatures, + Integer maxCategoricalFields, RestRequest.Method method, NamedXContentRegistry xContentRegistry, User user, @@ -213,746 +148,136 @@ public AbstractAnomalyDetectorActionHandler( Clock clock, Settings settings ) { - this.clusterService = clusterService; - this.client = client; - this.clientUtil = clientUtil; - this.transportService = transportService; - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.listener = listener; - this.detectorId = detectorId; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; - this.refreshPolicy = refreshPolicy; - this.anomalyDetector = anomalyDetector; - this.requestTimeout = requestTimeout; - this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; - this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; - this.maxAnomalyFeatures = maxAnomalyFeatures; - this.method = method; - this.xContentRegistry = xContentRegistry; - this.user = user; - this.adTaskManager = adTaskManager; - this.searchFeatureDao = searchFeatureDao; - this.validationType = validationType; - this.isDryRun = isDryRun; - this.clock = clock; - this.settings = settings; - } - - /** - * Start function to process create/update/validate anomaly detector request. - * If detector is not using custom result index, check if anomaly detector - * index exist first, if not, will create first. Otherwise, check if custom - * result index exists or not. If exists, will check if index mapping matches - * AD result index mapping and if user has correct permission to write index. - * If doesn't exist, will create custom result index with AD result index - * mapping. - */ - public void start() { - String resultIndex = anomalyDetector.getCustomResultIndex(); - // use default detector result index which is system index - if (resultIndex == null) { - createOrUpdateDetector(); - return; - } - - if (this.isDryRun) { - if (anomalyDetectionIndices.doesIndexExist(resultIndex)) { - anomalyDetectionIndices - .validateCustomResultIndexAndExecute( - resultIndex, - () -> createOrUpdateDetector(), - ActionListener.wrap(r -> createOrUpdateDetector(), ex -> { - logger.error(ex); - listener - .onFailure( - new ValidationException(ex.getMessage(), ValidationIssueType.RESULT_INDEX, ValidationAspect.DETECTOR) - ); - return; - }) - ); - return; - } else { - createOrUpdateDetector(); - return; - } - } - // use custom result index if not validating and resultIndex not null - anomalyDetectionIndices.initCustomResultIndexAndExecute(resultIndex, () -> createOrUpdateDetector(), listener); - } - - // if isDryRun is true then this method is being executed through Validation API meaning actual - // index won't be created, only validation checks will be executed throughout the class - private void createOrUpdateDetector() { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - if (!anomalyDetectionIndices.doesConfigIndexExist() && !this.isDryRun) { - logger.info("AnomalyDetector Indices do not exist"); - anomalyDetectionIndices - .initConfigIndex( - ActionListener - .wrap(response -> onCreateMappingsResponse(response, false), exception -> listener.onFailure(exception)) - ); - } else { - logger.info("AnomalyDetector Indices do exist, calling prepareAnomalyDetectorIndexing"); - logger.info("DryRun variable " + this.isDryRun); - validateDetectorName(this.isDryRun); - } - } catch (Exception e) { - logger.error("Failed to create or update detector " + detectorId, e); - listener.onFailure(e); - } - } - - // These validation checks are executed here and not in AnomalyDetector.parse() - // in order to not break any past detectors that were made with invalid names - // because it was never check on the backend in the past - protected void validateDetectorName(boolean indexingDryRun) { - if (!anomalyDetector.getName().matches(NAME_REGEX)) { - listener.onFailure(new ValidationException(CommonMessages.INVALID_NAME, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); - return; - - } - if (anomalyDetector.getName().length() > MAX_DETECTOR_NAME_SIZE) { - listener.onFailure(new ValidationException(INVALID_NAME_SIZE, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); - return; - } - validateTimeField(indexingDryRun); - } - - protected void validateTimeField(boolean indexingDryRun) { - String givenTimeField = anomalyDetector.getTimeField(); - GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); - getMappingsRequest.indices(anomalyDetector.getIndices().toArray(new String[0])).fields(givenTimeField); - getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); - - // comments explaining fieldMappingResponse parsing can be found inside following method: - // AbstractAnomalyDetectorActionHandler.validateCategoricalField(String, boolean) - ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { - boolean foundField = false; - Map> mappingsByIndex = getMappingsResponse.mappings(); - - for (Map mappingsByField : mappingsByIndex.values()) { - for (Map.Entry field2Metadata : mappingsByField.entrySet()) { - - GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); - if (fieldMetadata != null) { - // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field - Map fieldMap = fieldMetadata.sourceAsMap(); - if (fieldMap != null) { - for (Object type : fieldMap.values()) { - if (type instanceof Map) { - foundField = true; - Map metadataMap = (Map) type; - String typeName = (String) metadataMap.get(CommonName.TYPE); - if (!typeName.equals(CommonName.DATE_TYPE)) { - listener - .onFailure( - new ValidationException( - String.format(Locale.ROOT, CommonMessages.INVALID_TIMESTAMP, givenTimeField), - ValidationIssueType.TIMEFIELD_FIELD, - ValidationAspect.DETECTOR - ) - ); - return; - } - } - } - } - } - } - } - if (!foundField) { - listener - .onFailure( - new ValidationException( - String.format(Locale.ROOT, CommonMessages.NON_EXISTENT_TIMESTAMP, givenTimeField), - ValidationIssueType.TIMEFIELD_FIELD, - ValidationAspect.DETECTOR - ) - ); - return; - } - prepareAnomalyDetectorIndexing(indexingDryRun); - }, error -> { - String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", anomalyDetector.getIndices()); - logger.error(message, error); - listener.onFailure(new IllegalArgumentException(message)); - }); - clientUtil - .executeWithInjectedSecurity( - GetFieldMappingsAction.INSTANCE, - getMappingsRequest, - user, - client, - AnalysisType.AD, - mappingsListener - ); - } - - /** - * Prepare for indexing a new anomaly detector. - * @param indexingDryRun if this is dryrun for indexing; when validation, it is true; when create/update, it is false - */ - protected void prepareAnomalyDetectorIndexing(boolean indexingDryRun) { - if (method == RestRequest.Method.PUT) { - handler - .getDetectorJob( - clusterService, - client, - detectorId, - listener, - () -> updateAnomalyDetector(detectorId, indexingDryRun), - xContentRegistry - ); - } else { - createAnomalyDetector(indexingDryRun); - } - } - - protected void updateAnomalyDetector(String detectorId, boolean indexingDryRun) { - GetRequest request = new GetRequest(CommonName.CONFIG_INDEX, detectorId); - client - .get( - request, - ActionListener - .wrap( - response -> onGetAnomalyDetectorResponse(response, indexingDryRun, detectorId), - exception -> listener.onFailure(exception) - ) - ); - } - - private void onGetAnomalyDetectorResponse(GetResponse response, boolean indexingDryRun, String detectorId) { - if (!response.isExists()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - try (XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetector existingDetector = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); - // If detector category field changed, frontend may not be able to render AD result for different detector types correctly. - // For example, if detector changed from HC to single entity detector, AD result page may show multiple anomaly - // result points on the same time point if there are multiple entities have anomaly results. - // If single-category HC changed category field from IP to error type, the AD result page may show both IP and error type - // in top N entities list. That's confusing. - // So we decide to block updating detector category field. - if (!listEqualsWithoutConsideringOrder(existingDetector.getCategoryFields(), anomalyDetector.getCategoryFields())) { - listener.onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CATEGORY_FIELD, RestStatus.BAD_REQUEST)); - return; - } - if (!Objects.equals(existingDetector.getCustomResultIndex(), anomalyDetector.getCustomResultIndex())) { - listener - .onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CUSTOM_RESULT_INDEX, RestStatus.BAD_REQUEST)); - return; - } - - adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, HISTORICAL_DETECTOR_TASK_TYPES, (adTask) -> { - if (adTask.isPresent() && !adTask.get().isDone()) { - // can't update detector if there is AD task running - listener.onFailure(new OpenSearchStatusException("Detector is running", RestStatus.INTERNAL_SERVER_ERROR)); - } else { - validateExistingDetector(existingDetector, indexingDryRun); - } - }, transportService, true, listener); - } catch (IOException e) { - String message = "Failed to parse anomaly detector " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - - } - - protected void validateExistingDetector(AnomalyDetector existingDetector, boolean indexingDryRun) { - if (!hasCategoryField(existingDetector) && hasCategoryField(this.anomalyDetector)) { - validateAgainstExistingMultiEntityAnomalyDetector(detectorId, indexingDryRun); - } else { - validateCategoricalField(detectorId, indexingDryRun); - } - } - - protected boolean hasCategoryField(AnomalyDetector detector) { - return detector.getCategoryFields() != null && !detector.getCategoryFields().isEmpty(); - } - - protected void validateAgainstExistingMultiEntityAnomalyDetector(String detectorId, boolean indexingDryRun) { - if (anomalyDetectionIndices.doesConfigIndexExist()) { - QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(AnomalyDetector.CATEGORY_FIELD)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); - - SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); - client - .search( - searchRequest, - ActionListener - .wrap( - response -> onSearchMultiEntityAdResponse(response, detectorId, indexingDryRun), - exception -> listener.onFailure(exception) - ) - ); - } else { - validateCategoricalField(detectorId, indexingDryRun); - } - - } - - protected void createAnomalyDetector(boolean indexingDryRun) { - try { - List categoricalFields = anomalyDetector.getCategoryFields(); - if (categoricalFields != null && categoricalFields.size() > 0) { - validateAgainstExistingMultiEntityAnomalyDetector(null, indexingDryRun); - } else { - if (anomalyDetectionIndices.doesConfigIndexExist()) { - QueryBuilder query = QueryBuilders.matchAllQuery(); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); - - SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); - - client - .search( - searchRequest, - ActionListener - .wrap( - response -> onSearchSingleEntityAdResponse(response, indexingDryRun), - exception -> listener.onFailure(exception) - ) - ); - } else { - searchAdInputIndices(null, indexingDryRun); - } - - } - } catch (Exception e) { - listener.onFailure(e); - } - } - - protected void onSearchSingleEntityAdResponse(SearchResponse response, boolean indexingDryRun) throws IOException { - if (response.getHits().getTotalHits().value >= maxSingleEntityAnomalyDetectors) { - String errorMsgSingleEntity = String - .format(Locale.ROOT, EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, maxSingleEntityAnomalyDetectors); - logger.error(errorMsgSingleEntity); - if (indexingDryRun) { - listener - .onFailure( - new ValidationException(errorMsgSingleEntity, ValidationIssueType.GENERAL_SETTINGS, ValidationAspect.DETECTOR) - ); - return; - } - listener.onFailure(new IllegalArgumentException(errorMsgSingleEntity)); - } else { - searchAdInputIndices(null, indexingDryRun); - } - } - - protected void onSearchMultiEntityAdResponse(SearchResponse response, String detectorId, boolean indexingDryRun) throws IOException { - if (response.getHits().getTotalHits().value >= maxMultiEntityAnomalyDetectors) { - String errorMsg = String.format(Locale.ROOT, EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, maxMultiEntityAnomalyDetectors); - logger.error(errorMsg); - if (indexingDryRun) { - listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.GENERAL_SETTINGS, ValidationAspect.DETECTOR)); - return; - } - listener.onFailure(new IllegalArgumentException(errorMsg)); - } else { - validateCategoricalField(detectorId, indexingDryRun); - } - } - - @SuppressWarnings("unchecked") - protected void validateCategoricalField(String detectorId, boolean indexingDryRun) { - List categoryField = anomalyDetector.getCategoryFields(); - - if (categoryField == null) { - searchAdInputIndices(detectorId, indexingDryRun); - return; - } - - // we only support a certain number of categorical field - // If there is more fields than required, AnomalyDetector's constructor - // throws ADValidationException before reaching this line - int maxCategoryFields = ADNumericSetting.maxCategoricalFields(); - if (categoryField.size() > maxCategoryFields) { - listener - .onFailure( - new ValidationException( - CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields), - ValidationIssueType.CATEGORY, - ValidationAspect.DETECTOR - ) - ); - return; - } - - String categoryField0 = categoryField.get(0); - - GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); - getMappingsRequest.indices(anomalyDetector.getIndices().toArray(new String[0])).fields(categoryField.toArray(new String[0])); - getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); - - ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { - // example getMappingsResponse: - // GetFieldMappingsResponse{mappings={server-metrics={_doc={service=FieldMappingMetadata{fullName='service', - // source=org.opensearch.core.common.bytes.BytesArray@7ba87dbd}}}}} - // for nested field, it would be - // GetFieldMappingsResponse{mappings={server-metrics={_doc={host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', - // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08}}}}} - boolean foundField = false; - - // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata - Map> mappingsByIndex = getMappingsResponse.mappings(); - - for (Map mappingsByField : mappingsByIndex.values()) { - for (Map.Entry field2Metadata : mappingsByField.entrySet()) { - // example output: - // host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', - // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08} - - // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata - - GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); - - if (fieldMetadata != null) { - // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field - Map fieldMap = fieldMetadata.sourceAsMap(); - if (fieldMap != null) { - for (Object type : fieldMap.values()) { - if (type != null && type instanceof Map) { - foundField = true; - Map metadataMap = (Map) type; - String typeName = (String) metadataMap.get(CommonName.TYPE); - if (!typeName.equals(CommonName.KEYWORD_TYPE) && !typeName.equals(CommonName.IP_TYPE)) { - listener - .onFailure( - new ValidationException( - CATEGORICAL_FIELD_TYPE_ERR_MSG, - ValidationIssueType.CATEGORY, - ValidationAspect.DETECTOR - ) - ); - return; - } - } - } - } - - } - } - } - - if (foundField == false) { - listener - .onFailure( - new ValidationException( - String.format(Locale.ROOT, CATEGORY_NOT_FOUND_ERR_MSG, categoryField0), - ValidationIssueType.CATEGORY, - ValidationAspect.DETECTOR - ) - ); - return; - } - - searchAdInputIndices(detectorId, indexingDryRun); - }, error -> { - String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", anomalyDetector.getIndices()); - logger.error(message, error); - listener.onFailure(new IllegalArgumentException(message)); - }); + super( + anomalyDetector, + anomalyDetectionIndices, + isDryRun, + client, + detectorId, + clientUtil, + user, + method, + clusterService, + xContentRegistry, + transportService, + requestTimeout, + refreshPolicy, + seqNo, + primaryTerm, + validationType, + searchFeatureDao, + maxFeatures, + maxCategoricalFields, + AnalysisType.AD, + adTaskManager, + HISTORICAL_DETECTOR_TASK_TYPES, + false, + maxSingleStreamAnomalyDetectors, + maxHCAnomalyDetectors, + clock, + settings + ); - clientUtil - .executeWithInjectedSecurity( - GetFieldMappingsAction.INSTANCE, - getMappingsRequest, - user, - client, - AnalysisType.AD, - mappingsListener - ); } - protected void searchAdInputIndices(String detectorId, boolean indexingDryRun) { - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .query(QueryBuilders.matchAllQuery()) - .size(0) - .timeout(requestTimeout); - - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - - ActionListener searchResponseListener = ActionListener - .wrap( - searchResponse -> onSearchAdInputIndicesResponse(searchResponse, detectorId, indexingDryRun), - exception -> listener.onFailure(exception) - ); - - clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, AnalysisType.AD, searchResponseListener); + @Override + protected TimeSeriesException createValidationException(String msg, ValidationIssueType type) { + return new ValidationException(msg, type, ValidationAspect.DETECTOR); } - protected void onSearchAdInputIndicesResponse(SearchResponse response, String detectorId, boolean indexingDryRun) throws IOException { - if (response.getHits().getTotalHits().value == 0) { - String errorMsg = NO_DOCS_IN_USER_INDEX_MSG + Arrays.toString(anomalyDetector.getIndices().toArray(new String[0])); - logger.error(errorMsg); - if (indexingDryRun) { - listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.INDICES, ValidationAspect.DETECTOR)); - return; - } - listener.onFailure(new IllegalArgumentException(errorMsg)); - } else { - validateAnomalyDetectorFeatures(detectorId, indexingDryRun); - } + @Override + protected AnomalyDetector parse(XContentParser parser, GetResponse response) throws IOException { + return AnomalyDetector.parse(parser, response.getId(), response.getVersion()); } - protected void checkADNameExists(String detectorId, boolean indexingDryRun) throws IOException { - if (anomalyDetectionIndices.doesConfigIndexExist()) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - // src/main/resources/mappings/anomaly-detectors.json#L14 - boolQueryBuilder.must(QueryBuilders.termQuery("name.keyword", anomalyDetector.getName())); - if (StringUtils.isNotBlank(detectorId)) { - boolQueryBuilder.mustNot(QueryBuilders.termQuery(RestHandlerUtils._ID, detectorId)); - } - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQueryBuilder).timeout(requestTimeout); - SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); - client - .search( - searchRequest, - ActionListener - .wrap( - searchResponse -> onSearchADNameResponse(searchResponse, detectorId, anomalyDetector.getName(), indexingDryRun), - exception -> listener.onFailure(exception) - ) - ); - } else { - tryIndexingAnomalyDetector(indexingDryRun); - } - + @Override + protected String getExceedMaxSingleStreamConfigsErrorMsg(int maxSingleStreamConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG, getMaxSingleStreamConfigs()); } - protected void onSearchADNameResponse(SearchResponse response, String detectorId, String name, boolean indexingDryRun) - throws IOException { - if (response.getHits().getTotalHits().value > 0) { - String errorMsg = String - .format( - Locale.ROOT, - DUPLICATE_DETECTOR_MSG, - name, - Arrays.stream(response.getHits().getHits()).map(hit -> hit.getId()).collect(Collectors.toList()) - ); - logger.warn(errorMsg); - listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); - } else { - tryIndexingAnomalyDetector(indexingDryRun); - } + @Override + protected String getExceedMaxHCConfigsErrorMsg(int maxHCConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG, getMaxHCConfigs()); } - protected void tryIndexingAnomalyDetector(boolean indexingDryRun) throws IOException { - if (!indexingDryRun) { - indexAnomalyDetector(detectorId); - } else { - finishDetectorValidationOrContinueToModelValidation(); - } + @Override + protected String getNoDocsInUserIndexErrorMsg(String suppliedIndices) { + return String.format(Locale.ROOT, NO_DOCS_IN_USER_INDEX_MSG, suppliedIndices); } - protected Set getValidationTypes(String validationType) { - if (StringUtils.isBlank(validationType)) { - return DEFAULT_VALIDATION_ASPECTS; - } else { - Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); - return ValidationAspect - .getNames(Sets.intersection(RestValidateAnomalyDetectorAction.ALL_VALIDATION_ASPECTS_STRS, typesInRequest)); - } + @Override + protected String getDuplicateConfigErrorMsg(String name) { + return String.format(Locale.ROOT, DUPLICATE_DETECTOR_MSG, name); } - protected void finishDetectorValidationOrContinueToModelValidation() { - logger.info("Skipping indexing detector. No blocking issue found so far."); - if (!getValidationTypes(validationType).contains(ValidationAspect.MODEL)) { - listener.onResponse(null); - } else { - ModelValidationActionHandler modelValidationActionHandler = new ModelValidationActionHandler( - clusterService, - client, - clientUtil, - (ActionListener) listener, - anomalyDetector, - requestTimeout, - xContentRegistry, - searchFeatureDao, - validationType, - clock, - settings, - user - ); - modelValidationActionHandler.checkIfMultiEntityDetector(); - } - } - - @SuppressWarnings("unchecked") - protected void indexAnomalyDetector(String detectorId) throws IOException { - AnomalyDetector detector = new AnomalyDetector( - anomalyDetector.getId(), - anomalyDetector.getVersion(), - anomalyDetector.getName(), - anomalyDetector.getDescription(), - anomalyDetector.getTimeField(), - anomalyDetector.getIndices(), - anomalyDetector.getFeatureAttributes(), - anomalyDetector.getFilterQuery(), - anomalyDetector.getInterval(), - anomalyDetector.getWindowDelay(), - anomalyDetector.getShingleSize(), - anomalyDetector.getUiMetadata(), - anomalyDetector.getSchemaVersion(), + @Override + protected AnomalyDetector copyConfig(User user, Config config) { + AnomalyDetector detector = (AnomalyDetector) config; + return new AnomalyDetector( + config.getId(), + config.getVersion(), + config.getName(), + config.getDescription(), + config.getTimeField(), + config.getIndices(), + config.getFeatureAttributes(), + config.getFilterQuery(), + config.getInterval(), + config.getWindowDelay(), + config.getShingleSize(), + config.getUiMetadata(), + config.getSchemaVersion(), Instant.now(), - anomalyDetector.getCategoryFields(), + config.getCategoryFields(), user, - anomalyDetector.getCustomResultIndex(), - anomalyDetector.getImputationOption() + config.getCustomResultIndex(), + config.getImputationOption(), + config.getRecencyEmphasis(), + config.getSeasonIntervals(), + config.getHistoryIntervals(), + detector.getRules() ); - IndexRequest indexRequest = new IndexRequest(CommonName.CONFIG_INDEX) - .setRefreshPolicy(refreshPolicy) - .source(detector.toXContent(XContentFactory.jsonBuilder(), XCONTENT_WITH_TYPE)) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm) - .timeout(requestTimeout); - if (StringUtils.isNotBlank(detectorId)) { - indexRequest.id(detectorId); - } - - client.index(indexRequest, new ActionListener() { - @Override - public void onResponse(IndexResponse indexResponse) { - String errorMsg = checkShardsFailure(indexResponse); - if (errorMsg != null) { - listener.onFailure(new OpenSearchStatusException(errorMsg, indexResponse.status())); - return; - } - listener - .onResponse( - (T) new IndexAnomalyDetectorResponse( - indexResponse.getId(), - indexResponse.getVersion(), - indexResponse.getSeqNo(), - indexResponse.getPrimaryTerm(), - detector, - RestStatus.CREATED - ) - ); - } - - @Override - public void onFailure(Exception e) { - logger.warn("Failed to update detector", e); - if (e.getMessage() != null && e.getMessage().contains("version conflict")) { - listener - .onFailure( - new IllegalArgumentException("There was a problem updating the historical detector:[" + detectorId + "]") - ); - } else { - listener.onFailure(e); - } - } - }); } - protected void onCreateMappingsResponse(CreateIndexResponse response, boolean indexingDryRun) throws IOException { - if (response.isAcknowledged()) { - logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); - prepareAnomalyDetectorIndexing(indexingDryRun); - } else { - logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); - listener - .onFailure( - new OpenSearchStatusException( - "Created " + CommonName.CONFIG_INDEX + "with mappings call not acknowledged.", - RestStatus.INTERNAL_SERVER_ERROR - ) - ); - } + @SuppressWarnings("unchecked") + @Override + protected T createIndexConfigResponse(IndexResponse indexResponse, Config config) { + return (T) new IndexAnomalyDetectorResponse( + indexResponse.getId(), + indexResponse.getVersion(), + indexResponse.getSeqNo(), + indexResponse.getPrimaryTerm(), + (AnomalyDetector) config, + RestStatus.CREATED + ); } - protected String checkShardsFailure(IndexResponse response) { - StringBuilder failureReasons = new StringBuilder(); - if (response.getShardInfo().getFailed() > 0) { - for (ReplicationResponse.ShardInfo.Failure failure : response.getShardInfo().getFailures()) { - failureReasons.append(failure); - } - return failureReasons.toString(); - } - return null; + @Override + protected Set getDefaultValidationType() { + return Sets.newHashSet(ValidationAspect.DETECTOR); } - /** - * Validate config/syntax, and runtime error of detector features - * @param detectorId detector id - * @param indexingDryRun if false, then will eventually index detector; true, skip indexing detector - * @throws IOException when fail to parse feature aggregation - */ - // TODO: move this method to util class so that it can be re-usable for more use cases - // https://github.com/opensearch-project/anomaly-detection/issues/39 - protected void validateAnomalyDetectorFeatures(String detectorId, boolean indexingDryRun) throws IOException { - if (anomalyDetector != null - && (anomalyDetector.getFeatureAttributes() == null || anomalyDetector.getFeatureAttributes().isEmpty())) { - checkADNameExists(detectorId, indexingDryRun); - return; - } - // checking configuration/syntax error of detector features - String error = RestHandlerUtils.checkFeaturesSyntax(anomalyDetector, maxAnomalyFeatures); - if (StringUtils.isNotBlank(error)) { - if (indexingDryRun) { - listener.onFailure(new ValidationException(error, ValidationIssueType.FEATURE_ATTRIBUTES, ValidationAspect.DETECTOR)); - return; - } - listener.onFailure(new OpenSearchStatusException(error, RestStatus.BAD_REQUEST)); - return; - } - // checking runtime error from feature query - ActionListener>> validateFeatureQueriesListener = ActionListener.wrap(response -> { - checkADNameExists(detectorId, indexingDryRun); - }, exception -> { - listener - .onFailure( - new ValidationException(exception.getMessage(), ValidationIssueType.FEATURE_ATTRIBUTES, ValidationAspect.DETECTOR) - ); - }); - MultiResponsesDelegateActionListener>> multiFeatureQueriesResponseListener = - new MultiResponsesDelegateActionListener>>( - validateFeatureQueriesListener, - anomalyDetector.getFeatureAttributes().size(), - String.format(Locale.ROOT, "Validation failed for feature(s) of detector %s", anomalyDetector.getName()), - false - ); + @Override + protected String getFeatureErrorMsg(String name) { + return String.format(Locale.ROOT, VALIDATION_FEATURE_FAILURE, name); + } - for (Feature feature : anomalyDetector.getFeatureAttributes()) { - SearchSourceBuilder ssb = new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery()); - AggregatorFactories.Builder internalAgg = parseAggregators( - feature.getAggregation().toString(), - xContentRegistry, - feature.getId() - ); - ssb.aggregation(internalAgg.getAggregatorFactories().iterator().next()); - SearchRequest searchRequest = new SearchRequest().indices(anomalyDetector.getIndices().toArray(new String[0])).source(ssb); - ActionListener searchResponseListener = ActionListener.wrap(response -> { - Optional aggFeatureResult = searchFeatureDao.parseResponse(response, Arrays.asList(feature.getId())); - if (aggFeatureResult.isPresent()) { - multiFeatureQueriesResponseListener - .onResponse( - new MergeableList>(new ArrayList>(Arrays.asList(aggFeatureResult))) - ); - } else { - String errorMessage = CommonMessages.FEATURE_WITH_EMPTY_DATA_MSG + feature.getName(); - logger.error(errorMessage); - multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); - } - }, e -> { - String errorMessage; - if (isExceptionCausedByInvalidQuery(e)) { - errorMessage = CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG + feature.getName(); - } else { - errorMessage = CommonMessages.UNKNOWN_SEARCH_QUERY_EXCEPTION_MSG + feature.getName(); - } - logger.error(errorMessage, e); - multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST, e)); - }); - clientUtil - .asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, AnalysisType.AD, searchResponseListener); - } + @Override + protected void validateModel(ActionListener listener) { + ADModelValidationActionHandler modelValidationActionHandler = new ADModelValidationActionHandler( + clusterService, + client, + clientUtil, + (ActionListener) listener, + (AnomalyDetector) config, + requestTimeout, + xContentRegistry, + searchFeatureDao, + validationType, + clock, + settings, + user + ); + modelValidationActionHandler.start(); } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java deleted file mode 100644 index 28e68d0fb..000000000 --- a/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.rest.handler; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -import java.io.IOException; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.function.ExecutorFunction; -import org.opensearch.timeseries.model.Job; -import org.opensearch.timeseries.util.RestHandlerUtils; - -/** - * Common handler to process AD request. - */ -public class AnomalyDetectorActionHandler { - - private final Logger logger = LogManager.getLogger(AnomalyDetectorActionHandler.class); - - /** - * Get detector job for update/delete AD job. - * If AD job exist, will return error message; otherwise, execute function. - * - * @param clusterService ES cluster service - * @param client ES node client - * @param detectorId detector identifier - * @param listener Listener to send response - * @param function AD function - * @param xContentRegistry Registry which is used for XContentParser - */ - public void getDetectorJob( - ClusterService clusterService, - Client client, - String detectorId, - ActionListener listener, - ExecutorFunction function, - NamedXContentRegistry xContentRegistry - ) { - if (clusterService.state().metadata().indices().containsKey(CommonName.JOB_INDEX)) { - GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - client - .get( - request, - ActionListener - .wrap(response -> onGetAdJobResponseForWrite(response, listener, function, xContentRegistry), exception -> { - logger.error("Fail to get anomaly detector job: " + detectorId, exception); - listener.onFailure(exception); - }) - ); - } else { - function.execute(); - } - } - - private void onGetAdJobResponseForWrite( - GetResponse response, - ActionListener listener, - ExecutorFunction function, - NamedXContentRegistry xContentRegistry - ) { - if (response.isExists()) { - String adJobId = response.getId(); - if (adJobId != null) { - // check if AD job is running on the detector, if yes, we can't delete the detector - try ( - XContentParser parser = RestHandlerUtils - .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job adJob = Job.parse(parser); - if (adJob.isEnabled()) { - listener.onFailure(new OpenSearchStatusException("Detector job is running: " + adJobId, RestStatus.BAD_REQUEST)); - return; - } - } catch (IOException e) { - String message = "Failed to parse anomaly detector job " + adJobId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.BAD_REQUEST)); - } - } - } - function.execute(); - } -} diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java index bed6a7998..a600d8750 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java @@ -21,7 +21,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestRequest; import org.opensearch.timeseries.feature.SearchFeatureDao; @@ -42,7 +41,6 @@ public class IndexAnomalyDetectorActionHandler extends AbstractAnomalyDetectorAc * @param client ES node client that executes actions on the local node * @param clientUtil AD client util * @param transportService ES transport service - * @param listener ES channel used to construct bytes / builder based outputs, and send responses * @param anomalyDetectionIndices anomaly detector index manager * @param detectorId detector identifier * @param seqNo sequence number of last modification @@ -50,9 +48,10 @@ public class IndexAnomalyDetectorActionHandler extends AbstractAnomalyDetectorAc * @param refreshPolicy refresh policy * @param anomalyDetector anomaly detector instance * @param requestTimeout request time out configuration - * @param maxSingleEntityAnomalyDetectors max single-entity anomaly detectors allowed - * @param maxMultiEntityAnomalyDetectors max multi-entity detectors allowed - * @param maxAnomalyFeatures max features allowed per detector + * @param maxSingleStreamDetectors max single-stream anomaly detectors allowed + * @param maxHCDetectors max HC detectors allowed + * @param maxFeatures max features allowed per detector + * @param maxCategoricalFields max number of categorical fields * @param method Rest Method type * @param xContentRegistry Registry which is used for XContentParser * @param user User context @@ -65,7 +64,6 @@ public IndexAnomalyDetectorActionHandler( Client client, SecurityClientUtil clientUtil, TransportService transportService, - ActionListener listener, ADIndexManagement anomalyDetectionIndices, String detectorId, Long seqNo, @@ -73,9 +71,10 @@ public IndexAnomalyDetectorActionHandler( WriteRequest.RefreshPolicy refreshPolicy, AnomalyDetector anomalyDetector, TimeValue requestTimeout, - Integer maxSingleEntityAnomalyDetectors, - Integer maxMultiEntityAnomalyDetectors, - Integer maxAnomalyFeatures, + Integer maxSingleStreamDetectors, + Integer maxHCDetectors, + Integer maxFeatures, + Integer maxCategoricalFields, RestRequest.Method method, NamedXContentRegistry xContentRegistry, User user, @@ -88,7 +87,6 @@ public IndexAnomalyDetectorActionHandler( client, clientUtil, transportService, - listener, anomalyDetectionIndices, detectorId, seqNo, @@ -96,9 +94,10 @@ public IndexAnomalyDetectorActionHandler( refreshPolicy, anomalyDetector, requestTimeout, - maxSingleEntityAnomalyDetectors, - maxMultiEntityAnomalyDetectors, - maxAnomalyFeatures, + maxSingleStreamDetectors, + maxHCDetectors, + maxFeatures, + maxCategoricalFields, method, xContentRegistry, user, @@ -110,12 +109,4 @@ public IndexAnomalyDetectorActionHandler( settings ); } - - /** - * Start function to process create/update anomaly detector request. - */ - @Override - public void start() { - super.start(); - } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java deleted file mode 100644 index 5a3b19f24..000000000 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java +++ /dev/null @@ -1,403 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.rest.handler; - -import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.util.ExceptionUtil.getShardsFailure; -import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; - -import java.io.IOException; -import java.time.Duration; -import java.time.Instant; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.ExecuteADResultResponseRecorder; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyResultAction; -import org.opensearch.ad.transport.AnomalyResultRequest; -import org.opensearch.ad.transport.StopDetectorAction; -import org.opensearch.ad.transport.StopDetectorRequest; -import org.opensearch.ad.transport.StopDetectorResponse; -import org.opensearch.client.Client; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; -import org.opensearch.jobscheduler.spi.schedule.Schedule; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.function.ExecutorFunction; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.model.Job; -import org.opensearch.timeseries.model.TaskState; -import org.opensearch.timeseries.transport.JobResponse; -import org.opensearch.timeseries.util.RestHandlerUtils; -import org.opensearch.transport.TransportService; - -import com.google.common.base.Throwables; - -/** - * Anomaly detector job REST action handler to process POST/PUT request. - */ -public class IndexAnomalyDetectorJobActionHandler { - - private final ADIndexManagement anomalyDetectionIndices; - private final String detectorId; - private final Long seqNo; - private final Long primaryTerm; - private final Client client; - private final NamedXContentRegistry xContentRegistry; - private final TransportService transportService; - private final ADTaskManager adTaskManager; - - private final Logger logger = LogManager.getLogger(IndexAnomalyDetectorJobActionHandler.class); - private final TimeValue requestTimeout; - private final ExecuteADResultResponseRecorder recorder; - - /** - * Constructor function. - * - * @param client ES node client that executes actions on the local node - * @param anomalyDetectionIndices anomaly detector index manager - * @param detectorId detector identifier - * @param seqNo sequence number of last modification - * @param primaryTerm primary term of last modification - * @param requestTimeout request time out configuration - * @param xContentRegistry Registry which is used for XContentParser - * @param transportService transport service - * @param adTaskManager AD task manager - * @param recorder Utility to record AnomalyResultAction execution result - */ - public IndexAnomalyDetectorJobActionHandler( - Client client, - ADIndexManagement anomalyDetectionIndices, - String detectorId, - Long seqNo, - Long primaryTerm, - TimeValue requestTimeout, - NamedXContentRegistry xContentRegistry, - TransportService transportService, - ADTaskManager adTaskManager, - ExecuteADResultResponseRecorder recorder - ) { - this.client = client; - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.detectorId = detectorId; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; - this.requestTimeout = requestTimeout; - this.xContentRegistry = xContentRegistry; - this.transportService = transportService; - this.adTaskManager = adTaskManager; - this.recorder = recorder; - } - - /** - * Start anomaly detector job. - * 1. If job doesn't exist, create new job. - * 2. If job exists: a). if job enabled, return error message; b). if job disabled, enable job. - * @param detector anomaly detector - * @param listener Listener to send responses - */ - public void startAnomalyDetectorJob(AnomalyDetector detector, ActionListener listener) { - // this start listener is created & injected throughout the job handler so that whenever the job response is received, - // there's the extra step of trying to index results and update detector state with a 60s delay. - ActionListener startListener = ActionListener.wrap(r -> { - try { - Instant executionEndTime = Instant.now(); - IntervalTimeConfiguration schedule = (IntervalTimeConfiguration) detector.getInterval(); - Instant executionStartTime = executionEndTime.minus(schedule.getInterval(), schedule.getUnit()); - AnomalyResultRequest getRequest = new AnomalyResultRequest( - detector.getId(), - executionStartTime.toEpochMilli(), - executionEndTime.toEpochMilli() - ); - client - .execute( - AnomalyResultAction.INSTANCE, - getRequest, - ActionListener - .wrap( - response -> recorder.indexAnomalyResult(executionStartTime, executionEndTime, response, detector), - exception -> { - - recorder - .indexAnomalyResultException( - executionStartTime, - executionEndTime, - Throwables.getStackTraceAsString(exception), - null, - detector - ); - } - ) - ); - } catch (Exception ex) { - listener.onFailure(ex); - return; - } - listener.onResponse(r); - - }, listener::onFailure); - if (!anomalyDetectionIndices.doesJobIndexExist()) { - anomalyDetectionIndices.initJobIndex(ActionListener.wrap(response -> { - if (response.isAcknowledged()) { - logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); - createJob(detector, startListener); - } else { - logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); - startListener - .onFailure( - new OpenSearchStatusException( - "Created " + CommonName.CONFIG_INDEX + " with mappings call not acknowledged.", - RestStatus.INTERNAL_SERVER_ERROR - ) - ); - } - }, exception -> startListener.onFailure(exception))); - } else { - createJob(detector, startListener); - } - } - - private void createJob(AnomalyDetector detector, ActionListener listener) { - try { - IntervalTimeConfiguration interval = (IntervalTimeConfiguration) detector.getInterval(); - Schedule schedule = new IntervalSchedule(Instant.now(), (int) interval.getInterval(), interval.getUnit()); - Duration duration = Duration.of(interval.getInterval(), interval.getUnit()); - - Job job = new Job( - detector.getId(), - schedule, - detector.getWindowDelay(), - true, - Instant.now(), - null, - Instant.now(), - duration.getSeconds(), - detector.getUser(), - detector.getCustomResultIndex() - ); - - getJobForWrite(detector, job, listener); - } catch (Exception e) { - String message = "Failed to parse anomaly detector job " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } - - private void getJobForWrite(AnomalyDetector detector, Job job, ActionListener listener) { - GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - - client - .get( - getRequest, - ActionListener - .wrap( - response -> onGetAnomalyDetectorJobForWrite(response, detector, job, listener), - exception -> listener.onFailure(exception) - ) - ); - } - - private void onGetAnomalyDetectorJobForWrite( - GetResponse response, - AnomalyDetector detector, - Job job, - ActionListener listener - ) throws IOException { - if (response.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job currentAdJob = Job.parse(parser); - if (currentAdJob.isEnabled()) { - listener - .onFailure(new OpenSearchStatusException("Anomaly detector job is already running: " + detectorId, RestStatus.OK)); - return; - } else { - Job newJob = new Job( - job.getName(), - job.getSchedule(), - job.getWindowDelay(), - job.isEnabled(), - Instant.now(), - currentAdJob.getDisabledTime(), - Instant.now(), - job.getLockDurationSeconds(), - job.getUser(), - job.getCustomResultIndex() - ); - // Get latest realtime task and check its state before index job. Will reset running realtime task - // as STOPPED first if job disabled, then start new job and create new realtime task. - adTaskManager.startDetector(detector, null, job.getUser(), transportService, ActionListener.wrap(r -> { - indexAnomalyDetectorJob(newJob, null, listener); - }, e -> { - // Have logged error message in ADTaskManager#startDetector - listener.onFailure(e); - })); - } - } catch (IOException e) { - String message = "Failed to parse anomaly detector job " + job.getName(); - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } else { - adTaskManager.startDetector(detector, null, job.getUser(), transportService, ActionListener.wrap(r -> { - indexAnomalyDetectorJob(job, null, listener); - }, e -> listener.onFailure(e))); - } - } - - private void indexAnomalyDetectorJob(Job job, ExecutorFunction function, ActionListener listener) throws IOException { - IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .source(job.toXContent(XContentFactory.jsonBuilder(), RestHandlerUtils.XCONTENT_WITH_TYPE)) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm) - .timeout(requestTimeout) - .id(detectorId); - client - .index( - indexRequest, - ActionListener - .wrap( - response -> onIndexAnomalyDetectorJobResponse(response, function, listener), - exception -> listener.onFailure(exception) - ) - ); - } - - private void onIndexAnomalyDetectorJobResponse( - IndexResponse response, - ExecutorFunction function, - ActionListener listener - ) { - if (response == null || (response.getResult() != CREATED && response.getResult() != UPDATED)) { - String errorMsg = getShardsFailure(response); - listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); - return; - } - if (function != null) { - function.execute(); - } else { - JobResponse anomalyDetectorJobResponse = new JobResponse(response.getId()); - listener.onResponse(anomalyDetectorJobResponse); - } - } - - /** - * Stop anomaly detector job. - * 1.If job not exists, return error message - * 2.If job exists: a).if job state is disabled, return error message; b).if job state is enabled, disable job. - * - * @param detectorId detector identifier - * @param listener Listener to send responses - */ - public void stopAnomalyDetectorJob(String detectorId, ActionListener listener) { - GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - - client.get(getRequest, ActionListener.wrap(response -> { - if (response.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job job = Job.parse(parser); - if (!job.isEnabled()) { - adTaskManager.stopLatestRealtimeTask(detectorId, TaskState.STOPPED, null, transportService, listener); - } else { - Job newJob = new Job( - job.getName(), - job.getSchedule(), - job.getWindowDelay(), - false, - job.getEnabledTime(), - Instant.now(), - Instant.now(), - job.getLockDurationSeconds(), - job.getUser(), - job.getCustomResultIndex() - ); - indexAnomalyDetectorJob( - newJob, - () -> client - .execute( - StopDetectorAction.INSTANCE, - new StopDetectorRequest(detectorId), - stopAdDetectorListener(detectorId, listener) - ), - listener - ); - } - } catch (IOException e) { - String message = "Failed to parse anomaly detector job " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } else { - listener.onFailure(new OpenSearchStatusException("Anomaly detector job not exist: " + detectorId, RestStatus.BAD_REQUEST)); - } - }, exception -> listener.onFailure(exception))); - } - - private ActionListener stopAdDetectorListener(String detectorId, ActionListener listener) { - return new ActionListener() { - @Override - public void onResponse(StopDetectorResponse stopDetectorResponse) { - if (stopDetectorResponse.success()) { - logger.info("AD model deleted successfully for detector {}", detectorId); - // StopDetectorTransportAction will send out DeleteModelAction which will clear all realtime cache. - // Pass null transport service to method "stopLatestRealtimeTask" to not re-clear coordinating node cache. - adTaskManager.stopLatestRealtimeTask(detectorId, TaskState.STOPPED, null, null, listener); - } else { - logger.error("Failed to delete AD model for detector {}", detectorId); - // If failed to clear all realtime cache, will try to re-clear coordinating node cache. - adTaskManager - .stopLatestRealtimeTask( - detectorId, - TaskState.FAILED, - new OpenSearchStatusException("Failed to delete AD model", RestStatus.INTERNAL_SERVER_ERROR), - transportService, - listener - ); - } - } - - @Override - public void onFailure(Exception e) { - logger.error("Failed to delete AD model for detector " + detectorId, e); - // If failed to clear all realtime cache, will try to re-clear coordinating node cache. - adTaskManager - .stopLatestRealtimeTask( - detectorId, - TaskState.FAILED, - new OpenSearchStatusException("Failed to execute stop detector action", RestStatus.INTERNAL_SERVER_ERROR), - transportService, - listener - ); - } - }; - } - -} diff --git a/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java deleted file mode 100644 index f37a10580..000000000 --- a/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java +++ /dev/null @@ -1,840 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ad.rest.handler; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.CONFIG_BUCKET_MINIMUM_SUCCESS_RATE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_INTERVAL_REC_LENGTH_IN_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TIMES_DECREASING_INTERVAL; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.TOP_VALIDATE_TIMEOUT_IN_MILLIS; - -import java.io.IOException; -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.RangeQueryBuilder; -import org.opensearch.search.aggregations.AggregationBuilder; -import org.opensearch.search.aggregations.AggregationBuilders; -import org.opensearch.search.aggregations.Aggregations; -import org.opensearch.search.aggregations.BucketOrder; -import org.opensearch.search.aggregations.PipelineAggregatorBuilders; -import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; -import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; -import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; -import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; -import org.opensearch.search.aggregations.bucket.histogram.Histogram; -import org.opensearch.search.aggregations.bucket.histogram.LongBounds; -import org.opensearch.search.aggregations.bucket.terms.Terms; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.sort.FieldSortBuilder; -import org.opensearch.search.sort.SortOrder; -import org.opensearch.timeseries.AnalysisType; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.ValidationException; -import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.feature.SearchFeatureDao; -import org.opensearch.timeseries.model.Feature; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.model.MergeableList; -import org.opensearch.timeseries.model.TimeConfiguration; -import org.opensearch.timeseries.model.ValidationAspect; -import org.opensearch.timeseries.model.ValidationIssueType; -import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; -import org.opensearch.timeseries.util.ParseUtils; -import org.opensearch.timeseries.util.SecurityClientUtil; - -/** - *

This class executes all validation checks that are not blocking on the 'model' level. - * This mostly involves checking if the data is generally dense enough to complete model training - * which is based on if enough buckets in the last x intervals have at least 1 document present.

- *

Initially different bucket aggregations are executed with with every configuration applied and with - * different varying intervals in order to find the best interval for the data. If no interval is found with all - * configuration applied then each configuration is tested sequentially for sparsity

- */ -// TODO: Add more UT and IT -public class ModelValidationActionHandler { - protected static final String AGG_NAME_TOP = "top_agg"; - protected static final String AGGREGATION = "agg"; - protected final AnomalyDetector anomalyDetector; - protected final ClusterService clusterService; - protected final Logger logger = LogManager.getLogger(AbstractAnomalyDetectorActionHandler.class); - protected final TimeValue requestTimeout; - protected final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); - protected final Client client; - protected final SecurityClientUtil clientUtil; - protected final NamedXContentRegistry xContentRegistry; - protected final ActionListener listener; - protected final SearchFeatureDao searchFeatureDao; - protected final Clock clock; - protected final String validationType; - protected final Settings settings; - protected final User user; - - /** - * Constructor function. - * - * @param clusterService ClusterService - * @param client ES node client that executes actions on the local node - * @param clientUtil AD client util - * @param listener ES channel used to construct bytes / builder based outputs, and send responses - * @param anomalyDetector anomaly detector instance - * @param requestTimeout request time out configuration - * @param xContentRegistry Registry which is used for XContentParser - * @param searchFeatureDao Search feature DAO - * @param validationType Specified type for validation - * @param clock clock object to know when to timeout - * @param settings Node settings - * @param user User info - */ - public ModelValidationActionHandler( - ClusterService clusterService, - Client client, - SecurityClientUtil clientUtil, - ActionListener listener, - AnomalyDetector anomalyDetector, - TimeValue requestTimeout, - NamedXContentRegistry xContentRegistry, - SearchFeatureDao searchFeatureDao, - String validationType, - Clock clock, - Settings settings, - User user - ) { - this.clusterService = clusterService; - this.client = client; - this.clientUtil = clientUtil; - this.listener = listener; - this.anomalyDetector = anomalyDetector; - this.requestTimeout = requestTimeout; - this.xContentRegistry = xContentRegistry; - this.searchFeatureDao = searchFeatureDao; - this.validationType = validationType; - this.clock = clock; - this.settings = settings; - this.user = user; - } - - // Need to first check if multi entity detector or not before doing any sort of validation. - // If detector is HCAD then we will find the top entity and treat as single entity for - // validation purposes - public void checkIfMultiEntityDetector() { - ActionListener> recommendationListener = ActionListener - .wrap(topEntity -> getLatestDateForValidation(topEntity), exception -> { - listener.onFailure(exception); - logger.error("Failed to get top entity for categorical field", exception); - }); - if (anomalyDetector.isHighCardinality()) { - getTopEntity(recommendationListener); - } else { - recommendationListener.onResponse(Collections.emptyMap()); - } - } - - // For single category HCAD, this method uses bucket aggregation and sort to get the category field - // that have the highest document count in order to use that top entity for further validation - // For multi-category HCADs we use a composite aggregation to find the top fields for the entity - // with the highest doc count. - private void getTopEntity(ActionListener> topEntityListener) { - // Look at data back to the lower bound given the max interval we recommend or one given - long maxIntervalInMinutes = Math.max(MAX_INTERVAL_REC_LENGTH_IN_MINUTES, anomalyDetector.getIntervalInMinutes()); - LongBounds timeRangeBounds = getTimeRangeBounds( - Instant.now().toEpochMilli(), - new IntervalTimeConfiguration(maxIntervalInMinutes, ChronoUnit.MINUTES) - ); - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(anomalyDetector.getTimeField()) - .from(timeRangeBounds.getMin()) - .to(timeRangeBounds.getMax()); - AggregationBuilder bucketAggs; - Map topKeys = new HashMap<>(); - if (anomalyDetector.getCategoryFields().size() == 1) { - bucketAggs = AggregationBuilders - .terms(AGG_NAME_TOP) - .field(anomalyDetector.getCategoryFields().get(0)) - .order(BucketOrder.count(true)); - } else { - bucketAggs = AggregationBuilders - .composite( - AGG_NAME_TOP, - anomalyDetector - .getCategoryFields() - .stream() - .map(f -> new TermsValuesSourceBuilder(f).field(f)) - .collect(Collectors.toList()) - ) - .size(1000) - .subAggregation( - PipelineAggregatorBuilders - .bucketSort("bucketSort", Collections.singletonList(new FieldSortBuilder("_count").order(SortOrder.DESC))) - .size(1) - ); - } - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .query(rangeQuery) - .aggregation(bucketAggs) - .trackTotalHits(false) - .size(0); - SearchRequest searchRequest = new SearchRequest() - .indices(anomalyDetector.getIndices().toArray(new String[0])) - .source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener.wrap(response -> { - Aggregations aggs = response.getAggregations(); - if (aggs == null) { - topEntityListener.onResponse(Collections.emptyMap()); - return; - } - if (anomalyDetector.getCategoryFields().size() == 1) { - Terms entities = aggs.get(AGG_NAME_TOP); - Object key = entities - .getBuckets() - .stream() - .max(Comparator.comparingInt(entry -> (int) entry.getDocCount())) - .map(MultiBucketsAggregation.Bucket::getKeyAsString) - .orElse(null); - topKeys.put(anomalyDetector.getCategoryFields().get(0), key); - } else { - CompositeAggregation compositeAgg = aggs.get(AGG_NAME_TOP); - topKeys - .putAll( - compositeAgg - .getBuckets() - .stream() - .flatMap(bucket -> bucket.getKey().entrySet().stream()) // this would create a flattened stream of map entries - .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())) - ); - } - for (Map.Entry entry : topKeys.entrySet()) { - if (entry.getValue() == null) { - topEntityListener.onResponse(Collections.emptyMap()); - return; - } - } - topEntityListener.onResponse(topKeys); - }, topEntityListener::onFailure); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - - private void getLatestDateForValidation(Map topEntity) { - ActionListener> latestTimeListener = ActionListener - .wrap(latest -> getSampleRangesForValidationChecks(latest, anomalyDetector, listener, topEntity), exception -> { - listener.onFailure(exception); - logger.error("Failed to create search request for last data point", exception); - }); - searchFeatureDao.getLatestDataTime(anomalyDetector, latestTimeListener); - } - - private void getSampleRangesForValidationChecks( - Optional latestTime, - AnomalyDetector detector, - ActionListener listener, - Map topEntity - ) { - if (!latestTime.isPresent() || latestTime.get() <= 0) { - listener - .onFailure( - new ValidationException( - CommonMessages.TIME_FIELD_NOT_ENOUGH_HISTORICAL_DATA, - ValidationIssueType.TIMEFIELD_FIELD, - ValidationAspect.MODEL - ) - ); - return; - } - long timeRangeEnd = Math.min(Instant.now().toEpochMilli(), latestTime.get()); - try { - getBucketAggregates(timeRangeEnd, listener, topEntity); - } catch (IOException e) { - listener.onFailure(new EndRunException(detector.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); - } - } - - private void getBucketAggregates( - long latestTime, - ActionListener listener, - Map topEntity - ) throws IOException { - AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); - BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); - if (anomalyDetector.isHighCardinality()) { - if (topEntity.isEmpty()) { - listener - .onFailure( - new ValidationException( - CommonMessages.CATEGORY_FIELD_TOO_SPARSE, - ValidationIssueType.CATEGORY, - ValidationAspect.MODEL - ) - ); - return; - } - for (Map.Entry entry : topEntity.entrySet()) { - query.filter(QueryBuilders.termQuery(entry.getKey(), entry.getValue())); - } - } - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .query(query) - .aggregation(aggregation) - .size(0) - .timeout(requestTimeout); - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - ActionListener intervalListener = ActionListener - .wrap(interval -> processIntervalRecommendation(interval, latestTime), exception -> { - listener.onFailure(exception); - logger.error("Failed to get interval recommendation", exception); - }); - final ActionListener searchResponseListener = - new ModelValidationActionHandler.DetectorIntervalRecommendationListener( - intervalListener, - searchRequest.source(), - (IntervalTimeConfiguration) anomalyDetector.getInterval(), - clock.millis() + TOP_VALIDATE_TIMEOUT_IN_MILLIS, - latestTime, - false, - MAX_TIMES_DECREASING_INTERVAL - ); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - - private double processBucketAggregationResults(Histogram buckets) { - int docCountOverOne = 0; - // For each entry - for (Histogram.Bucket entry : buckets.getBuckets()) { - if (entry.getDocCount() > 0) { - docCountOverOne++; - } - } - return (docCountOverOne / (double) getNumberOfSamples()); - } - - /** - * ActionListener class to handle execution of multiple bucket aggregations one after the other - * Bucket aggregation with different interval lengths are executed one by one to check if the data is dense enough - * We only need to execute the next query if the previous one led to data that is too sparse. - */ - class DetectorIntervalRecommendationListener implements ActionListener { - private final ActionListener intervalListener; - SearchSourceBuilder searchSourceBuilder; - IntervalTimeConfiguration detectorInterval; - private final long expirationEpochMs; - private final long latestTime; - boolean decreasingInterval; - int numTimesDecreasing; // maximum amount of times we will try decreasing interval for recommendation - - DetectorIntervalRecommendationListener( - ActionListener intervalListener, - SearchSourceBuilder searchSourceBuilder, - IntervalTimeConfiguration detectorInterval, - long expirationEpochMs, - long latestTime, - boolean decreasingInterval, - int numTimesDecreasing - ) { - this.intervalListener = intervalListener; - this.searchSourceBuilder = searchSourceBuilder; - this.detectorInterval = detectorInterval; - this.expirationEpochMs = expirationEpochMs; - this.latestTime = latestTime; - this.decreasingInterval = decreasingInterval; - this.numTimesDecreasing = numTimesDecreasing; - } - - @Override - public void onResponse(SearchResponse response) { - try { - Histogram aggregate = checkBucketResultErrors(response); - if (aggregate == null) { - return; - } - - long newIntervalMinute; - if (decreasingInterval) { - newIntervalMinute = (long) Math - .floor( - IntervalTimeConfiguration.getIntervalInMinute(detectorInterval) * INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER - ); - } else { - newIntervalMinute = (long) Math - .ceil( - IntervalTimeConfiguration.getIntervalInMinute(detectorInterval) * INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER - ); - } - double fullBucketRate = processBucketAggregationResults(aggregate); - // If rate is above success minimum then return interval suggestion. - if (fullBucketRate > INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE) { - intervalListener.onResponse(this.detectorInterval); - } else if (expirationEpochMs < clock.millis()) { - listener - .onFailure( - new ValidationException( - CommonMessages.TIMEOUT_ON_INTERVAL_REC, - ValidationIssueType.TIMEOUT, - ValidationAspect.MODEL - ) - ); - logger.info(CommonMessages.TIMEOUT_ON_INTERVAL_REC); - // keep trying higher intervals as new interval is below max, and we aren't decreasing yet - } else if (newIntervalMinute < MAX_INTERVAL_REC_LENGTH_IN_MINUTES && !decreasingInterval) { - searchWithDifferentInterval(newIntervalMinute); - // The below block is executed only the first time when new interval is above max and - // we aren't decreasing yet, at this point we will start decreasing for the first time - // if we are inside the below block - } else if (newIntervalMinute >= MAX_INTERVAL_REC_LENGTH_IN_MINUTES && !decreasingInterval) { - IntervalTimeConfiguration givenInterval = (IntervalTimeConfiguration) anomalyDetector.getInterval(); - this.detectorInterval = new IntervalTimeConfiguration( - (long) Math - .floor( - IntervalTimeConfiguration.getIntervalInMinute(givenInterval) * INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER - ), - ChronoUnit.MINUTES - ); - if (detectorInterval.getInterval() <= 0) { - intervalListener.onResponse(null); - return; - } - this.decreasingInterval = true; - this.numTimesDecreasing -= 1; - // Searching again using an updated interval - SearchSourceBuilder updatedSearchSourceBuilder = getSearchSourceBuilder( - searchSourceBuilder.query(), - getBucketAggregation(this.latestTime, new IntervalTimeConfiguration(newIntervalMinute, ChronoUnit.MINUTES)) - ); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - new SearchRequest() - .indices(anomalyDetector.getIndices().toArray(new String[0])) - .source(updatedSearchSourceBuilder), - client::search, - user, - client, - AnalysisType.AD, - this - ); - // In this case decreasingInterval has to be true already, so we will stop - // when the next new interval is below or equal to 0, or we have decreased up to max times - } else if (numTimesDecreasing >= 0 && newIntervalMinute > 0) { - this.numTimesDecreasing -= 1; - searchWithDifferentInterval(newIntervalMinute); - // this case means all intervals up to max interval recommendation length and down to either - // 0 or until we tried 10 lower intervals than the one given have been tried - // which further means the next step is to go through A/B validation checks - } else { - intervalListener.onResponse(null); - } - - } catch (Exception e) { - onFailure(e); - } - } - - private void searchWithDifferentInterval(long newIntervalMinuteValue) { - this.detectorInterval = new IntervalTimeConfiguration(newIntervalMinuteValue, ChronoUnit.MINUTES); - // Searching again using an updated interval - SearchSourceBuilder updatedSearchSourceBuilder = getSearchSourceBuilder( - searchSourceBuilder.query(), - getBucketAggregation(this.latestTime, new IntervalTimeConfiguration(newIntervalMinuteValue, ChronoUnit.MINUTES)) - ); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - new SearchRequest().indices(anomalyDetector.getIndices().toArray(new String[0])).source(updatedSearchSourceBuilder), - client::search, - user, - client, - AnalysisType.AD, - this - ); - } - - @Override - public void onFailure(Exception e) { - logger.error("Failed to recommend new interval", e); - listener - .onFailure( - new ValidationException( - CommonMessages.MODEL_VALIDATION_FAILED_UNEXPECTEDLY, - ValidationIssueType.AGGREGATION, - ValidationAspect.MODEL - ) - ); - } - } - - private void processIntervalRecommendation(IntervalTimeConfiguration interval, long latestTime) { - // if interval suggestion is null that means no interval could be found with all the configurations - // applied, our next step then is to check density just with the raw data and then add each configuration - // one at a time to try and find root cause of low density - if (interval == null) { - checkRawDataSparsity(latestTime); - } else { - if (interval.equals(anomalyDetector.getInterval())) { - logger.info("Using the current interval there is enough dense data "); - // Check if there is a window delay recommendation if everything else is successful and send exception - if (Instant.now().toEpochMilli() - latestTime > timeConfigToMilliSec(anomalyDetector.getWindowDelay())) { - sendWindowDelayRec(latestTime); - return; - } - // The rate of buckets with at least 1 doc with given interval is above the success rate - listener.onResponse(null); - return; - } - // return response with interval recommendation - listener - .onFailure( - new ValidationException( - CommonMessages.INTERVAL_REC + interval.getInterval(), - ValidationIssueType.DETECTION_INTERVAL, - ValidationAspect.MODEL, - interval - ) - ); - } - } - - private AggregationBuilder getBucketAggregation(long latestTime, IntervalTimeConfiguration detectorInterval) { - return AggregationBuilders - .dateHistogram(AGGREGATION) - .field(anomalyDetector.getTimeField()) - .minDocCount(1) - .hardBounds(getTimeRangeBounds(latestTime, detectorInterval)) - .fixedInterval(DateHistogramInterval.minutes((int) IntervalTimeConfiguration.getIntervalInMinute(detectorInterval))); - } - - private SearchSourceBuilder getSearchSourceBuilder(QueryBuilder query, AggregationBuilder aggregation) { - return new SearchSourceBuilder().query(query).aggregation(aggregation).size(0).timeout(requestTimeout); - } - - private void checkRawDataSparsity(long latestTime) { - AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(aggregation).size(0).timeout(requestTimeout); - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener - .wrap(response -> processRawDataResults(response, latestTime), listener::onFailure); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - - private Histogram checkBucketResultErrors(SearchResponse response) { - Aggregations aggs = response.getAggregations(); - if (aggs == null) { - // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date with - // the large amounts of changes there). For this reason I'm not throwing a SearchException but instead a validation exception - // which will be converted to validation response. - logger.warn("Unexpected null aggregation."); - listener - .onFailure( - new ValidationException( - CommonMessages.MODEL_VALIDATION_FAILED_UNEXPECTEDLY, - ValidationIssueType.AGGREGATION, - ValidationAspect.MODEL - ) - ); - return null; - } - Histogram aggregate = aggs.get(AGGREGATION); - if (aggregate == null) { - listener.onFailure(new IllegalArgumentException("Failed to find valid aggregation result")); - return null; - } - return aggregate; - } - - private void processRawDataResults(SearchResponse response, long latestTime) { - Histogram aggregate = checkBucketResultErrors(response); - if (aggregate == null) { - return; - } - double fullBucketRate = processBucketAggregationResults(aggregate); - if (fullBucketRate < INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE) { - listener - .onFailure( - new ValidationException(CommonMessages.RAW_DATA_TOO_SPARSE, ValidationIssueType.INDICES, ValidationAspect.MODEL) - ); - } else { - checkDataFilterSparsity(latestTime); - } - } - - private void checkDataFilterSparsity(long latestTime) { - AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); - BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); - SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener - .wrap(response -> processDataFilterResults(response, latestTime), listener::onFailure); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - - private void processDataFilterResults(SearchResponse response, long latestTime) { - Histogram aggregate = checkBucketResultErrors(response); - if (aggregate == null) { - return; - } - double fullBucketRate = processBucketAggregationResults(aggregate); - if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { - listener - .onFailure( - new ValidationException( - CommonMessages.FILTER_QUERY_TOO_SPARSE, - ValidationIssueType.FILTER_QUERY, - ValidationAspect.MODEL - ) - ); - // blocks below are executed if data is dense enough with filter query applied. - // If HCAD then category fields will be added to bucket aggregation to see if they - // are the root cause of the issues and if not the feature queries will be checked for sparsity - } else if (anomalyDetector.isHighCardinality()) { - getTopEntityForCategoryField(latestTime); - } else { - try { - checkFeatureQueryDelegate(latestTime); - } catch (Exception ex) { - logger.error(ex); - listener.onFailure(ex); - } - } - } - - private void getTopEntityForCategoryField(long latestTime) { - ActionListener> getTopEntityListener = ActionListener - .wrap(topEntity -> checkCategoryFieldSparsity(topEntity, latestTime), exception -> { - listener.onFailure(exception); - logger.error("Failed to get top entity for categorical field", exception); - return; - }); - getTopEntity(getTopEntityListener); - } - - private void checkCategoryFieldSparsity(Map topEntity, long latestTime) { - BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); - for (Map.Entry entry : topEntity.entrySet()) { - query.filter(QueryBuilders.termQuery(entry.getKey(), entry.getValue())); - } - AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); - SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener - .wrap(response -> processTopEntityResults(response, latestTime), listener::onFailure); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - - private void processTopEntityResults(SearchResponse response, long latestTime) { - Histogram aggregate = checkBucketResultErrors(response); - if (aggregate == null) { - return; - } - double fullBucketRate = processBucketAggregationResults(aggregate); - if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { - listener - .onFailure( - new ValidationException(CommonMessages.CATEGORY_FIELD_TOO_SPARSE, ValidationIssueType.CATEGORY, ValidationAspect.MODEL) - ); - } else { - try { - checkFeatureQueryDelegate(latestTime); - } catch (Exception ex) { - logger.error(ex); - listener.onFailure(ex); - } - } - } - - private void checkFeatureQueryDelegate(long latestTime) throws IOException { - ActionListener> validateFeatureQueriesListener = ActionListener.wrap(response -> { - windowDelayRecommendation(latestTime); - }, exception -> { - listener - .onFailure(new ValidationException(exception.getMessage(), ValidationIssueType.FEATURE_ATTRIBUTES, ValidationAspect.MODEL)); - }); - MultiResponsesDelegateActionListener> multiFeatureQueriesResponseListener = - new MultiResponsesDelegateActionListener<>( - validateFeatureQueriesListener, - anomalyDetector.getFeatureAttributes().size(), - CommonMessages.FEATURE_QUERY_TOO_SPARSE, - false - ); - - for (Feature feature : anomalyDetector.getFeatureAttributes()) { - AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); - BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); - List featureFields = ParseUtils.getFieldNamesForFeature(feature, xContentRegistry); - for (String featureField : featureFields) { - query.filter(QueryBuilders.existsQuery(featureField)); - } - SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])) - .source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener.wrap(response -> { - Histogram aggregate = checkBucketResultErrors(response); - if (aggregate == null) { - return; - } - double fullBucketRate = processBucketAggregationResults(aggregate); - if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { - multiFeatureQueriesResponseListener - .onFailure( - new ValidationException( - CommonMessages.FEATURE_QUERY_TOO_SPARSE, - ValidationIssueType.FEATURE_ATTRIBUTES, - ValidationAspect.MODEL - ) - ); - } else { - multiFeatureQueriesResponseListener - .onResponse(new MergeableList<>(new ArrayList<>(Collections.singletonList(new double[] { fullBucketRate })))); - } - }, e -> { - logger.error(e); - multiFeatureQueriesResponseListener - .onFailure(new OpenSearchStatusException(CommonMessages.FEATURE_QUERY_TOO_SPARSE, RestStatus.BAD_REQUEST, e)); - }); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - } - - private void sendWindowDelayRec(long latestTimeInMillis) { - long minutesSinceLastStamp = (long) Math.ceil((Instant.now().toEpochMilli() - latestTimeInMillis) / 60000.0); - listener - .onFailure( - new ValidationException( - String.format(Locale.ROOT, CommonMessages.WINDOW_DELAY_REC, minutesSinceLastStamp, minutesSinceLastStamp), - ValidationIssueType.WINDOW_DELAY, - ValidationAspect.MODEL, - new IntervalTimeConfiguration(minutesSinceLastStamp, ChronoUnit.MINUTES) - ) - ); - } - - private void windowDelayRecommendation(long latestTime) { - // Check if there is a better window-delay to recommend and if one was recommended - // then send exception and return, otherwise continue to let user know data is too sparse as explained below - if (Instant.now().toEpochMilli() - latestTime > timeConfigToMilliSec(anomalyDetector.getWindowDelay())) { - sendWindowDelayRec(latestTime); - return; - } - // This case has been reached if following conditions are met: - // 1. no interval recommendation was found that leads to a bucket success rate of >= 0.75 - // 2. bucket success rate with the given interval and just raw data is also below 0.75. - // 3. no single configuration during the following checks reduced the bucket success rate below 0.25 - // This means the rate with all configs applied or just raw data was below 0.75 but the rate when checking each configuration at - // a time was always above 0.25 meaning the best suggestion is to simply ingest more data or change interval since - // we have no more insight regarding the root cause of the lower density. - listener - .onFailure(new ValidationException(CommonMessages.RAW_DATA_TOO_SPARSE, ValidationIssueType.INDICES, ValidationAspect.MODEL)); - } - - private LongBounds getTimeRangeBounds(long endMillis, IntervalTimeConfiguration detectorIntervalInMinutes) { - Long detectorInterval = timeConfigToMilliSec(detectorIntervalInMinutes); - Long startMillis = endMillis - (getNumberOfSamples() * detectorInterval); - return new LongBounds(startMillis, endMillis); - } - - private int getNumberOfSamples() { - long interval = anomalyDetector.getIntervalInMilliseconds(); - return Math - .max( - (int) (Duration.ofHours(AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS).toMillis() / interval), - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES - ); - } - - private Long timeConfigToMilliSec(TimeConfiguration config) { - return Optional.ofNullable((IntervalTimeConfiguration) config).map(t -> t.toDuration().toMillis()).orElse(0L); - } -} diff --git a/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java index 3c0b13c5e..cf52a2237 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java @@ -14,24 +14,23 @@ import java.time.Clock; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestRequest; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.transport.ValidateConfigResponse; import org.opensearch.timeseries.util.SecurityClientUtil; /** * Anomaly detector REST action handler to process POST request. * POST request is for validating anomaly detector against detector and/or model configs. */ -public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetectorActionHandler { +public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetectorActionHandler { /** * Constructor function. @@ -39,13 +38,13 @@ public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetecto * @param clusterService ClusterService * @param client ES node client that executes actions on the local node * @param clientUtil AD client utility - * @param listener ES channel used to construct bytes / builder based outputs, and send responses * @param anomalyDetectionIndices anomaly detector index manager * @param anomalyDetector anomaly detector instance * @param requestTimeout request time out configuration * @param maxSingleEntityAnomalyDetectors max single-entity anomaly detectors allowed * @param maxMultiEntityAnomalyDetectors max multi-entity detectors allowed * @param maxAnomalyFeatures max features allowed per detector + * @param maxCategoricalFields max number of categorical fields * @param method Rest Method type * @param xContentRegistry Registry which is used for XContentParser * @param user User context @@ -58,13 +57,13 @@ public ValidateAnomalyDetectorActionHandler( ClusterService clusterService, Client client, SecurityClientUtil clientUtil, - ActionListener listener, ADIndexManagement anomalyDetectionIndices, - AnomalyDetector anomalyDetector, + Config anomalyDetector, TimeValue requestTimeout, Integer maxSingleEntityAnomalyDetectors, Integer maxMultiEntityAnomalyDetectors, Integer maxAnomalyFeatures, + Integer maxCategoricalFields, RestRequest.Method method, NamedXContentRegistry xContentRegistry, User user, @@ -78,9 +77,8 @@ public ValidateAnomalyDetectorActionHandler( client, clientUtil, null, - listener, anomalyDetectionIndices, - AnomalyDetector.NO_ID, + Config.NO_ID, null, null, null, @@ -89,6 +87,7 @@ public ValidateAnomalyDetectorActionHandler( maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry, user, @@ -100,16 +99,4 @@ public ValidateAnomalyDetectorActionHandler( settings ); } - - // If validation type is detector then all validation in AbstractAnomalyDetectorActionHandler that is called - // by super.start() involves validation checks against the detector configurations, - // any issues raised here would block user from creating the anomaly detector. - // If validation Aspect is of type model then further non-blocker validation will be executed - // after the blocker validation is executed. Any issues that are raised for model validation - // are simply warnings for the user in terms of how configuration could be changed to lead to - // a higher likelihood of model training completing successfully - @Override - public void start() { - super.start(); - } } diff --git a/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java b/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java index ed4414f6c..0de968afb 100644 --- a/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java +++ b/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java @@ -1,12 +1,6 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ package org.opensearch.ad.settings; @@ -34,12 +28,16 @@ public class ADEnabledSetting extends DynamicNumericSetting { */ public static final String AD_ENABLED = "plugins.anomaly_detection.enabled"; + // use TimeSeriesEnabledSetting.BREAKER_ENABLED instread + @Deprecated public static final String AD_BREAKER_ENABLED = "plugins.anomaly_detection.breaker.enabled"; public static final String LEGACY_OPENDISTRO_AD_ENABLED = "opendistro.anomaly_detection.enabled"; public static final String LEGACY_OPENDISTRO_AD_BREAKER_ENABLED = "opendistro.anomaly_detection.breaker.enabled"; + // we don't support interpolation during cold start starting 3.0 (TODO replace with the right version) + @Deprecated public static final String INTERPOLATION_IN_HCAD_COLD_START_ENABLED = "plugins.anomaly_detection.hcad_cold_start_interpolation.enabled"; public static final String DOOR_KEEPER_IN_CACHE_ENABLED = "plugins.anomaly_detection.door_keeper_in_cache.enabled"; @@ -82,7 +80,7 @@ public class ADEnabledSetting extends DynamicNumericSetting { * filter out unpopular items that are not likely to appear more * than once. Whether this bloom filter is enabled or not. */ - put(DOOR_KEEPER_IN_CACHE_ENABLED, Setting.boolSetting(DOOR_KEEPER_IN_CACHE_ENABLED, false, NodeScope, Dynamic)); + put(DOOR_KEEPER_IN_CACHE_ENABLED, Setting.boolSetting(DOOR_KEEPER_IN_CACHE_ENABLED, true, NodeScope, Dynamic)); } }); @@ -105,14 +103,6 @@ public static boolean isADEnabled() { return ADEnabledSetting.getInstance().getSettingValue(ADEnabledSetting.AD_ENABLED); } - /** - * Whether AD circuit breaker is enabled or not. If disabled, an open circuit breaker wouldn't cause an AD job to be stopped. - * @return whether AD circuit breaker is enabled or not. - */ - public static boolean isADBreakerEnabled() { - return ADEnabledSetting.getInstance().getSettingValue(ADEnabledSetting.AD_BREAKER_ENABLED); - } - /** * If enabled, we use samples plus interpolation to train models. * @return wWhether interpolation in HCAD cold start is enabled or not. diff --git a/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java b/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java index e064867a0..869cdf412 100644 --- a/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java +++ b/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java @@ -1,12 +1,6 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ package org.opensearch.ad.settings; diff --git a/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java b/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java index b5f10b383..3e732b374 100644 --- a/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java +++ b/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java @@ -115,7 +115,7 @@ private AnomalyDetectorSettings() {} /** * @deprecated This setting is deprecated because we need to manage fault tolerance for * multiple analysis such as AD and forecasting. - * Use TimeSeriesSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE instead. + * Use TimeSeriesSettings#MAX_RETRY_FOR_UNRESPONSIVE_NODE instead. */ @Deprecated public static final Setting AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE = Setting @@ -130,7 +130,7 @@ private AnomalyDetectorSettings() {} /** * @deprecated This setting is deprecated because we need to manage fault tolerance for * multiple analysis such as AD and forecasting. - * Use TimeSeriesSettings.COOLDOWN_MINUTES instead. + * Use {@link TimeSeriesSettings#COOLDOWN_MINUTES} instead. */ @Deprecated public static final Setting AD_COOLDOWN_MINUTES = Setting @@ -144,7 +144,7 @@ private AnomalyDetectorSettings() {} /** * @deprecated This setting is deprecated because we need to manage fault tolerance for * multiple analysis such as AD and forecasting. - * Use TimeSeriesSettings.BACKOFF_MINUTES instead. + * Use {@link TimeSeriesSettings#BACKOFF_MINUTES} instead. */ @Deprecated public static final Setting AD_BACKOFF_MINUTES = Setting @@ -238,10 +238,6 @@ private AnomalyDetectorSettings() {} public static final int MAX_SAMPLE_STRIDE = 64; - public static final int TRAIN_SAMPLE_TIME_RANGE_IN_HOURS = 24; - - public static final int MIN_TRAIN_SAMPLES = 512; - public static final int MAX_IMPUTATION_NEIGHBOR_DISTANCE = 2; // shingling @@ -592,37 +588,6 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - /** - * EntityRequest has entityName (# category fields * 256, the recommended limit - * of a keyword field length), model Id (roughly 256 bytes), and QueuedRequest - * fields including detector Id(roughly 128 bytes), expirationEpochMs (long, - * 8 bytes), and priority (12 bytes). - * Plus Java object size (12 bytes), we have roughly 928 bytes per request - * assuming we have 2 categorical fields (plan to support 2 categorical fields now). - * We don't want the total size exceeds 0.1% of the heap. - * We can have at most 0.1% heap / 928 = heap / 928,000. - * For t3.small, 0.1% heap is of 1MB. The queue's size is up to - * 10^ 6 / 928 = 1078 - */ - // to be replaced by TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES - @Deprecated - public static int ENTITY_REQUEST_SIZE_IN_BYTES = 928; - - /** - * EntityFeatureRequest consists of EntityRequest (928 bytes, read comments - * of ENTITY_COLD_START_QUEUE_SIZE_CONSTANT), pointer to current feature - * (8 bytes), and dataStartTimeMillis (8 bytes). We have roughly - * 928 + 16 = 944 bytes per request. - * - * We don't want the total size exceeds 0.1% of the heap. - * We should have at most 0.1% heap / 944 = heap / 944,000 - * For t3.small, 0.1% heap is of 1MB. The queue's size is up to - * 10^ 6 / 944 = 1059 - */ - // to be replaced by TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES - @Deprecated - public static int ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES = 944; - // ====================================== // pagination setting // ====================================== @@ -701,14 +666,6 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - // ====================================== - // Validate Detector API setting - // ====================================== - public static final long TOP_VALIDATE_TIMEOUT_IN_MILLIS = 10_000; - public static final long MAX_INTERVAL_REC_LENGTH_IN_MINUTES = 60L; - public static final double INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER = 1.2; - public static final double INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER = 0.8; - public static final double INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE = 0.75; public static final double CONFIG_BUCKET_MINIMUM_SUCCESS_RATE = 0.25; // This value is set to decrease the number of times we decrease the interval when recommending a new one // The reason we need a max is because user could give an arbitrarly large interval where we don't know even diff --git a/src/main/java/org/opensearch/ad/stats/ADStats.java b/src/main/java/org/opensearch/ad/stats/ADStats.java index 1fb0e8fe4..433b8b0aa 100644 --- a/src/main/java/org/opensearch/ad/stats/ADStats.java +++ b/src/main/java/org/opensearch/ad/stats/ADStats.java @@ -1,84 +1,19 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ package org.opensearch.ad.stats; -import java.util.HashMap; import java.util.Map; -/** - * This class is the main entry-point for access to the stats that the AD plugin keeps track of. - */ -public class ADStats { - - private Map> stats; - - /** - * Constructor - * - * @param stats Map of the stats that are to be kept - */ - public ADStats(Map> stats) { - this.stats = stats; - } +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.stats.TimeSeriesStat; - /** - * Get the stats - * - * @return all of the stats - */ - public Map> getStats() { - return stats; - } +public class ADStats extends Stats { - /** - * Get individual stat by stat name - * - * @param key Name of stat - * @return ADStat - * @throws IllegalArgumentException thrown on illegal statName - */ - public ADStat getStat(String key) throws IllegalArgumentException { - if (!stats.keySet().contains(key)) { - throw new IllegalArgumentException("Stat=\"" + key + "\" does not exist"); - } - return stats.get(key); + public ADStats(Map> stats) { + super(stats); } - /** - * Get a map of the stats that are kept at the node level - * - * @return Map of stats kept at the node level - */ - public Map> getNodeStats() { - return getClusterOrNodeStats(false); - } - - /** - * Get a map of the stats that are kept at the cluster level - * - * @return Map of stats kept at the cluster level - */ - public Map> getClusterStats() { - return getClusterOrNodeStats(true); - } - - private Map> getClusterOrNodeStats(Boolean getClusterStats) { - Map> statsMap = new HashMap<>(); - - for (Map.Entry> entry : stats.entrySet()) { - if (entry.getValue().isClusterLevel() == getClusterStats) { - statsMap.put(entry.getKey(), entry.getValue()); - } - } - return statsMap; - } } diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeCountSupplier.java b/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeCountSupplier.java new file mode 100644 index 000000000..48cc36ebb --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeCountSupplier.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.stats.suppliers; + +import java.util.function.Supplier; +import java.util.stream.Stream; + +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.ml.ADModelManager; + +/** + * ModelsOnNodeCountSupplier provides the number of models a node contains + */ +public class ADModelsOnNodeCountSupplier implements Supplier { + private ADModelManager modelManager; + private ADCacheProvider adCache; + + /** + * Constructor + * + * @param modelManager object that manages the model partitions hosted on the node + * @param adCache object that manages multi-entity detectors' models + */ + public ADModelsOnNodeCountSupplier(ADModelManager modelManager, ADCacheProvider adCache) { + this.modelManager = modelManager; + this.adCache = adCache; + } + + @Override + public Long get() { + return Stream.concat(modelManager.getAllModels().stream(), adCache.get().getAllModels().stream()).count(); + } +} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeSupplier.java b/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeSupplier.java new file mode 100644 index 000000000..26b1cb8d5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeSupplier.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.stats.suppliers; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; +import static org.opensearch.timeseries.ml.ModelState.LAST_CHECKPOINT_TIME_KEY; +import static org.opensearch.timeseries.ml.ModelState.LAST_USED_TIME_KEY; +import static org.opensearch.timeseries.ml.ModelState.MODEL_TYPE_KEY; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.constant.CommonName; + +public class ADModelsOnNodeSupplier implements Supplier>> { + private ADModelManager modelManager; + private ADCacheProvider adCache; + // the max number of models to return per node. Defaults to 100. + private volatile int adNumModelsToReturn; + + /** + * Set that contains the model stats that should be exposed. + */ + public static Set MODEL_STATE_STAT_KEYS = new HashSet<>( + Arrays + .asList( + CommonName.MODEL_ID_FIELD, + ADCommonName.DETECTOR_ID_KEY, + MODEL_TYPE_KEY, + CommonName.ENTITY_KEY, + LAST_USED_TIME_KEY, + LAST_CHECKPOINT_TIME_KEY + ) + ); + + /** + * Constructor + * + * @param modelManager object that manages the model partitions hosted on the node + * @param adCache object that manages multi-entity detectors' models + * @param settings node settings accessor + * @param clusterService Cluster service accessor + */ + public ADModelsOnNodeSupplier(ADModelManager modelManager, ADCacheProvider adCache, Settings settings, ClusterService clusterService) { + this.modelManager = modelManager; + this.adCache = adCache; + this.adNumModelsToReturn = AD_MAX_MODEL_SIZE_PER_NODE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_MODEL_SIZE_PER_NODE, it -> this.adNumModelsToReturn = it); + + } + + @Override + public List> get() { + Stream> adStream = Stream + .concat(modelManager.getAllModels().stream(), adCache.get().getAllModels().stream()) + .limit(adNumModelsToReturn) + .map( + modelState -> modelState + .getModelStateAsMap() + .entrySet() + .stream() + .filter(entry -> MODEL_STATE_STAT_KEYS.contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ); + + return adStream.collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java b/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java deleted file mode 100644 index 8fdac74d7..000000000 --- a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.stats.suppliers; - -import java.util.function.Supplier; -import java.util.stream.Stream; - -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.ModelManager; - -/** - * ModelsOnNodeCountSupplier provides the number of models a node contains - */ -public class ModelsOnNodeCountSupplier implements Supplier { - private ModelManager modelManager; - private CacheProvider cache; - - /** - * Constructor - * - * @param modelManager object that manages the model partitions hosted on the node - * @param cache object that manages multi-entity detectors' models - */ - public ModelsOnNodeCountSupplier(ModelManager modelManager, CacheProvider cache) { - this.modelManager = modelManager; - this.cache = cache; - } - - @Override - public Long get() { - return Stream.concat(modelManager.getAllModels().stream(), cache.get().getAllModels().stream()).count(); - } -} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java b/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java deleted file mode 100644 index 2cdee5fb8..000000000 --- a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.stats.suppliers; - -import static org.opensearch.ad.ml.ModelState.LAST_CHECKPOINT_TIME_KEY; -import static org.opensearch.ad.ml.ModelState.LAST_USED_TIME_KEY; -import static org.opensearch.ad.ml.ModelState.MODEL_TYPE_KEY; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; -import org.opensearch.timeseries.constant.CommonName; - -/** - * ModelsOnNodeSupplier provides a List of ModelStates info for the models the nodes contains - */ -public class ModelsOnNodeSupplier implements Supplier>> { - private ModelManager modelManager; - private CacheProvider cache; - // the max number of models to return per node. Defaults to 100. - private volatile int numModelsToReturn; - - /** - * Set that contains the model stats that should be exposed. - */ - public static Set MODEL_STATE_STAT_KEYS = new HashSet<>( - Arrays - .asList( - CommonName.MODEL_ID_FIELD, - ADCommonName.DETECTOR_ID_KEY, - MODEL_TYPE_KEY, - CommonName.ENTITY_KEY, - LAST_USED_TIME_KEY, - LAST_CHECKPOINT_TIME_KEY - ) - ); - - /** - * Constructor - * - * @param modelManager object that manages the model partitions hosted on the node - * @param cache object that manages multi-entity detectors' models - * @param settings node settings accessor - * @param clusterService Cluster service accessor - */ - public ModelsOnNodeSupplier(ModelManager modelManager, CacheProvider cache, Settings settings, ClusterService clusterService) { - this.modelManager = modelManager; - this.cache = cache; - this.numModelsToReturn = AD_MAX_MODEL_SIZE_PER_NODE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it); - } - - @Override - public List> get() { - List> values = new ArrayList<>(); - Stream - .concat(modelManager.getAllModels().stream(), cache.get().getAllModels().stream()) - .limit(numModelsToReturn) - .forEach( - modelState -> values - .add( - modelState - .getModelStateAsMap() - .entrySet() - .stream() - .filter(entry -> MODEL_STATE_STAT_KEYS.contains(entry.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) - ) - ); - - return values; - } -} diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java index 05897fe64..7fb488d95 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java @@ -11,11 +11,6 @@ package org.opensearch.ad.task; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_MIN_SAMPLES; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_SAMPLES_PER_TREE; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_TREES; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.TIME_DECAY; - import java.util.ArrayDeque; import java.util.Deque; import java.util.Map; @@ -62,7 +57,7 @@ protected ADBatchTaskCache(ADTask adTask) { this.entity = adTask.getEntity(); AnomalyDetector detector = adTask.getDetector(); - int numberOfTrees = NUM_TREES; + int numberOfTrees = TimeSeriesSettings.NUM_TREES; int shingleSize = detector.getShingleSize(); this.shingle = new ArrayDeque<>(shingleSize); int dimensions = detector.getShingleSize() * detector.getEnabledFeatureIds().size(); @@ -71,10 +66,10 @@ protected ADBatchTaskCache(ADTask adTask) { .builder() .dimensions(dimensions) .numberOfTrees(numberOfTrees) - .timeDecay(TIME_DECAY) - .sampleSize(NUM_SAMPLES_PER_TREE) - .outputAfter(NUM_MIN_SAMPLES) - .initialAcceptFraction(NUM_MIN_SAMPLES * 1.0d / NUM_SAMPLES_PER_TREE) + .timeDecay(detector.getTimeDecay()) + .sampleSize(TimeSeriesSettings.NUM_SAMPLES_PER_TREE) + .outputAfter(TimeSeriesSettings.NUM_MIN_SAMPLES) + .initialAcceptFraction(TimeSeriesSettings.NUM_MIN_SAMPLES * 1.0d / TimeSeriesSettings.NUM_SAMPLES_PER_TREE) .parallelExecutionEnabled(false) .compact(true) .precision(Precision.FLOAT_32) diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java index f25b09af4..fba6b0206 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java @@ -12,24 +12,16 @@ package org.opensearch.ad.task; import static org.opensearch.ad.constant.ADCommonMessages.NO_ELIGIBLE_NODE_TO_RUN_DETECTOR; -import static org.opensearch.ad.model.ADTask.CURRENT_PIECE_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_END_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.INIT_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.STATE_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.WORKER_NODE_FIELD; import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TOP_ENTITIES_FOR_HISTORICAL_ANALYSIS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TOP_ENTITIES_LIMIT_FOR_HISTORICAL_ANALYSIS; -import static org.opensearch.ad.stats.InternalStatNames.JVM_HEAP_USAGE; import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; import static org.opensearch.timeseries.breaker.MemoryCircuitBreaker.DEFAULT_JVM_HEAP_USAGE_THRESHOLD; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_MIN_SAMPLES; +import static org.opensearch.timeseries.stats.InternalStatNames.JVM_HEAP_USAGE; import static org.opensearch.timeseries.stats.StatNames.AD_EXECUTING_BATCH_TASK_COUNT; -import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; import java.time.Clock; import java.time.Instant; @@ -49,14 +41,10 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.ad.caching.PriorityTracker; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SinglePointFeatures; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; @@ -67,10 +55,7 @@ import org.opensearch.ad.transport.ADBatchAnomalyResultRequest; import org.opensearch.ad.transport.ADBatchAnomalyResultResponse; import org.opensearch.ad.transport.ADBatchTaskRemoteExecutionAction; -import org.opensearch.ad.transport.ADStatsNodeResponse; import org.opensearch.ad.transport.ADStatsNodesAction; -import org.opensearch.ad.transport.ADStatsRequest; -import org.opensearch.ad.transport.handler.AnomalyResultBulkIndexHandler; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -92,20 +77,29 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.PriorityTracker; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TaskCancelledException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.feature.SinglePointFeatures; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.FeatureData; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.ExceptionUtil; import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.timeseries.util.SecurityClientUtil; @@ -131,14 +125,14 @@ public class ADBatchTaskRunner { private final FeatureManager featureManager; private final CircuitBreakerService adCircuitBreakerService; private final ADTaskManager adTaskManager; - private final AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler; + private final ResultBulkIndexingHandler anomalyResultBulkIndexHandler; private final ADIndexManagement anomalyDetectionIndices; private final SearchFeatureDao searchFeatureDao; private final ADTaskCacheManager adTaskCacheManager; private final TransportRequestOptions option; private final HashRing hashRing; - private final ModelManager modelManager; + private final ADModelManager modelManager; private volatile Integer maxAdBatchTaskPerNode; private volatile Integer pieceSize; @@ -160,11 +154,11 @@ public ADBatchTaskRunner( ADTaskManager adTaskManager, ADIndexManagement anomalyDetectionIndices, ADStats adStats, - AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler, + ResultBulkIndexingHandler anomalyResultBulkIndexHandler, ADTaskCacheManager adTaskCacheManager, SearchFeatureDao searchFeatureDao, HashRing hashRing, - ModelManager modelManager + ADModelManager modelManager ) { this.settings = settings; this.threadPool = threadPool; @@ -267,7 +261,7 @@ private ActionListener getTopEntitiesListener( adTaskCacheManager.setTopEntityInited(detectorId); int totalEntities = adTaskCacheManager.getPendingEntityCount(detectorId); logger.info("Total top entities: {} for detector {}, task {}", totalEntities, detectorId, taskId); - hashRing.getNodesWithSameLocalAdVersion(dataNodes -> { + hashRing.getNodesWithSameLocalVersion(dataNodes -> { int numberOfEligibleDataNodes = dataNodes.length; // maxAdBatchTaskPerNode means how many task can run on per data node, which is hard limitation per node. // maxRunningEntitiesPerDetector means how many entities can run per detector on whole cluster, which is @@ -533,7 +527,7 @@ public void forwardOrExecuteADTask( ? adTask.getParentTaskId() // For HISTORICAL_HC_ENTITY task, return its parent task id : adTask.getTaskId(); // For HISTORICAL_HC_DETECTOR task, its task id is parent task id adTaskManager - .getAndExecuteOnLatestADTask( + .getAndExecuteOnLatestConfigTask( detectorId, parentTaskId, entity, @@ -578,7 +572,7 @@ public void forwardOrExecuteADTask( .entity(entity) .parentTaskId(parentTaskId) .build(); - adTaskManager.createADTaskDirectly(adEntityTask, r -> { + adTaskManager.createTaskDirectly(adEntityTask, r -> { adEntityTask.setTaskId(r.getId()); ActionListener workerNodeResponseListener = workerNodeResponseListener( adEntityTask, @@ -595,15 +589,15 @@ public void forwardOrExecuteADTask( ); } else { Map updatedFields = new HashMap<>(); - updatedFields.put(STATE_FIELD, TaskState.INIT.name()); - updatedFields.put(INIT_PROGRESS_FIELD, 0.0f); + updatedFields.put(TimeSeriesTask.STATE_FIELD, TaskState.INIT.name()); + updatedFields.put(TimeSeriesTask.INIT_PROGRESS_FIELD, 0.0f); ActionListener workerNodeResponseListener = workerNodeResponseListener( adTask, transportService, listener ); adTaskManager - .updateADTask( + .updateTask( adTask.getTaskId(), updatedFields, ActionListener.wrap(r -> forwardOrExecuteEntityTask(adTask, transportService, workerNodeResponseListener), e -> { @@ -634,7 +628,7 @@ private ActionListener workerNodeResponseListener( ) { ActionListener actionListener = ActionListener.wrap(r -> { listener.onResponse(r); - if (adTask.isEntityTask()) { + if (adTask.isHistoricalEntityTask()) { // When reach this line, the entity task already been put into worker node's cache. // Then it's safe to move entity from temp entities queue to running entities queue. adTaskCacheManager.moveToRunningEntity(adTask.getConfigId(), adTaskManager.convertEntityToString(adTask)); @@ -704,12 +698,12 @@ private synchronized void startNewEntityTaskLane(ADTask adTask, TransportService } private void dispatchTask(ADTask adTask, ActionListener listener) { - hashRing.getNodesWithSameLocalAdVersion(dataNodes -> { - ADStatsRequest adStatsRequest = new ADStatsRequest(dataNodes); + hashRing.getNodesWithSameLocalVersion(dataNodes -> { + StatsRequest adStatsRequest = new StatsRequest(dataNodes); adStatsRequest.addAll(ImmutableSet.of(AD_EXECUTING_BATCH_TASK_COUNT.getName(), JVM_HEAP_USAGE.getName())); client.execute(ADStatsNodesAction.INSTANCE, adStatsRequest, ActionListener.wrap(adStatsResponse -> { - List candidateNodeResponse = adStatsResponse + List candidateNodeResponse = adStatsResponse .getNodes() .stream() .filter(stat -> (long) stat.getStatsMap().get(JVM_HEAP_USAGE.getName()) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD) @@ -739,9 +733,9 @@ private void dispatchTask(ADTask adTask, ActionListener listener) listener.onFailure(new LimitExceededException(adTask.getConfigId(), errorMessage)); return; } - Optional targetNode = candidateNodeResponse + Optional targetNode = candidateNodeResponse .stream() - .sorted((ADStatsNodeResponse r1, ADStatsNodeResponse r2) -> { + .sorted((StatsNodeResponse r1, StatsNodeResponse r2) -> { int result = ((Long) r1.getStatsMap().get(AD_EXECUTING_BATCH_TASK_COUNT.getName())) .compareTo((Long) r2.getStatsMap().get(AD_EXECUTING_BATCH_TASK_COUNT.getName())); if (result == 0) { @@ -808,11 +802,11 @@ private ActionListener internalBatchTaskListener(ADTask adTask, Transpor .cleanDetectorCache( adTask, transportService, - () -> adTaskManager.updateADTask(taskId, ImmutableMap.of(STATE_FIELD, TaskState.FINISHED.name())) + () -> adTaskManager.updateTask(taskId, ImmutableMap.of(TimeSeriesTask.STATE_FIELD, TaskState.FINISHED.name())) ); } else { // Set entity task as FINISHED here - adTaskManager.updateADTask(adTask.getTaskId(), ImmutableMap.of(STATE_FIELD, TaskState.FINISHED.name())); + adTaskManager.updateTask(adTask.getTaskId(), ImmutableMap.of(TimeSeriesTask.STATE_FIELD, TaskState.FINISHED.name())); adTaskManager.entityTaskDone(adTask, null, transportService); } }, e -> { @@ -845,7 +839,7 @@ private void handleException(ADTask adTask, Exception e) { adStats.getStat(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName()).increment(); } // Handle AD task exception - adTaskManager.handleADTaskException(adTask, e); + adTaskManager.handleTaskException(adTask, e); } private void executeADBatchTaskOnWorkerNode(ADTask adTask, ActionListener internalListener) { @@ -888,19 +882,19 @@ private void checkCircuitBreaker(ADTask adTask) { private void runFirstPiece(ADTask adTask, Instant executeStartTime, ActionListener internalListener) { try { adTaskManager - .updateADTask( + .updateTask( adTask.getTaskId(), ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, TaskState.INIT.name(), - CURRENT_PIECE_FIELD, + TimeSeriesTask.CURRENT_PIECE_FIELD, adTask.getDetectionDateRange().getStartTime().toEpochMilli(), - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, 0.0f, - INIT_PROGRESS_FIELD, + TimeSeriesTask.INIT_PROGRESS_FIELD, 0.0f, - WORKER_NODE_FIELD, + TimeSeriesTask.WORKER_NODE_FIELD, clusterService.localNode().getId() ), ActionListener.wrap(r -> { @@ -996,7 +990,7 @@ private void getDateRangeOfSourceData(ADTask adTask, BiConsumer cons dataStartTime = dataStartTime - dataStartTime % interval; dataEndTime = dataEndTime - dataEndTime % interval; logger.debug("adjusted date range: start: {}, end: {}, taskId: {}", dataStartTime, dataEndTime, taskId); - if ((dataEndTime - dataStartTime) < NUM_MIN_SAMPLES * interval) { + if ((dataEndTime - dataStartTime) < TimeSeriesSettings.NUM_MIN_SAMPLES * interval) { internalListener.onFailure(new TimeSeriesException("There is not enough data to train model").countedInStats(false)); return; } @@ -1229,10 +1223,12 @@ private void storeAnomalyResultAndRunNextPiece( false ); + String detectorId = adTask.getConfigId(); anomalyResultBulkIndexHandler - .bulkIndexAnomalyResult( + .bulk( resultIndex, anomalyResults, + detectorId, runBefore == null ? actionListener : ActionListener.runBefore(actionListener, runBefore) ); } @@ -1252,7 +1248,7 @@ private void runNextPiece( String taskState = initProgress >= 1.0f ? TaskState.RUNNING.name() : TaskState.INIT.name(); logger.debug("Init progress: {}, taskState:{}, task id: {}", initProgress, taskState, taskId); - if (initProgress >= 1.0f && adTask.isEntityTask()) { + if (initProgress >= 1.0f && adTask.isHistoricalEntityTask()) { updateDetectorLevelTaskState(detectorId, adTask.getParentTaskId(), TaskState.RUNNING.name()); } @@ -1273,17 +1269,17 @@ private void runNextPiece( float taskProgress = (float) (pieceStartTime - dataStartTime) / (dataEndTime - dataStartTime); logger.debug("Task progress: {}, task id:{}, detector id:{}", taskProgress, taskId, detectorId); adTaskManager - .updateADTask( + .updateTask( taskId, ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, taskState, - CURRENT_PIECE_FIELD, + TimeSeriesTask.CURRENT_PIECE_FIELD, pieceStartTime, - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, taskProgress, - INIT_PROGRESS_FIELD, + TimeSeriesTask.INIT_PROGRESS_FIELD, initProgress ), ActionListener @@ -1306,19 +1302,19 @@ private void runNextPiece( logger.info("AD task finished for detector {}, task id: {}", detectorId, taskId); adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); adTaskManager - .updateADTask( + .updateTask( taskId, ImmutableMap .of( - CURRENT_PIECE_FIELD, + TimeSeriesTask.CURRENT_PIECE_FIELD, dataEndTime, - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, 1.0f, - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli(), - INIT_PROGRESS_FIELD, + TimeSeriesTask.INIT_PROGRESS_FIELD, initProgress, - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, TaskState.FINISHED ), ActionListener.wrap(r -> internalListener.onResponse("task execution done"), e -> internalListener.onFailure(e)) @@ -1328,7 +1324,7 @@ private void runNextPiece( private void updateDetectorLevelTaskState(String detectorId, String detectorTaskId, String newState) { ExecutorFunction function = () -> adTaskManager - .updateADTask(detectorTaskId, ImmutableMap.of(STATE_FIELD, newState), ActionListener.wrap(r -> { + .updateTask(detectorTaskId, ImmutableMap.of(TimeSeriesTask.STATE_FIELD, newState), ActionListener.wrap(r -> { logger.info("Updated HC detector task: {} state as: {} for detector: {}", detectorTaskId, newState, detectorId); adTaskCacheManager.updateDetectorTaskState(detectorId, detectorTaskId, newState); }, e -> { logger.error("Failed to update HC detector task: {} for detector: {}", detectorTaskId, detectorId); })); @@ -1352,7 +1348,7 @@ private float calculateInitProgress(String taskId) { if (rcf == null) { return 0.0f; } - float initProgress = (float) rcf.getTotalUpdates() / NUM_MIN_SAMPLES; + float initProgress = (float) rcf.getTotalUpdates() / TimeSeriesSettings.NUM_MIN_SAMPLES; logger.debug("RCF total updates {} for task {}", rcf.getTotalUpdates(), taskId); return initProgress > 1.0f ? 1.0f : initProgress; } @@ -1381,7 +1377,7 @@ private void checkIfADTaskCancelledAndCleanupCache(ADTask adTask) { String cancelledBy = adTaskCacheManager.getCancelledBy(taskId); adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); if (!adTaskCacheManager.isHCTaskCoordinatingNode(detectorId) - && isNullOrEmpty(adTaskCacheManager.getTasksOfDetector(detectorId))) { + && ParseUtils.isNullOrEmpty(adTaskCacheManager.getTasksOfDetector(detectorId))) { // Clean up historical task cache for HC detector on worker node if no running entity task. logger.info("All AD task cancelled, cleanup historical task cache for detector {}", detectorId); adTaskCacheManager.removeHistoricalTaskCache(detectorId); diff --git a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java index 014a9f798..aab2eb652 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java @@ -142,7 +142,7 @@ public synchronized void add(ADTask adTask) { throw new DuplicateTaskException(DETECTOR_IS_RUNNING); } // It's possible that multiple entity tasks of one detector run on same data node. - if (!adTask.isEntityTask() && containsTaskOfDetector(detectorId)) { + if (!adTask.isHistoricalEntityTask() && containsTaskOfDetector(detectorId)) { throw new DuplicateTaskException(DETECTOR_IS_RUNNING); } checkRunningTaskLimit(); @@ -154,7 +154,7 @@ public synchronized void add(ADTask adTask) { ADBatchTaskCache taskCache = new ADBatchTaskCache(adTask); taskCache.getCacheMemorySize().set(neededCacheSize); batchTaskCaches.put(taskId, taskCache); - if (adTask.isEntityTask()) { + if (adTask.isHistoricalEntityTask()) { ADHCBatchTaskRunState hcBatchTaskRunState = getHCBatchTaskRunState(detectorId, adTask.getConfigLevelTaskId()); if (hcBatchTaskRunState != null) { hcBatchTaskRunState.setLastTaskRunTimeInMillis(Instant.now().toEpochMilli()); diff --git a/src/main/java/org/opensearch/ad/task/ADTaskManager.java b/src/main/java/org/opensearch/ad/task/ADTaskManager.java index 268bbc26a..644117cfd 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskManager.java @@ -18,20 +18,7 @@ import static org.opensearch.ad.constant.ADCommonMessages.NO_ELIGIBLE_NODE_TO_RUN_DETECTOR; import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; -import static org.opensearch.ad.model.ADTask.COORDINATING_NODE_FIELD; import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; -import static org.opensearch.ad.model.ADTask.ERROR_FIELD; -import static org.opensearch.ad.model.ADTask.ESTIMATED_MINUTES_LEFT_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_END_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_START_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.INIT_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; -import static org.opensearch.ad.model.ADTask.LAST_UPDATE_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.PARENT_TASK_ID_FIELD; -import static org.opensearch.ad.model.ADTask.STATE_FIELD; -import static org.opensearch.ad.model.ADTask.STOPPED_BY_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD; import static org.opensearch.ad.model.ADTaskType.ALL_HISTORICAL_TASK_TYPES; import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; import static org.opensearch.ad.model.ADTaskType.REALTIME_TASK_TYPES; @@ -39,31 +26,21 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.DELETE_AD_RESULT_WHEN_DELETE_DETECTOR; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS; -import static org.opensearch.ad.stats.InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT; -import static org.opensearch.ad.stats.InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; -import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_LATEST_TASK; -import static org.opensearch.timeseries.constant.CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; -import static org.opensearch.timeseries.constant.CommonName.TASK_ID_FIELD; import static org.opensearch.timeseries.model.TaskState.NOT_ENDED_STATES; import static org.opensearch.timeseries.model.TaskType.taskTypeToString; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_MIN_SAMPLES; -import static org.opensearch.timeseries.util.ExceptionUtil.getErrorMessage; -import static org.opensearch.timeseries.util.ExceptionUtil.getShardsFailure; -import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; -import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.stats.InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT; +import static org.opensearch.timeseries.stats.InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT; import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; -import java.util.HashMap; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -80,45 +57,28 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; -import org.apache.lucene.search.join.ScoreMode; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.Version; import org.opensearch.action.ActionListenerResponseHandler; -import org.opensearch.action.bulk.BulkAction; -import org.opensearch.action.bulk.BulkItemResponse; -import org.opensearch.action.bulk.BulkRequest; -import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; -import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.ADTaskProfileRunner; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.ADEntityTaskProfile; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskAction; import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; import org.opensearch.ad.transport.ADBatchAnomalyResultAction; import org.opensearch.ad.transport.ADBatchAnomalyResultRequest; import org.opensearch.ad.transport.ADCancelTaskAction; import org.opensearch.ad.transport.ADCancelTaskRequest; -import org.opensearch.ad.transport.ADStatsNodeResponse; import org.opensearch.ad.transport.ADStatsNodesAction; -import org.opensearch.ad.transport.ADStatsRequest; -import org.opensearch.ad.transport.ADTaskProfileAction; -import org.opensearch.ad.transport.ADTaskProfileNodeResponse; -import org.opensearch.ad.transport.ADTaskProfileRequest; import org.opensearch.ad.transport.ForwardADTaskAction; import org.opensearch.ad.transport.ForwardADTaskRequest; import org.opensearch.client.Client; @@ -126,7 +86,6 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; @@ -140,35 +99,39 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.NestedQueryBuilder; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; -import org.opensearch.index.reindex.DeleteByQueryAction; -import org.opensearch.index.reindex.DeleteByQueryRequest; -import org.opensearch.index.reindex.UpdateByQueryAction; -import org.opensearch.index.reindex.UpdateByQueryRequest; -import org.opensearch.script.Script; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.SortOrder; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.DuplicateTaskException; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TaskCancelledException; import org.opensearch.timeseries.common.exception.TimeSeriesException; -import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.function.BiCheckedFunction; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.function.ResponseTransformer; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; -import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.EntityTaskProfile; import org.opensearch.timeseries.model.TaskState; -import org.opensearch.timeseries.task.RealtimeTaskCache; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskManager; import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportService; @@ -179,28 +142,21 @@ /** * Manage AD task. */ -public class ADTaskManager { +public class ADTaskManager extends TaskManager { public static final String AD_TASK_LEAD_NODE_MODEL_ID = "ad_task_lead_node_model_id"; public static final String AD_TASK_MAINTAINENCE_NODE_MODEL_ID = "ad_task_maintainence_node_model_id"; // HC batch task timeout after 10 minutes if no update after last known run time. public static final int HC_BATCH_TASK_CACHE_TIMEOUT_IN_MILLIS = 600_000; - private final Logger logger = LogManager.getLogger(this.getClass()); + public final Logger logger = LogManager.getLogger(this.getClass()); static final String STATE_INDEX_NOT_EXIST_MSG = "State index does not exist."; private final Set retryableErrors = ImmutableSet.of(EXCEED_HISTORICAL_ANALYSIS_LIMIT, NO_ELIGIBLE_NODE_TO_RUN_DETECTOR); - private final Client client; - private final ClusterService clusterService; - private final NamedXContentRegistry xContentRegistry; - private final ADIndexManagement detectionIndices; + private final DiscoveryNodeFilterer nodeFilter; - private final ADTaskCacheManager adTaskCacheManager; private final HashRing hashRing; - private volatile Integer maxOldAdTaskDocsPerDetector; private volatile Integer pieceIntervalSeconds; - private volatile boolean deleteADResultWhenDeleteDetector; + private volatile TransportRequestOptions transportRequestOptions; - private final ThreadPool threadPool; - private static int DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS = 5; private final Semaphore checkingTaskSlot; private volatile Integer maxAdBatchTaskPerNode; @@ -208,6 +164,7 @@ public class ADTaskManager { private final Semaphore scaleEntityTaskLane; private static final int SCALE_ENTITY_TASK_LANE_INTERVAL_IN_MILLIS = 10_000; // 10 seconds + private final ADTaskProfileRunner taskProfileRunner; public ADTaskManager( Settings settings, @@ -218,29 +175,38 @@ public ADTaskManager( DiscoveryNodeFilterer nodeFilter, HashRing hashRing, ADTaskCacheManager adTaskCacheManager, - ThreadPool threadPool + ThreadPool threadPool, + NodeStateManager nodeStateManager, + ADTaskProfileRunner taskProfileRunner ) { - this.client = client; - this.xContentRegistry = xContentRegistry; - this.detectionIndices = detectionIndices; + super( + adTaskCacheManager, + clusterService, + client, + DETECTION_STATE_INDEX, + ADTaskType.REALTIME_TASK_TYPES, + ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES, + Collections.emptyList(), + detectionIndices, + nodeStateManager, + AnalysisType.AD, + xContentRegistry, + DETECTOR_ID_FIELD, + MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + settings, + threadPool, + ALL_AD_RESULTS_INDEX_PATTERN, + AD_BATCH_TASK_THREAD_POOL_NAME, + DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, + TaskState.STOPPED + ); + this.nodeFilter = nodeFilter; - this.clusterService = clusterService; - this.adTaskCacheManager = adTaskCacheManager; this.hashRing = hashRing; - this.maxOldAdTaskDocsPerDetector = MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.get(settings); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, it -> maxOldAdTaskDocsPerDetector = it); - this.pieceIntervalSeconds = BATCH_TASK_PIECE_INTERVAL_SECONDS.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(BATCH_TASK_PIECE_INTERVAL_SECONDS, it -> pieceIntervalSeconds = it); - this.deleteADResultWhenDeleteDetector = DELETE_AD_RESULT_WHEN_DELETE_DETECTOR.get(settings); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, it -> deleteADResultWhenDeleteDetector = it); - this.maxAdBatchTaskPerNode = MAX_BATCH_TASK_PER_NODE.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BATCH_TASK_PER_NODE, it -> maxAdBatchTaskPerNode = it); @@ -257,83 +223,10 @@ public ADTaskManager( clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_REQUEST_TIMEOUT, it -> { transportRequestOptions = TransportRequestOptions.builder().withType(TransportRequestOptions.Type.REG).withTimeout(it).build(); }); - this.threadPool = threadPool; + this.checkingTaskSlot = new Semaphore(1); this.scaleEntityTaskLane = new Semaphore(1); - } - - /** - * Start detector. Will create schedule job for realtime detector, - * and start AD task for historical detector. - * - * @param detectorId detector id - * @param detectionDateRange historical analysis date range - * @param handler anomaly detector job action handler - * @param user user - * @param transportService transport service - * @param context thread context - * @param listener action listener - */ - public void startDetector( - String detectorId, - DateRange detectionDateRange, - IndexAnomalyDetectorJobActionHandler handler, - User user, - TransportService transportService, - ThreadContext.StoredContext context, - ActionListener listener - ) { - // upgrade index mapping of AD default indices - detectionIndices.update(); - - getDetector(detectorId, (detector) -> { - if (!detector.isPresent()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - - // Validate if detector is ready to start. Will return null if ready to start. - String errorMessage = validateDetector(detector.get()); - if (errorMessage != null) { - listener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); - return; - } - String resultIndex = detector.get().getCustomResultIndex(); - if (resultIndex == null) { - startRealtimeOrHistoricalDetection(detectionDateRange, handler, user, transportService, listener, detector); - return; - } - context.restore(); - detectionIndices - .initCustomResultIndexAndExecute( - resultIndex, - () -> startRealtimeOrHistoricalDetection(detectionDateRange, handler, user, transportService, listener, detector), - listener - ); - - }, listener); - } - - private void startRealtimeOrHistoricalDetection( - DateRange detectionDateRange, - IndexAnomalyDetectorJobActionHandler handler, - User user, - TransportService transportService, - ActionListener listener, - Optional detector - ) { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - if (detectionDateRange == null) { - // start realtime job - handler.startAnomalyDetectorJob(detector.get(), listener); - } else { - // start historical analysis task - forwardApplyForTaskSlotsRequestToLeadNode(detector.get(), detectionDateRange, user, transportService, listener); - } - } catch (Exception e) { - logger.error("Failed to stash context", e); - listener.onFailure(e); - } + this.taskProfileRunner = taskProfileRunner; } /** @@ -344,21 +237,22 @@ private void startRealtimeOrHistoricalDetection( * 3. Then coordinating node will choose one data node with least load as work * node and dispatch historical analysis to it. * - * @param detector detector + * @param config config accessor * @param detectionDateRange detection date range * @param user user * @param transportService transport service * @param listener action listener */ - protected void forwardApplyForTaskSlotsRequestToLeadNode( - AnomalyDetector detector, + @Override + public void startHistorical( + Config config, DateRange detectionDateRange, User user, TransportService transportService, ActionListener listener ) { ForwardADTaskRequest forwardADTaskRequest = new ForwardADTaskRequest( - detector, + (AnomalyDetector) config, detectionDateRange, user, ADTaskAction.APPLY_FOR_TASK_SLOTS @@ -379,7 +273,7 @@ public void forwardRequestToLeadNode( TransportService transportService, ActionListener listener ) { - hashRing.buildAndGetOwningNodeWithSameLocalAdVersion(AD_TASK_LEAD_NODE_MODEL_ID, node -> { + hashRing.buildAndGetOwningNodeWithSameLocalVersion(AD_TASK_LEAD_NODE_MODEL_ID, node -> { if (!node.isPresent()) { listener.onFailure(new ResourceNotFoundException("Can't find AD task lead node")); return; @@ -414,7 +308,7 @@ public void startHistoricalAnalysis( ActionListener listener ) { String detectorId = detector.getId(); - hashRing.buildAndGetOwningNodeWithSameLocalAdVersion(detectorId, owningNode -> { + hashRing.buildAndGetOwningNodeWithSameLocalVersion(detectorId, owningNode -> { if (!owningNode.isPresent()) { logger.debug("Can't find eligible node to run as AD task's coordinating node"); listener.onFailure(new OpenSearchStatusException("No eligible node to run detector", RestStatus.INTERNAL_SERVER_ERROR)); @@ -467,7 +361,7 @@ protected void forwardDetectRequestToCoordinatingNode( DiscoveryNode node, ActionListener listener ) { - Version adVersion = hashRing.getAdVersion(node.getId()); + Version adVersion = hashRing.getVersion(node.getId()); transportService .sendRequest( node, @@ -570,15 +464,15 @@ public void checkTaskSlots( checkingTaskSlot.release(1); logger.debug("Release checking task slot semaphore on lead node for detector {}", detectorId); }); - hashRing.getNodesWithSameLocalAdVersion(nodes -> { + hashRing.getNodesWithSameLocalVersion(nodes -> { int maxAdTaskSlots = nodes.length * maxAdBatchTaskPerNode; - ADStatsRequest adStatsRequest = new ADStatsRequest(nodes); + StatsRequest adStatsRequest = new StatsRequest(nodes); adStatsRequest .addAll(ImmutableSet.of(AD_USED_BATCH_TASK_SLOT_COUNT.getName(), AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName())); client.execute(ADStatsNodesAction.INSTANCE, adStatsRequest, ActionListener.wrap(adStatsResponse -> { int totalUsedTaskSlots = 0; // Total entity tasks running on worker nodes int totalAssignedTaskSlots = 0; // Total assigned task slots on coordinating nodes - for (ADStatsNodeResponse response : adStatsResponse.getNodes()) { + for (StatsNodeResponse response : adStatsResponse.getNodes()) { totalUsedTaskSlots += (int) response.getStatsMap().get(AD_USED_BATCH_TASK_SLOT_COUNT.getName()); totalAssignedTaskSlots += (int) response.getStatsMap().get(AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName()); } @@ -704,418 +598,32 @@ private DiscoveryNode getCoordinatingNode(ADTask adTask) { return targetNode; } - /** - * Start anomaly detector. - * For historical analysis, this method will be called on coordinating node. - * For realtime task, we won't know AD job coordinating node until AD job starts. So - * this method will be called on vanilla node. - * - * Will init task index if not exist and write new AD task to index. If task index - * exists, will check if there is task running. If no running task, reset old task - * as not latest and clean old tasks which exceeds max old task doc limitation. - * Then find out node with least load and dispatch task to that node(worker node). - * - * @param detector anomaly detector - * @param detectionDateRange detection date range - * @param user user - * @param transportService transport service - * @param listener action listener - */ - public void startDetector( - AnomalyDetector detector, - DateRange detectionDateRange, - User user, - TransportService transportService, - ActionListener listener - ) { - try { - if (detectionIndices.doesStateIndexExist()) { - // If detection index exist, check if latest AD task is running - getAndExecuteOnLatestDetectorLevelTask(detector.getId(), getADTaskTypes(detectionDateRange), (adTask) -> { - if (!adTask.isPresent() || adTask.get().isDone()) { - updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); - } else { - listener.onFailure(new OpenSearchStatusException(DETECTOR_IS_RUNNING, RestStatus.BAD_REQUEST)); - } - }, transportService, true, listener); - } else { - // If detection index doesn't exist, create index and execute detector. - detectionIndices.initStateIndex(ActionListener.wrap(r -> { - if (r.isAcknowledged()) { - logger.info("Created {} with mappings.", DETECTION_STATE_INDEX); - updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); - } else { - String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); - logger.warn(error); - listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); - } - }, e -> { - if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { - updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); - } else { - logger.error("Failed to init anomaly detection state index", e); - listener.onFailure(e); - } - })); - } - } catch (Exception e) { - logger.error("Failed to start detector " + detector.getId(), e); - listener.onFailure(e); - } - } - - private ADTaskType getADTaskType(AnomalyDetector detector, DateRange detectionDateRange) { + @Override + protected TaskType getTaskType(Config config, DateRange detectionDateRange, boolean runOnce) { if (detectionDateRange == null) { - return detector.isHighCardinality() ? ADTaskType.REALTIME_HC_DETECTOR : ADTaskType.REALTIME_SINGLE_ENTITY; + return config.isHighCardinality() ? ADTaskType.REALTIME_HC_DETECTOR : ADTaskType.REALTIME_SINGLE_ENTITY; } else { - return detector.isHighCardinality() ? ADTaskType.HISTORICAL_HC_DETECTOR : ADTaskType.HISTORICAL_SINGLE_ENTITY; - } - } - - private List getADTaskTypes(DateRange detectionDateRange) { - return getADTaskTypes(detectionDateRange, false); - } - - /** - * Get list of task types. - * 1. If detection date range is null, will return all realtime task types - * 2. If detection date range is not null, will return all historical detector level tasks types - * if resetLatestTaskStateFlag is true; otherwise return all historical tasks types include - * HC entity level task type. - * @param detectionDateRange detection date range - * @param resetLatestTaskStateFlag reset latest task state or not - * @return list of AD task types - */ - private List getADTaskTypes(DateRange detectionDateRange, boolean resetLatestTaskStateFlag) { - if (detectionDateRange == null) { - return REALTIME_TASK_TYPES; - } else { - if (resetLatestTaskStateFlag) { - // return all task types include HC entity task to make sure we can reset all tasks latest flag - return ALL_HISTORICAL_TASK_TYPES; - } else { - return HISTORICAL_DETECTOR_TASK_TYPES; - } - } - } - - /** - * Stop detector. - * For realtime detector, will set detector job as disabled. - * For historical detector, will set its AD task as cancelled. - * - * @param detectorId detector id - * @param historical stop historical analysis or not - * @param handler AD job action handler - * @param user user - * @param transportService transport service - * @param listener action listener - */ - public void stopDetector( - String detectorId, - boolean historical, - IndexAnomalyDetectorJobActionHandler handler, - User user, - TransportService transportService, - ActionListener listener - ) { - getDetector(detectorId, (detector) -> { - if (!detector.isPresent()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - if (historical) { - // stop historical analyis - getAndExecuteOnLatestDetectorLevelTask( - detectorId, - HISTORICAL_DETECTOR_TASK_TYPES, - (task) -> stopHistoricalAnalysis(detectorId, task, user, listener), - transportService, - false,// don't need to reset task state when stop detector - listener - ); - } else { - // stop realtime detector job - handler.stopAnomalyDetectorJob(detectorId, listener); - } - }, listener); - } - - /** - * Get anomaly detector and execute consumer function. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param function consumer function - * @param listener action listener - * @param action listener response type - */ - public void getDetector(String detectorId, Consumer> function, ActionListener listener) { - GetRequest getRequest = new GetRequest(CommonName.CONFIG_INDEX, detectorId); - client.get(getRequest, ActionListener.wrap(response -> { - if (!response.isExists()) { - function.accept(Optional.empty()); - return; - } - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); - - function.accept(Optional.of(detector)); - } catch (Exception e) { - String message = "Failed to parse anomaly detector " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - }, exception -> { - logger.error("Failed to get detector " + detectorId, exception); - listener.onFailure(exception); - })); - } - - /** - * Get latest AD task and execute consumer function. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param adTaskTypes AD task types - * @param function consumer function - * @param transportService transport service - * @param resetTaskState reset task state or not - * @param listener action listener - * @param action listener response type - */ - public void getAndExecuteOnLatestDetectorLevelTask( - String detectorId, - List adTaskTypes, - Consumer> function, - TransportService transportService, - boolean resetTaskState, - ActionListener listener - ) { - getAndExecuteOnLatestADTask(detectorId, null, null, adTaskTypes, function, transportService, resetTaskState, listener); - } - - /** - * Get one latest AD task and execute consumer function. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param parentTaskId parent task id - * @param entity entity value - * @param adTaskTypes AD task types - * @param function consumer function - * @param transportService transport service - * @param resetTaskState reset task state or not - * @param listener action listener - * @param action listener response type - */ - public void getAndExecuteOnLatestADTask( - String detectorId, - String parentTaskId, - Entity entity, - List adTaskTypes, - Consumer> function, - TransportService transportService, - boolean resetTaskState, - ActionListener listener - ) { - getAndExecuteOnLatestADTasks(detectorId, parentTaskId, entity, adTaskTypes, (taskList) -> { - if (taskList != null && taskList.size() > 0) { - function.accept(Optional.ofNullable(taskList.get(0))); - } else { - function.accept(Optional.empty()); - } - }, transportService, resetTaskState, 1, listener); - } - - /** - * Get latest AD tasks and execute consumer function. - * If resetTaskState is true, will collect latest task's profile data from all data nodes. If no data - * node running the latest task, will reset the task state as STOPPED; otherwise, check if there is - * any stale running entities(entity exists in coordinating node cache but no task running on worker - * node) and clean up. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param parentTaskId parent task id - * @param entity entity value - * @param adTaskTypes AD task types - * @param function consumer function - * @param transportService transport service - * @param resetTaskState reset task state or not - * @param size return how many AD tasks - * @param listener action listener - * @param response type of action listener - */ - public void getAndExecuteOnLatestADTasks( - String detectorId, - String parentTaskId, - Entity entity, - List adTaskTypes, - Consumer> function, - TransportService transportService, - boolean resetTaskState, - int size, - ActionListener listener - ) { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); - if (parentTaskId != null) { - query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, parentTaskId)); - } - if (adTaskTypes != null && adTaskTypes.size() > 0) { - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(adTaskTypes))); + return config.isHighCardinality() ? ADTaskType.HISTORICAL_HC_DETECTOR : ADTaskType.HISTORICAL_SINGLE_ENTITY; } - if (entity != null && !isNullOrEmpty(entity.getAttributes())) { - String path = "entity"; - String entityKeyFieldName = path + ".name"; - String entityValueFieldName = path + ".value"; - - for (Map.Entry attribute : entity.getAttributes().entrySet()) { - BoolQueryBuilder entityBoolQuery = new BoolQueryBuilder(); - TermQueryBuilder entityKeyFilterQuery = QueryBuilders.termQuery(entityKeyFieldName, attribute.getKey()); - TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValueFieldName, attribute.getValue()); - - entityBoolQuery.filter(entityKeyFilterQuery).filter(entityValueFilterQuery); - NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(path, entityBoolQuery, ScoreMode.None); - query.filter(nestedQueryBuilder); - } - } - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(query).sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC).size(size); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.source(sourceBuilder); - searchRequest.indices(DETECTION_STATE_INDEX); - - client.search(searchRequest, ActionListener.wrap(r -> { - // https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/359#discussion_r558653132 - // getTotalHits will be null when we track_total_hits is false in the query request. - // Add more checking here to cover some unknown cases. - List adTasks = new ArrayList<>(); - if (r == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { - // don't throw exception here as consumer functions need to handle missing task - // in different way. - function.accept(adTasks); - return; - } - - Iterator iterator = r.getHits().iterator(); - while (iterator.hasNext()) { - SearchHit searchHit = iterator.next(); - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - ADTask adTask = ADTask.parse(parser, searchHit.getId()); - adTasks.add(adTask); - } catch (Exception e) { - String message = "Failed to parse AD task for detector " + detectorId + ", task id " + searchHit.getId(); - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } - if (resetTaskState) { - resetLatestDetectorTaskState(adTasks, function, transportService, listener); - } else { - function.accept(adTasks); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - function.accept(new ArrayList<>()); - } else { - logger.error("Failed to search AD task for detector " + detectorId, e); - listener.onFailure(e); - } - })); } /** - * Reset latest detector task state. Will reset both historical and realtime tasks. - * [Important!] Make sure listener returns in function - * - * @param adTasks ad tasks - * @param function consumer function - * @param transportService transport service - * @param listener action listener - * @param response type of action listener - */ - private void resetLatestDetectorTaskState( - List adTasks, - Consumer> function, - TransportService transportService, - ActionListener listener - ) { - List runningHistoricalTasks = new ArrayList<>(); - List runningRealtimeTasks = new ArrayList<>(); - for (ADTask adTask : adTasks) { - if (!adTask.isEntityTask() && !adTask.isDone()) { - if (!adTask.isHistoricalTask()) { - // try to reset task state if realtime task is not ended - runningRealtimeTasks.add(adTask); - } else { - // try to reset task state if historical task not updated for 2 piece intervals - runningHistoricalTasks.add(adTask); - } - } - } - - resetHistoricalDetectorTaskState( - runningHistoricalTasks, - () -> resetRealtimeDetectorTaskState(runningRealtimeTasks, () -> function.accept(adTasks), transportService, listener), - transportService, - listener - ); - } - - private void resetRealtimeDetectorTaskState( - List runningRealtimeTasks, - ExecutorFunction function, - TransportService transportService, - ActionListener listener - ) { - if (isNullOrEmpty(runningRealtimeTasks)) { - function.execute(); - return; - } - ADTask adTask = runningRealtimeTasks.get(0); - String detectorId = adTask.getConfigId(); - GetRequest getJobRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - client.get(getJobRequest, ActionListener.wrap(r -> { - if (r.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job job = Job.parse(parser); - if (!job.isEnabled()) { - logger.debug("AD job is disabled, reset realtime task as stopped for detector {}", detectorId); - resetTaskStateAsStopped(adTask, function, transportService, listener); - } else { - function.execute(); - } - } catch (IOException e) { - logger.error(" Failed to parse AD job " + detectorId, e); - listener.onFailure(e); - } - } else { - logger.debug("AD job is not found, reset realtime task as stopped for detector {}", detectorId); - resetTaskStateAsStopped(adTask, function, transportService, listener); - } - }, e -> { - logger.error("Fail to get AD realtime job for detector " + detectorId, e); - listener.onFailure(e); - })); - } - - private void resetHistoricalDetectorTaskState( - List runningHistoricalTasks, + * If resetTaskState is true, will collect latest task's profile data from all data nodes. If no data + * node running the latest task, will reset the task state as STOPPED; otherwise, check if there is + * any stale running entities(entity exists in coordinating node cache but no task running on worker + * node) and clean up. + */ + protected void resetHistoricalConfigTaskState( + List runningHistoricalTasks, ExecutorFunction function, TransportService transportService, ActionListener listener ) { - if (isNullOrEmpty(runningHistoricalTasks)) { + if (ParseUtils.isNullOrEmpty(runningHistoricalTasks)) { function.execute(); return; } - ADTask adTask = runningHistoricalTasks.get(0); + ADTask adTask = (ADTask) runningHistoricalTasks.get(0); // If AD task is still running, but its last updated time not refreshed for 2 piece intervals, we will get // task profile to check if it's really running. If task not running, reset state as STOPPED. // For example, ES process crashes, then all tasks running on it will stay as running. We can reset the task @@ -1126,13 +634,13 @@ private void resetHistoricalDetectorTaskState( } String taskId = adTask.getTaskId(); AnomalyDetector detector = adTask.getDetector(); - getADTaskProfile(adTask, ActionListener.wrap(taskProfile -> { + taskProfileRunner.getTaskProfile(adTask, ActionListener.wrap(taskProfile -> { boolean taskStopped = isTaskStopped(taskId, detector, taskProfile); if (taskStopped) { logger.debug("Reset task state as stopped, task id: {}", adTask.getTaskId()); if (taskProfile.getTaskId() == null // This means coordinating node doesn't have HC detector cache && detector.isHighCardinality() - && !isNullOrEmpty(taskProfile.getEntityTaskProfiles())) { + && !ParseUtils.isNullOrEmpty(taskProfile.getEntityTaskProfiles())) { // If coordinating node restarted, HC detector cache on it will be gone. But worker node still // runs entity tasks, we'd better stop these entity tasks to clean up resource earlier. stopHistoricalAnalysis(adTask.getConfigId(), Optional.of(adTask), null, ActionListener.wrap(r -> { @@ -1151,7 +659,8 @@ private void resetHistoricalDetectorTaskState( if (ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { // Check if any running entity not run on worker node. If yes, we need to remove it // and poll next entity from pending entity queue and run it. - if (!isNullOrEmpty(taskProfile.getRunningEntities()) && hcBatchTaskExpired(taskProfile.getLatestHCTaskRunTime())) { + if (!ParseUtils.isNullOrEmpty(taskProfile.getRunningEntities()) + && hcBatchTaskExpired(taskProfile.getLatestHCTaskRunTime())) { List runningTasksInCoordinatingNodeCache = new ArrayList<>(taskProfile.getRunningEntities()); List runningTasksOnWorkerNode = new ArrayList<>(); if (taskProfile.getEntityTaskProfiles() != null && taskProfile.getEntityTaskProfiles().size() > 0) { @@ -1196,8 +705,8 @@ private boolean isTaskStopped(String taskId, AnomalyDetector detector, ADTaskPro } if (detector.isHighCardinality() && taskProfile.getTotalEntitiesInited() - && isNullOrEmpty(taskProfile.getRunningEntities()) - && isNullOrEmpty(taskProfile.getEntityTaskProfiles()) + && ParseUtils.isNullOrEmpty(taskProfile.getRunningEntities()) + && ParseUtils.isNullOrEmpty(taskProfile.getEntityTaskProfiles()) && hcBatchTaskExpired(taskProfile.getLatestHCTaskRunTime())) { logger.debug("AD task not running for HC detector {}, task {}", detectorId, taskId); return true; @@ -1212,7 +721,7 @@ public boolean hcBatchTaskExpired(Long latestHCTaskRunTime) { return latestHCTaskRunTime + HC_BATCH_TASK_CACHE_TIMEOUT_IN_MILLIS < Instant.now().toEpochMilli(); } - private void stopHistoricalAnalysis(String detectorId, Optional adTask, User user, ActionListener listener) { + public void stopHistoricalAnalysis(String detectorId, Optional adTask, User user, ActionListener listener) { if (!adTask.isPresent()) { listener.onFailure(new ResourceNotFoundException(detectorId, "Detector not started")); return; @@ -1224,7 +733,7 @@ private void stopHistoricalAnalysis(String detectorId, Optional adTask, } String taskId = adTask.get().getTaskId(); - DiscoveryNode[] dataNodes = hashRing.getNodesWithSameLocalAdVersion(); + DiscoveryNode[] dataNodes = hashRing.getNodesWithSameLocalVersion(); String userName = user == null ? null : user.getName(); ADCancelTaskRequest cancelTaskRequest = new ADCancelTaskRequest(detectorId, taskId, userName, dataNodes); @@ -1236,58 +745,15 @@ private void stopHistoricalAnalysis(String detectorId, Optional adTask, })); } - private boolean lastUpdateTimeOfHistoricalTaskExpired(ADTask adTask) { + private boolean lastUpdateTimeOfHistoricalTaskExpired(TimeSeriesTask adTask) { // Wait at least 10 seconds. Piece interval seconds is dynamic setting, user could change it to a smaller value. int waitingTime = Math.max(2 * pieceIntervalSeconds, 10); return adTask.getLastUpdateTime().plus(waitingTime, ChronoUnit.SECONDS).isBefore(Instant.now()); } - private void resetTaskStateAsStopped( - ADTask adTask, - ExecutorFunction function, - TransportService transportService, - ActionListener listener - ) { - cleanDetectorCache(adTask, transportService, () -> { - String taskId = adTask.getTaskId(); - Map updatedFields = ImmutableMap.of(STATE_FIELD, TaskState.STOPPED.name()); - updateADTask(taskId, updatedFields, ActionListener.wrap(r -> { - adTask.setState(TaskState.STOPPED.name()); - if (function != null) { - function.execute(); - } - // For realtime anomaly detection, we only create detector level task, no entity level realtime task. - if (ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { - // Reset running entity tasks as STOPPED - resetEntityTasksAsStopped(taskId); - } - }, e -> { - logger.error("Failed to update task state as STOPPED for task " + taskId, e); - listener.onFailure(e); - })); - }, listener); - } - - private void resetEntityTasksAsStopped(String detectorTaskId) { - UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); - updateByQueryRequest.indices(DETECTION_STATE_INDEX); - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, detectorTaskId)); - query.filter(new TermQueryBuilder(TASK_TYPE_FIELD, ADTaskType.HISTORICAL_HC_ENTITY.name())); - query.filter(new TermsQueryBuilder(STATE_FIELD, NOT_ENDED_STATES)); - updateByQueryRequest.setQuery(query); - updateByQueryRequest.setRefresh(true); - String script = String.format(Locale.ROOT, "ctx._source.%s='%s';", STATE_FIELD, TaskState.STOPPED.name()); - updateByQueryRequest.setScript(new Script(script)); - - client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { - List bulkFailures = r.getBulkFailures(); - if (isNullOrEmpty(bulkFailures)) { - logger.debug("Updated {} child entity tasks state for detector task {}", r.getUpdated(), detectorTaskId); - } else { - logger.error("Failed to update child entity task's state for detector task {} ", detectorTaskId); - } - }, e -> logger.error("Exception happened when update child entity task's state for detector task " + detectorTaskId, e))); + @Override + protected boolean isHistoricalHCTask(TimeSeriesTask task) { + return ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(task.getTaskType()); } /** @@ -1305,8 +771,9 @@ private void resetEntityTasksAsStopped(String detectorTaskId) { * @param listener action listener * @param response type of listener */ - public void cleanDetectorCache( - ADTask adTask, + @Override + public void cleanConfigCache( + TimeSeriesTask adTask, TransportService transportService, ExecutorFunction function, ActionListener listener @@ -1315,15 +782,12 @@ public void cleanDetectorCache( String detectorId = adTask.getConfigId(); String taskId = adTask.getTaskId(); try { - forwardADTaskToCoordinatingNode( - adTask, - ADTaskAction.CLEAN_CACHE, - transportService, - ActionListener.wrap(r -> { function.execute(); }, e -> { - logger.error("Failed to clear detector cache on coordinating node " + coordinatingNode, e); - listener.onFailure(e); - }) - ); + forwardADTaskToCoordinatingNode((ADTask) adTask, ADTaskAction.CLEAN_CACHE, transportService, ActionListener.wrap(r -> { + function.execute(); + }, e -> { + logger.error("Failed to clear detector cache on coordinating node " + coordinatingNode, e); + listener.onFailure(e); + })); } catch (ResourceNotFoundException e) { logger .warn( @@ -1342,161 +806,27 @@ public void cleanDetectorCache( protected void cleanDetectorCache(ADTask adTask, TransportService transportService, ExecutorFunction function) { String detectorId = adTask.getConfigId(); String taskId = adTask.getTaskId(); - cleanDetectorCache(adTask, transportService, function, ActionListener.wrap(r -> { + cleanConfigCache(adTask, transportService, function, ActionListener.wrap(r -> { logger.debug("Successfully cleaned cache for detector {}, task {}", detectorId, taskId); }, e -> { logger.error("Failed to clean cache for detector " + detectorId + ", task " + taskId, e); })); } - /** - * Get latest historical AD task profile. - * Will not reset task state in this method. - * - * @param detectorId detector id - * @param transportService transport service - * @param profile detector profile - * @param listener action listener - */ - public void getLatestHistoricalTaskProfile( - String detectorId, - TransportService transportService, - DetectorProfile profile, - ActionListener listener - ) { - getAndExecuteOnLatestADTask(detectorId, null, null, HISTORICAL_DETECTOR_TASK_TYPES, adTask -> { - if (adTask.isPresent()) { - getADTaskProfile(adTask.get(), ActionListener.wrap(adTaskProfile -> { - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - profileBuilder.adTaskProfile(adTaskProfile); - DetectorProfile detectorProfile = profileBuilder.build(); - detectorProfile.merge(profile); - listener.onResponse(detectorProfile); - }, e -> { - logger.error("Failed to get AD task profile for task " + adTask.get().getTaskId(), e); - listener.onFailure(e); - })); - } else { - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - listener.onResponse(profileBuilder.build()); - } - }, transportService, false, listener); - } - - /** - * Get AD task profile. - * @param adDetectorLevelTask detector level task - * @param listener action listener - */ - private void getADTaskProfile(ADTask adDetectorLevelTask, ActionListener listener) { - String detectorId = adDetectorLevelTask.getConfigId(); - - hashRing.getAllEligibleDataNodesWithKnownAdVersion(dataNodes -> { - ADTaskProfileRequest adTaskProfileRequest = new ADTaskProfileRequest(detectorId, dataNodes); - client.execute(ADTaskProfileAction.INSTANCE, adTaskProfileRequest, ActionListener.wrap(response -> { - if (response.hasFailures()) { - listener.onFailure(response.failures().get(0)); - return; - } - - List adEntityTaskProfiles = new ArrayList<>(); - ADTaskProfile detectorTaskProfile = new ADTaskProfile(adDetectorLevelTask); - for (ADTaskProfileNodeResponse node : response.getNodes()) { - ADTaskProfile taskProfile = node.getAdTaskProfile(); - if (taskProfile != null) { - if (taskProfile.getNodeId() != null) { - // HC detector: task profile from coordinating node - // Single entity detector: task profile from worker node - detectorTaskProfile.setTaskId(taskProfile.getTaskId()); - detectorTaskProfile.setShingleSize(taskProfile.getShingleSize()); - detectorTaskProfile.setRcfTotalUpdates(taskProfile.getRcfTotalUpdates()); - detectorTaskProfile.setThresholdModelTrained(taskProfile.getThresholdModelTrained()); - detectorTaskProfile.setThresholdModelTrainingDataSize(taskProfile.getThresholdModelTrainingDataSize()); - detectorTaskProfile.setModelSizeInBytes(taskProfile.getModelSizeInBytes()); - detectorTaskProfile.setNodeId(taskProfile.getNodeId()); - detectorTaskProfile.setTotalEntitiesCount(taskProfile.getTotalEntitiesCount()); - detectorTaskProfile.setDetectorTaskSlots(taskProfile.getDetectorTaskSlots()); - detectorTaskProfile.setPendingEntitiesCount(taskProfile.getPendingEntitiesCount()); - detectorTaskProfile.setRunningEntitiesCount(taskProfile.getRunningEntitiesCount()); - detectorTaskProfile.setRunningEntities(taskProfile.getRunningEntities()); - detectorTaskProfile.setAdTaskType(taskProfile.getAdTaskType()); - } - if (taskProfile.getEntityTaskProfiles() != null) { - adEntityTaskProfiles.addAll(taskProfile.getEntityTaskProfiles()); - } - } - } - if (adEntityTaskProfiles != null && adEntityTaskProfiles.size() > 0) { - detectorTaskProfile.setEntityTaskProfiles(adEntityTaskProfiles); - } - listener.onResponse(detectorTaskProfile); - }, e -> { - logger.error("Failed to get task profile for task " + adDetectorLevelTask.getTaskId(), e); - listener.onFailure(e); - })); - }, listener); - - } - - private String validateDetector(AnomalyDetector detector) { - String error = null; - if (detector.getFeatureAttributes().size() == 0) { - error = "Can't start detector job as no features configured"; - } else if (detector.getEnabledFeatureIds().size() == 0) { - error = "Can't start detector job as no enabled features configured"; - } - return error; - } - - private void updateLatestFlagOfOldTasksAndCreateNewTask( - AnomalyDetector detector, - DateRange detectionDateRange, - User user, - ActionListener listener - ) { - UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); - updateByQueryRequest.indices(DETECTION_STATE_INDEX); - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detector.getId())); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); - // make sure we reset all latest task as false when user switch from single entity to HC, vice versa. - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(getADTaskTypes(detectionDateRange, true)))); - updateByQueryRequest.setQuery(query); - updateByQueryRequest.setRefresh(true); - String script = String.format(Locale.ROOT, "ctx._source.%s=%s;", IS_LATEST_FIELD, false); - updateByQueryRequest.setScript(new Script(script)); - - client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { - List bulkFailures = r.getBulkFailures(); - if (bulkFailures.isEmpty()) { - // Realtime AD coordinating node is chosen by job scheduler, we won't know it until realtime AD job - // runs. Just set realtime AD coordinating node as null here, and AD job runner will reset correct - // coordinating node once realtime job starts. - // For historical analysis, this method will be called on coordinating node, so we can set coordinating - // node as local node. - String coordinatingNode = detectionDateRange == null ? null : clusterService.localNode().getId(); - createNewADTask(detector, detectionDateRange, user, coordinatingNode, listener); - } else { - logger.error("Failed to update old task's state for detector: {}, response: {} ", detector.getId(), r.toString()); - listener.onFailure(bulkFailures.get(0).getCause()); - } - }, e -> { - logger.error("Failed to reset old tasks as not latest for detector " + detector.getId(), e); - listener.onFailure(e); - })); - } - - private void createNewADTask( - AnomalyDetector detector, + @Override + protected void createNewTask( + Config config, DateRange detectionDateRange, + boolean runOnce, User user, String coordinatingNode, - ActionListener listener + TaskState initialState, + ActionListener listener ) { String userName = user == null ? null : user.getName(); Instant now = Instant.now(); - String taskType = getADTaskType(detector, detectionDateRange).name(); + String taskType = getTaskType(config, detectionDateRange, runOnce).name(); ADTask adTask = new ADTask.Builder() - .configId(detector.getId()) - .detector(detector) + .configId(config.getId()) + .detector((AnomalyDetector) config) .isLatest(true) .taskType(taskType) .executionStartTime(now) @@ -1510,57 +840,38 @@ private void createNewADTask( .user(user) .build(); - createADTaskDirectly( + createTaskDirectly( adTask, - r -> onIndexADTaskResponse( + r -> onIndexConfigTaskResponse( r, adTask, - (response, delegatedListener) -> cleanOldAdTaskDocs(response, adTask, delegatedListener), + (response, delegatedListener) -> cleanOldConfigTaskDocs( + response, + adTask, + (indexResponse) -> (T) new JobResponse(indexResponse.getId()), + delegatedListener + ), listener ), listener ); } - /** - * Create AD task directly without checking index exists of not. - * [Important!] Make sure listener returns in function - * - * @param adTask AD task - * @param function consumer function - * @param listener action listener - * @param action listener response type - */ - public void createADTaskDirectly(ADTask adTask, Consumer function, ActionListener listener) { - IndexRequest request = new IndexRequest(DETECTION_STATE_INDEX); - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - request - .source(adTask.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(request, ActionListener.wrap(r -> function.accept(r), e -> { - logger.error("Failed to create AD task for detector " + adTask.getConfigId(), e); - listener.onFailure(e); - })); - } catch (Exception e) { - logger.error("Failed to create AD task for detector " + adTask.getConfigId(), e); - listener.onFailure(e); - } - } - - private void onIndexADTaskResponse( + @Override + protected void onIndexConfigTaskResponse( IndexResponse response, ADTask adTask, - BiConsumer> function, - ActionListener listener + BiConsumer> function, + ActionListener listener ) { if (response == null || response.getResult() != CREATED) { - String errorMsg = getShardsFailure(response); + String errorMsg = ExceptionUtil.getShardsFailure(response); listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); return; } adTask.setTaskId(response.getId()); - ActionListener delegatedListener = ActionListener.wrap(r -> { listener.onResponse(r); }, e -> { - handleADTaskException(adTask, e); + ActionListener delegatedListener = ActionListener.wrap(r -> { listener.onResponse(r); }, e -> { + handleTaskException(adTask, e); if (e instanceof DuplicateTaskException) { listener.onFailure(new OpenSearchStatusException(DETECTOR_IS_RUNNING, RestStatus.BAD_REQUEST)); } else { @@ -1569,17 +880,17 @@ private void onIndexADTaskResponse( // ADTaskManager#initRealtimeTaskCacheAndCleanupStaleCache for details. Here the // realtime task cache not inited yet when create AD task, so no need to cleanup. if (adTask.isHistoricalTask()) { - adTaskCacheManager.removeHistoricalTaskCache(adTask.getConfigId()); + taskCacheManager.removeHistoricalTaskCache(adTask.getConfigId()); } listener.onFailure(e); } }); try { - // Put detector id in cache. If detector id already in cache, will throw + // Put config id in cache. If config id already in cache, will throw // DuplicateTaskException. This is to solve race condition when user send - // multiple start request for one historical detector. + // multiple start request for one historical run. if (adTask.isHistoricalTask()) { - adTaskCacheManager.add(adTask.getConfigId(), adTask); + taskCacheManager.add(adTask.getConfigId(), adTask); } } catch (Exception e) { delegatedListener.onFailure(e); @@ -1590,126 +901,13 @@ private void onIndexADTaskResponse( } } - private void cleanOldAdTaskDocs(IndexResponse response, ADTask adTask, ActionListener delegatedListener) { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, adTask.getConfigId())); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, false)); - - if (adTask.isHistoricalTask()) { - // If historical task, only delete detector level task. It may take longer time to delete entity tasks. - // We will delete child task (entity task) of detector level task in hourly cron job. - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); - } else { - // We don't have entity level task for realtime detection, so will delete all tasks. - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(REALTIME_TASK_TYPES))); - } - - SearchRequest searchRequest = new SearchRequest(); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder - .query(query) - .sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC) - // Search query "from" starts from 0. - .from(maxOldAdTaskDocsPerDetector) - .size(MAX_OLD_AD_TASK_DOCS); - searchRequest.source(sourceBuilder).indices(DETECTION_STATE_INDEX); - String detectorId = adTask.getConfigId(); - - deleteTaskDocs(detectorId, searchRequest, () -> { - if (adTask.isHistoricalTask()) { - // run batch result action for historical detection - runBatchResultAction(response, adTask, delegatedListener); - } else { - // return response directly for realtime detection - JobResponse anomalyDetectorJobResponse = new JobResponse(response.getId()); - delegatedListener.onResponse(anomalyDetectorJobResponse); - } - }, delegatedListener); - } - - protected void deleteTaskDocs( - String detectorId, - SearchRequest searchRequest, - ExecutorFunction function, + @Override + protected void runBatchResultAction( + IndexResponse response, + ADTask adTask, + ResponseTransformer responseTransformer, ActionListener listener ) { - ActionListener searchListener = ActionListener.wrap(r -> { - Iterator iterator = r.getHits().iterator(); - if (iterator.hasNext()) { - BulkRequest bulkRequest = new BulkRequest(); - while (iterator.hasNext()) { - SearchHit searchHit = iterator.next(); - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - ADTask adTask = ADTask.parse(parser, searchHit.getId()); - logger.debug("Delete old task: {} of detector: {}", adTask.getTaskId(), adTask.getConfigId()); - bulkRequest.add(new DeleteRequest(DETECTION_STATE_INDEX).id(adTask.getTaskId())); - } catch (Exception e) { - listener.onFailure(e); - } - } - client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { - logger.info("Old AD tasks deleted for detector {}", detectorId); - BulkItemResponse[] bulkItemResponses = res.getItems(); - if (bulkItemResponses != null && bulkItemResponses.length > 0) { - for (BulkItemResponse bulkItemResponse : bulkItemResponses) { - if (!bulkItemResponse.isFailed()) { - logger.debug("Add detector task into cache. Task id: {}", bulkItemResponse.getId()); - // add deleted task in cache and delete its child tasks and AD results - adTaskCacheManager.addDeletedTask(bulkItemResponse.getId()); - } - } - } - // delete child tasks and AD results of this task - cleanChildTasksAndADResultsOfDeletedTask(); - - function.execute(); - }, e -> { - logger.warn("Failed to clean AD tasks for detector " + detectorId, e); - listener.onFailure(e); - })); - } else { - function.execute(); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - function.execute(); - } else { - listener.onFailure(e); - } - }); - - client.search(searchRequest, searchListener); - } - - /** - * Poll deleted detector task from cache and delete its child tasks and AD results. - */ - public void cleanChildTasksAndADResultsOfDeletedTask() { - if (!adTaskCacheManager.hasDeletedTask()) { - return; - } - threadPool.schedule(() -> { - String taskId = adTaskCacheManager.pollDeletedTask(); - if (taskId == null) { - return; - } - DeleteByQueryRequest deleteADResultsRequest = new DeleteByQueryRequest(ALL_AD_RESULTS_INDEX_PATTERN); - deleteADResultsRequest.setQuery(new TermsQueryBuilder(TASK_ID_FIELD, taskId)); - client.execute(DeleteByQueryAction.INSTANCE, deleteADResultsRequest, ActionListener.wrap(res -> { - logger.debug("Successfully deleted AD results of task " + taskId); - DeleteByQueryRequest deleteChildTasksRequest = new DeleteByQueryRequest(DETECTION_STATE_INDEX); - deleteChildTasksRequest.setQuery(new TermsQueryBuilder(PARENT_TASK_ID_FIELD, taskId)); - - client.execute(DeleteByQueryAction.INSTANCE, deleteChildTasksRequest, ActionListener.wrap(r -> { - logger.debug("Successfully deleted child tasks of task " + taskId); - cleanChildTasksAndADResultsOfDeletedTask(); - }, e -> { logger.error("Failed to delete child tasks of task " + taskId, e); })); - }, ex -> { logger.error("Failed to delete AD results for task " + taskId, ex); })); - }, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), AD_BATCH_TASK_THREAD_POOL_NAME); - } - - private void runBatchResultAction(IndexResponse response, ADTask adTask, ActionListener listener) { client.execute(ADBatchAnomalyResultAction.INSTANCE, new ADBatchAnomalyResultRequest(adTask), ActionListener.wrap(r -> { String remoteOrLocal = r.isRunTaskRemotely() ? "remote" : "local"; logger @@ -1720,104 +918,9 @@ private void runBatchResultAction(IndexResponse response, ADTask adTask, ActionL remoteOrLocal, r.getNodeId() ); - JobResponse anomalyDetectorJobResponse = new JobResponse(response.getId()); - listener.onResponse(anomalyDetectorJobResponse); - }, e -> listener.onFailure(e))); - } - - /** - * Handle exceptions for AD task. Update task state and record error message. - * - * @param adTask AD task - * @param e exception - */ - public void handleADTaskException(ADTask adTask, Exception e) { - // TODO: handle timeout exception - String state = TaskState.FAILED.name(); - Map updatedFields = new HashMap<>(); - if (e instanceof DuplicateTaskException) { - // If user send multiple start detector request, we will meet race condition. - // Cache manager will put first request in cache and throw DuplicateTaskException - // for the second request. We will delete the second task. - logger - .warn( - "There is already one running task for detector, detectorId:" - + adTask.getConfigId() - + ". Will delete task " - + adTask.getTaskId() - ); - deleteADTask(adTask.getTaskId()); - return; - } - if (e instanceof TaskCancelledException) { - logger.info("AD task cancelled, taskId: {}, detectorId: {}", adTask.getTaskId(), adTask.getConfigId()); - state = TaskState.STOPPED.name(); - String stoppedBy = ((TaskCancelledException) e).getCancelledBy(); - if (stoppedBy != null) { - updatedFields.put(STOPPED_BY_FIELD, stoppedBy); - } - } else { - logger.error("Failed to execute AD batch task, task id: " + adTask.getTaskId() + ", detector id: " + adTask.getConfigId(), e); - } - updatedFields.put(ERROR_FIELD, getErrorMessage(e)); - updatedFields.put(STATE_FIELD, state); - updatedFields.put(EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli()); - updateADTask(adTask.getTaskId(), updatedFields); - } - - /** - * Update AD task with specific fields. - * - * @param taskId AD task id - * @param updatedFields updated fields, key: filed name, value: new value - */ - public void updateADTask(String taskId, Map updatedFields) { - updateADTask(taskId, updatedFields, ActionListener.wrap(response -> { - if (response.status() == RestStatus.OK) { - logger.debug("Updated AD task successfully: {}, task id: {}", response.status(), taskId); - } else { - logger.error("Failed to update AD task {}, status: {}", taskId, response.status()); - } - }, e -> { logger.error("Failed to update task: " + taskId, e); })); - } - - /** - * Update AD task for specific fields. - * - * @param taskId task id - * @param updatedFields updated fields, key: filed name, value: new value - * @param listener action listener - */ - public void updateADTask(String taskId, Map updatedFields, ActionListener listener) { - UpdateRequest updateRequest = new UpdateRequest(DETECTION_STATE_INDEX, taskId); - Map updatedContent = new HashMap<>(); - updatedContent.putAll(updatedFields); - updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); - updateRequest.doc(updatedContent); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.update(updateRequest, listener); - } - - /** - * Delete AD task with task id. - * - * @param taskId AD task id - */ - public void deleteADTask(String taskId) { - deleteADTask(taskId, ActionListener.wrap(r -> { logger.info("Deleted AD task {} with status: {}", taskId, r.status()); }, e -> { - logger.error("Failed to delete AD task " + taskId, e); - })); - } - /** - * Delete AD task with task id. - * - * @param taskId AD task id - * @param listener action listener - */ - public void deleteADTask(String taskId, ActionListener listener) { - DeleteRequest deleteRequest = new DeleteRequest(DETECTION_STATE_INDEX, taskId); - client.delete(deleteRequest, listener); + listener.onResponse(responseTransformer.transform(response)); + }, e -> listener.onFailure(e))); } /** @@ -1830,7 +933,7 @@ public void deleteADTask(String taskId, ActionListener listener) * @return AD task cancellation state */ public ADTaskCancellationState cancelLocalTaskByDetectorId(String detectorId, String detectorTaskId, String reason, String userName) { - ADTaskCancellationState cancellationState = adTaskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, reason, userName); + ADTaskCancellationState cancellationState = taskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, reason, userName); logger .debug( "Cancelled AD task for detector: {}, state: {}, cancelled by: {}, reason: {}", @@ -1842,199 +945,6 @@ public ADTaskCancellationState cancelLocalTaskByDetectorId(String detectorId, St return cancellationState; } - /** - * Delete AD tasks docs. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param function AD function - * @param listener action listener - */ - public void deleteADTasks(String detectorId, ExecutorFunction function, ActionListener listener) { - DeleteByQueryRequest request = new DeleteByQueryRequest(DETECTION_STATE_INDEX); - - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); - - request.setQuery(query); - client.execute(DeleteByQueryAction.INSTANCE, request, ActionListener.wrap(r -> { - if (r.getBulkFailures() == null || r.getBulkFailures().size() == 0) { - logger.info("AD tasks deleted for detector {}", detectorId); - deleteADResultOfDetector(detectorId); - function.execute(); - } else { - listener.onFailure(new OpenSearchStatusException("Failed to delete all AD tasks", RestStatus.INTERNAL_SERVER_ERROR)); - } - }, e -> { - logger.info("Failed to delete AD tasks for " + detectorId, e); - if (e instanceof IndexNotFoundException) { - deleteADResultOfDetector(detectorId); - function.execute(); - } else { - listener.onFailure(e); - } - })); - } - - private void deleteADResultOfDetector(String detectorId) { - if (!deleteADResultWhenDeleteDetector) { - logger.info("Won't delete ad result for {} as delete AD result setting is disabled", detectorId); - return; - } - logger.info("Start to delete AD results of detector {}", detectorId); - DeleteByQueryRequest deleteADResultsRequest = new DeleteByQueryRequest(ALL_AD_RESULTS_INDEX_PATTERN); - deleteADResultsRequest.setQuery(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); - client.execute(DeleteByQueryAction.INSTANCE, deleteADResultsRequest, ActionListener.wrap(response -> { - logger.debug("Successfully deleted AD results of detector " + detectorId); - }, exception -> { - logger.error("Failed to delete AD results of detector " + detectorId, exception); - adTaskCacheManager.addDeletedConfig(detectorId); - })); - } - - /** - * Clean AD results of deleted detector. - */ - public void cleanADResultOfDeletedDetector() { - String detectorId = adTaskCacheManager.pollDeletedConfig(); - if (detectorId != null) { - deleteADResultOfDetector(detectorId); - } - } - - /** - * Update latest AD task of detector. - * - * @param detectorId detector id - * @param taskTypes task types - * @param updatedFields updated fields, key: filed name, value: new value - * @param listener action listener - */ - public void updateLatestADTask( - String detectorId, - List taskTypes, - Map updatedFields, - ActionListener listener - ) { - getAndExecuteOnLatestDetectorLevelTask(detectorId, taskTypes, (adTask) -> { - if (adTask.isPresent()) { - updateADTask(adTask.get().getTaskId(), updatedFields, listener); - } else { - listener.onFailure(new ResourceNotFoundException(detectorId, CAN_NOT_FIND_LATEST_TASK)); - } - }, null, false, listener); - } - - /** - * Update latest realtime task. - * - * @param detectorId detector id - * @param state task state - * @param error error - * @param transportService transport service - * @param listener action listener - */ - public void stopLatestRealtimeTask( - String detectorId, - TaskState state, - Exception error, - TransportService transportService, - ActionListener listener - ) { - getAndExecuteOnLatestDetectorLevelTask(detectorId, REALTIME_TASK_TYPES, (adTask) -> { - if (adTask.isPresent() && !adTask.get().isDone()) { - Map updatedFields = new HashMap<>(); - updatedFields.put(ADTask.STATE_FIELD, state.name()); - if (error != null) { - updatedFields.put(ADTask.ERROR_FIELD, error.getMessage()); - } - ExecutorFunction function = () -> updateADTask(adTask.get().getTaskId(), updatedFields, ActionListener.wrap(r -> { - if (error == null) { - listener.onResponse(new JobResponse(detectorId)); - } else { - listener.onFailure(error); - } - }, e -> { listener.onFailure(e); })); - - String coordinatingNode = adTask.get().getCoordinatingNode(); - if (coordinatingNode != null && transportService != null) { - cleanDetectorCache(adTask.get(), transportService, function, listener); - } else { - function.execute(); - } - } else { - listener.onFailure(new OpenSearchStatusException("Anomaly detector job is already stopped: " + detectorId, RestStatus.OK)); - } - }, null, false, listener); - } - - /** - * Update realtime task cache on realtime detector's coordinating node. - * - * @param detectorId detector id - * @param state new state - * @param rcfTotalUpdates rcf total updates - * @param detectorIntervalInMinutes detector interval in minutes - * @param error error - * @param listener action listener - */ - public void updateLatestRealtimeTaskOnCoordinatingNode( - String detectorId, - String state, - Long rcfTotalUpdates, - Long detectorIntervalInMinutes, - String error, - ActionListener listener - ) { - Float initProgress = null; - String newState = null; - // calculate init progress and task state with RCF total updates - if (detectorIntervalInMinutes != null && rcfTotalUpdates != null) { - newState = TaskState.INIT.name(); - if (rcfTotalUpdates < NUM_MIN_SAMPLES) { - initProgress = (float) rcfTotalUpdates / NUM_MIN_SAMPLES; - } else { - newState = TaskState.RUNNING.name(); - initProgress = 1.0f; - } - } - // Check if new state is not null and override state calculated with rcf total updates - if (state != null) { - newState = state; - } - - error = Optional.ofNullable(error).orElse(""); - if (!adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId, newState, initProgress, error)) { - // If task not changed, no need to update, just return - listener.onResponse(null); - return; - } - Map updatedFields = new HashMap<>(); - updatedFields.put(COORDINATING_NODE_FIELD, clusterService.localNode().getId()); - if (initProgress != null) { - updatedFields.put(INIT_PROGRESS_FIELD, initProgress); - updatedFields.put(ESTIMATED_MINUTES_LEFT_FIELD, Math.max(0, NUM_MIN_SAMPLES - rcfTotalUpdates) * detectorIntervalInMinutes); - } - if (newState != null) { - updatedFields.put(STATE_FIELD, newState); - } - if (error != null) { - updatedFields.put(ERROR_FIELD, error); - } - Float finalInitProgress = initProgress; - // Variable used in lambda expression should be final or effectively final - String finalError = error; - String finalNewState = newState; - updateLatestADTask(detectorId, ADTaskType.REALTIME_TASK_TYPES, updatedFields, ActionListener.wrap(r -> { - logger.debug("Updated latest realtime AD task successfully for detector {}", detectorId); - adTaskCacheManager.updateRealtimeTaskCache(detectorId, finalNewState, finalInitProgress, finalError); - listener.onResponse(r); - }, e -> { - logger.error("Failed to update realtime task for detector " + detectorId, e); - listener.onFailure(e); - })); - } - /** * Init realtime task cache and clean up realtime task cache on old coordinating node. Realtime AD * depends on job scheduler to choose node (job coordinating node) to run AD job. Nodes have primary @@ -2048,33 +958,37 @@ public void updateLatestRealtimeTaskOnCoordinatingNode( * listener will return false. * * @param detectorId detector id - * @param detector anomaly detector + * @param config config accessor * @param transportService transport service * @param listener listener */ + @Override public void initRealtimeTaskCacheAndCleanupStaleCache( String detectorId, - AnomalyDetector detector, + Config config, TransportService transportService, ActionListener listener ) { try { - if (adTaskCacheManager.getRealtimeTaskCache(detectorId) != null) { + if (taskCacheManager.getRealtimeTaskCache(detectorId) != null) { listener.onResponse(false); return; } - getAndExecuteOnLatestDetectorLevelTask(detectorId, REALTIME_TASK_TYPES, (adTaskOptional) -> { + AnomalyDetector detector = (AnomalyDetector) config; + getAndExecuteOnLatestConfigLevelTask(detectorId, REALTIME_TASK_TYPES, (adTaskOptional) -> { if (!adTaskOptional.isPresent()) { logger.debug("Can't find realtime task for detector {}, init realtime task cache directly", detectorId); - ExecutorFunction function = () -> createNewADTask( + ExecutorFunction function = () -> createNewTask( detector, null, + false, detector.getUser(), clusterService.localNode().getId(), + TaskState.CREATED, ActionListener.wrap(r -> { logger.info("Recreate realtime task successfully for detector {}", detectorId); - adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + taskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); listener.onResponse(true); }, e -> { logger.error("Failed to recreate realtime task for detector " + detectorId, e); @@ -2096,19 +1010,19 @@ public void initRealtimeTaskCacheAndCleanupStaleCache( localNodeId, detectorId ); - cleanDetectorCache(adTask, transportService, () -> { + cleanConfigCache(adTask, transportService, () -> { logger .info( "Realtime task cache cleaned on old coordinating node {} for detector {}", oldCoordinatingNode, detectorId ); - adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + taskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); listener.onResponse(true); }, listener); } else { logger.info("Init realtime task cache for detector {}", detectorId); - adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + taskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); listener.onResponse(true); } }, transportService, false, listener); @@ -2119,16 +1033,16 @@ public void initRealtimeTaskCacheAndCleanupStaleCache( } private void recreateRealtimeTask(ExecutorFunction function, ActionListener listener) { - if (detectionIndices.doesStateIndexExist()) { + if (indexManagement.doesStateIndexExist()) { function.execute(); } else { // If detection index doesn't exist, create index and execute function. - detectionIndices.initStateIndex(ActionListener.wrap(r -> { + indexManagement.initStateIndex(ActionListener.wrap(r -> { if (r.isAcknowledged()) { logger.info("Created {} with mappings.", DETECTION_STATE_INDEX); function.execute(); } else { - String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); + String error = String.format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); logger.warn(error); listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); } @@ -2143,14 +1057,6 @@ private void recreateRealtimeTask(ExecutorFunction function, ActionListener listener) { String detectorId = adTask.getConfigId(); - String taskId = adTask.isEntityTask() ? adTask.getParentTaskId() : adTask.getTaskId(); + String taskId = adTask.isHistoricalEntityTask() ? adTask.getParentTaskId() : adTask.getTaskId(); String detectorTaskId = adTask.getConfigLevelTaskId(); ActionListener wrappedListener = ActionListener.wrap(response -> { logger.info("Historical HC detector done with state: {}. Remove from cache, detector id:{}", state.name(), detectorId); - adTaskCacheManager.removeHistoricalTaskCache(detectorId); + taskCacheManager.removeHistoricalTaskCache(detectorId); }, e -> { // HC detector task may fail to update as FINISHED for some edge case if failed to get updating semaphore. // Will reset task state when get detector with task or maintain tasks in hourly cron. @@ -2260,7 +1166,7 @@ public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener } else { logger.error("Failed to update task: " + taskId, e); } - adTaskCacheManager.removeHistoricalTaskCache(detectorId); + taskCacheManager.removeHistoricalTaskCache(detectorId); }); long timeoutInMillis = 2000;// wait for 2 seconds to acquire updating HC detector task semaphore @@ -2276,11 +1182,11 @@ public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener taskId, ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, hcDetectorTaskState.name(), - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, 1.0, - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli() ), timeoutInMillis, @@ -2290,20 +1196,20 @@ public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener }, e -> { logger.error("Failed to get finished entity tasks", e); - String errorMessage = getErrorMessage(e); + String errorMessage = ExceptionUtil.getErrorMessage(e); threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { updateADHCDetectorTask( detectorId, taskId, ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, TaskState.FAILED.name(),// set as FAILED if fail to get finished entity tasks. - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, 1.0, - ERROR_FIELD, + TimeSeriesTask.ERROR_FIELD, errorMessage, - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli() ), timeoutInMillis, @@ -2318,11 +1224,11 @@ public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener taskId, ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, state.name(), - ERROR_FIELD, + TimeSeriesTask.ERROR_FIELD, adTask.getError(), - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli() ), timeoutInMillis, @@ -2344,9 +1250,12 @@ public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener */ public void countEntityTasksByState(String detectorTaskId, List taskStates, ActionListener listener) { BoolQueryBuilder queryBuilder = new BoolQueryBuilder(); - queryBuilder.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, detectorTaskId)); + queryBuilder.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, detectorTaskId)); if (taskStates != null && taskStates.size() > 0) { - queryBuilder.filter(new TermsQueryBuilder(STATE_FIELD, taskStates.stream().map(s -> s.name()).collect(Collectors.toList()))); + queryBuilder + .filter( + new TermsQueryBuilder(TimeSeriesTask.STATE_FIELD, taskStates.stream().map(s -> s.name()).collect(Collectors.toList())) + ); } SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.query(queryBuilder); @@ -2409,19 +1318,19 @@ private void updateADHCDetectorTask( ActionListener listener ) { try { - if (adTaskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, timeoutInMillis)) { + if (taskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, timeoutInMillis)) { try { - updateADTask( + updateTask( taskId, updatedFields, - ActionListener.runAfter(listener, () -> { adTaskCacheManager.releaseTaskUpdatingSemaphore(detectorId); }) + ActionListener.runAfter(listener, () -> { taskCacheManager.releaseTaskUpdatingSemaphore(detectorId); }) ); } catch (Exception e) { logger.error("Failed to update detector task " + taskId, e); - adTaskCacheManager.releaseTaskUpdatingSemaphore(detectorId); + taskCacheManager.releaseTaskUpdatingSemaphore(detectorId); listener.onFailure(e); } - } else if (!adTaskCacheManager.isHCTaskCoordinatingNode(detectorId)) { + } else if (!taskCacheManager.isHCTaskCoordinatingNode(detectorId)) { // It's possible that AD task cache cleaned up by other task. Return null to avoid too many failure logs. logger.info("HC detector task cache does not exist, detectorId:{}, taskId:{}", detectorId, taskId); listener.onResponse(null); @@ -2457,7 +1366,7 @@ public void runNextEntityForHCADHistorical(ADTask adTask, TransportService trans "Have scaled down task slots. Will not poll next entity for detector {}, task {}, task slots: {}", detectorId, adTask.getTaskId(), - adTaskCacheManager.getDetectorTaskSlots(detectorId) + taskCacheManager.getDetectorTaskSlots(detectorId) ); listener.onResponse(new JobResponse(detectorId)); return; @@ -2496,9 +1405,9 @@ protected int scaleTaskSlots(ADTask adTask, TransportService transportService, A try { int scaleDelta = detectorTaskSlotScaleDelta(detectorId); logger.debug("start to scale task slots for detector {} with delta {}", detectorId, scaleDelta); - if (adTaskCacheManager.getAvailableNewEntityTaskLanes(detectorId) <= 0 && scaleDelta > 0) { + if (taskCacheManager.getAvailableNewEntityTaskLanes(detectorId) <= 0 && scaleDelta > 0) { // scale up to run more entities in parallel - Instant lastScaleEntityTaskLaneTime = adTaskCacheManager.getLastScaleEntityTaskLaneTime(detectorId); + Instant lastScaleEntityTaskLaneTime = taskCacheManager.getLastScaleEntityTaskLaneTime(detectorId); if (lastScaleEntityTaskLaneTime == null) { logger.debug("lastScaleEntityTaskLaneTime is null for detector {}", detectorId); scaleEntityTaskLane.release(); @@ -2508,7 +1417,7 @@ protected int scaleTaskSlots(ADTask adTask, TransportService transportService, A .plusMillis(SCALE_ENTITY_TASK_LANE_INTERVAL_IN_MILLIS) .isBefore(Instant.now()); if (lastScaleTimeExpired) { - adTaskCacheManager.refreshLastScaleEntityTaskLaneTime(detectorId); + taskCacheManager.refreshLastScaleEntityTaskLaneTime(detectorId); logger.debug("Forward scale entity task lane request to lead node for detector {}", detectorId); forwardScaleTaskSlotRequestToLeadNode( adTask, @@ -2526,9 +1435,9 @@ protected int scaleTaskSlots(ADTask adTask, TransportService transportService, A } } else { if (scaleDelta < 0) { // scale down to release task slots for other detectors - int runningEntityCount = adTaskCacheManager.getRunningEntityCount(detectorId) + adTaskCacheManager + int runningEntityCount = taskCacheManager.getRunningEntityCount(detectorId) + taskCacheManager .getTempEntityCount(detectorId); - int assignedTaskSlots = adTaskCacheManager.getDetectorTaskSlots(detectorId); + int assignedTaskSlots = taskCacheManager.getDetectorTaskSlots(detectorId); int scaleDownDelta = Math.min(assignedTaskSlots - runningEntityCount, 0 - scaleDelta); logger .debug( @@ -2538,7 +1447,7 @@ protected int scaleTaskSlots(ADTask adTask, TransportService transportService, A runningEntityCount, scaleDownDelta ); - adTaskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, scaleDownDelta); + taskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, scaleDownDelta); } scaleEntityTaskLane.release(); } @@ -2567,13 +1476,13 @@ protected int scaleTaskSlots(ADTask adTask, TransportService transportService, A * @return detector task slots scale delta */ public int detectorTaskSlotScaleDelta(String detectorId) { - DiscoveryNode[] eligibleDataNodes = hashRing.getNodesWithSameLocalAdVersion(); - int unfinishedEntities = adTaskCacheManager.getUnfinishedEntityCount(detectorId); + DiscoveryNode[] eligibleDataNodes = hashRing.getNodesWithSameLocalVersion(); + int unfinishedEntities = taskCacheManager.getUnfinishedEntityCount(detectorId); int totalTaskSlots = eligibleDataNodes.length * maxAdBatchTaskPerNode; int taskLaneLimit = Math.min(unfinishedEntities, Math.min(totalTaskSlots, maxRunningEntitiesPerDetector)); - adTaskCacheManager.setDetectorTaskLaneLimit(detectorId, taskLaneLimit); + taskCacheManager.setDetectorTaskLaneLimit(detectorId, taskLaneLimit); - int assignedTaskSlots = adTaskCacheManager.getDetectorTaskSlots(detectorId); + int assignedTaskSlots = taskCacheManager.getDetectorTaskSlots(detectorId); int scaleDelta = taskLaneLimit - assignedTaskSlots; logger .debug( @@ -2597,8 +1506,8 @@ public int detectorTaskSlotScaleDelta(String detectorId) { * @return task progress */ public float hcDetectorProgress(String detectorId) { - int entityCount = adTaskCacheManager.getTopEntityCount(detectorId); - int leftEntities = adTaskCacheManager.getPendingEntityCount(detectorId) + adTaskCacheManager.getRunningEntityCount(detectorId); + int entityCount = taskCacheManager.getTopEntityCount(detectorId); + int leftEntities = taskCacheManager.getPendingEntityCount(detectorId) + taskCacheManager.getRunningEntityCount(detectorId); return 1 - (float) leftEntities / entityCount; } @@ -2608,39 +1517,39 @@ public float hcDetectorProgress(String detectorId) { * @return list of AD task profile */ public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { - List tasksOfDetector = adTaskCacheManager.getTasksOfDetector(detectorId); + List tasksOfDetector = taskCacheManager.getTasksOfDetector(detectorId); ADTaskProfile detectorTaskProfile = null; String localNodeId = clusterService.localNode().getId(); - if (adTaskCacheManager.isHCTaskRunning(detectorId)) { + if (taskCacheManager.isHCTaskRunning(detectorId)) { detectorTaskProfile = new ADTaskProfile(); - if (adTaskCacheManager.isHCTaskCoordinatingNode(detectorId)) { + if (taskCacheManager.isHCTaskCoordinatingNode(detectorId)) { detectorTaskProfile.setNodeId(localNodeId); - detectorTaskProfile.setTaskId(adTaskCacheManager.getDetectorTaskId(detectorId)); - detectorTaskProfile.setDetectorTaskSlots(adTaskCacheManager.getDetectorTaskSlots(detectorId)); - detectorTaskProfile.setTotalEntitiesInited(adTaskCacheManager.topEntityInited(detectorId)); - detectorTaskProfile.setTotalEntitiesCount(adTaskCacheManager.getTopEntityCount(detectorId)); - detectorTaskProfile.setPendingEntitiesCount(adTaskCacheManager.getPendingEntityCount(detectorId)); - detectorTaskProfile.setRunningEntitiesCount(adTaskCacheManager.getRunningEntityCount(detectorId)); - detectorTaskProfile.setRunningEntities(adTaskCacheManager.getRunningEntities(detectorId)); - detectorTaskProfile.setAdTaskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()); - Instant latestHCTaskRunTime = adTaskCacheManager.getLatestHCTaskRunTime(detectorId); + detectorTaskProfile.setTaskId(taskCacheManager.getDetectorTaskId(detectorId)); + detectorTaskProfile.setDetectorTaskSlots(taskCacheManager.getDetectorTaskSlots(detectorId)); + detectorTaskProfile.setTotalEntitiesInited(taskCacheManager.topEntityInited(detectorId)); + detectorTaskProfile.setTotalEntitiesCount(taskCacheManager.getTopEntityCount(detectorId)); + detectorTaskProfile.setPendingEntitiesCount(taskCacheManager.getPendingEntityCount(detectorId)); + detectorTaskProfile.setRunningEntitiesCount(taskCacheManager.getRunningEntityCount(detectorId)); + detectorTaskProfile.setRunningEntities(taskCacheManager.getRunningEntities(detectorId)); + detectorTaskProfile.setTaskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()); + Instant latestHCTaskRunTime = taskCacheManager.getLatestHCTaskRunTime(detectorId); if (latestHCTaskRunTime != null) { detectorTaskProfile.setLatestHCTaskRunTime(latestHCTaskRunTime.toEpochMilli()); } } if (tasksOfDetector.size() > 0) { - List entityTaskProfiles = new ArrayList<>(); + List entityTaskProfiles = new ArrayList<>(); tasksOfDetector.forEach(taskId -> { - ADEntityTaskProfile entityTaskProfile = new ADEntityTaskProfile( - adTaskCacheManager.getShingle(taskId).size(), - adTaskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), - adTaskCacheManager.isThresholdModelTrained(taskId), - adTaskCacheManager.getThresholdModelTrainingDataSize(taskId), - adTaskCacheManager.getModelSize(taskId), + EntityTaskProfile entityTaskProfile = new EntityTaskProfile( + taskCacheManager.getShingle(taskId).size(), + taskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), + taskCacheManager.isThresholdModelTrained(taskId), + taskCacheManager.getThresholdModelTrainingDataSize(taskId), + taskCacheManager.getModelSize(taskId), localNodeId, - adTaskCacheManager.getEntity(taskId), + taskCacheManager.getEntity(taskId), taskId, ADTaskType.HISTORICAL_HC_ENTITY.name() ); @@ -2659,12 +1568,12 @@ public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { if (tasksOfDetector.size() == 1) { String taskId = tasksOfDetector.get(0); detectorTaskProfile = new ADTaskProfile( - adTaskCacheManager.getDetectorTaskId(detectorId), - adTaskCacheManager.getShingle(taskId).size(), - adTaskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), - adTaskCacheManager.isThresholdModelTrained(taskId), - adTaskCacheManager.getThresholdModelTrainingDataSize(taskId), - adTaskCacheManager.getModelSize(taskId), + taskCacheManager.getDetectorTaskId(detectorId), + taskCacheManager.getShingle(taskId).size(), + taskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), + taskCacheManager.isThresholdModelTrained(taskId), + taskCacheManager.getThresholdModelTrainingDataSize(taskId), + taskCacheManager.getModelSize(taskId), localNodeId ); // Single-flow detector only has 1 task slot. @@ -2677,7 +1586,7 @@ public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { // Clean expired HC batch task run states as it may exists after HC historical analysis done if user cancel // before querying top entities done. We will clean it in hourly cron, check "maintainRunningHistoricalTasks" // method. Clean it up here when get task profile to release memory earlier. - adTaskCacheManager.cleanExpiredHCBatchTaskRunStates(); + taskCacheManager.cleanExpiredHCBatchTaskRunStates(); }); logger.debug("Local AD task profile of detector {}: {}", detectorId, detectorTaskProfile); return detectorTaskProfile; @@ -2730,34 +1639,19 @@ public synchronized void removeStaleRunningEntity( ActionListener listener ) { String detectorId = adTask.getConfigId(); - boolean removed = adTaskCacheManager.removeRunningEntity(detectorId, entity); - if (removed && adTaskCacheManager.getPendingEntityCount(detectorId) > 0) { + boolean removed = taskCacheManager.removeRunningEntity(detectorId, entity); + if (removed && taskCacheManager.getPendingEntityCount(detectorId) > 0) { logger.debug("kick off next pending entities"); this.runNextEntityForHCADHistorical(adTask, transportService, listener); } else { - if (!adTaskCacheManager.hasEntity(detectorId)) { + if (!taskCacheManager.hasEntity(detectorId)) { setHCDetectorTaskDone(adTask, TaskState.STOPPED, listener); } } } - public boolean skipUpdateHCRealtimeTask(String detectorId, String error) { - RealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); - return realtimeTaskCache != null - && realtimeTaskCache.getInitProgress() != null - && realtimeTaskCache.getInitProgress().floatValue() == 1.0 - && Objects.equals(error, realtimeTaskCache.getError()); - } - - public boolean isHCRealtimeTaskStartInitializing(String detectorId) { - RealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); - return realtimeTaskCache != null - && realtimeTaskCache.getInitProgress() != null - && realtimeTaskCache.getInitProgress().floatValue() > 0; - } - public String convertEntityToString(ADTask adTask) { - if (adTask == null || !adTask.isEntityTask()) { + if (adTask == null || !adTask.isHistoricalEntityTask()) { return null; } AnomalyDetector detector = adTask.getDetector(); @@ -2844,45 +1738,8 @@ public void getADTask(String taskId, ActionListener> listener) })); } - /** - * Set old AD task's latest flag as false. - * @param adTasks list of AD tasks - */ - public void resetLatestFlagAsFalse(List adTasks) { - if (adTasks == null || adTasks.size() == 0) { - return; - } - BulkRequest bulkRequest = new BulkRequest(); - adTasks.forEach(task -> { - try { - task.setLatest(false); - task.setLastUpdateTime(Instant.now()); - IndexRequest indexRequest = new IndexRequest(DETECTION_STATE_INDEX) - .id(task.getTaskId()) - .source(task.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)); - bulkRequest.add(indexRequest); - } catch (Exception e) { - logger.error("Fail to parse task AD task to XContent, task id " + task.getTaskId(), e); - } - }); - - bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { - BulkItemResponse[] bulkItemResponses = res.getItems(); - if (bulkItemResponses != null && bulkItemResponses.length > 0) { - for (BulkItemResponse bulkItemResponse : bulkItemResponses) { - if (!bulkItemResponse.isFailed()) { - logger.warn("Reset AD tasks latest flag as false Successfully. Task id: {}", bulkItemResponse.getId()); - } else { - logger.warn("Failed to reset AD tasks latest flag as false. Task id: " + bulkItemResponse.getId()); - } - } - } - }, e -> { logger.warn("Failed to reset AD tasks latest flag as false", e); })); - } - public int getLocalAdUsedBatchTaskSlot() { - return adTaskCacheManager.getTotalBatchTaskCount(); + return taskCacheManager.getTotalBatchTaskCount(); } /** @@ -2908,7 +1765,7 @@ public int getLocalAdUsedBatchTaskSlot() { * @return assigned batch task slots */ public int getLocalAdAssignedBatchTaskSlot() { - return adTaskCacheManager.getTotalDetectorTaskSlots(); + return taskCacheManager.getTotalDetectorTaskSlots(); } // ========================================================= @@ -2928,23 +1785,23 @@ public int getLocalAdAssignedBatchTaskSlot() { */ public void maintainRunningHistoricalTasks(TransportService transportService, int size) { // Clean expired HC batch task run state cache. - adTaskCacheManager.cleanExpiredHCBatchTaskRunStates(); + taskCacheManager.cleanExpiredHCBatchTaskRunStates(); // Find owning node with highest AD version to make sure we only have 1 node maintain running historical tasks // and we use the latest logic. - Optional owningNode = hashRing.getOwningNodeWithHighestAdVersion(AD_TASK_MAINTAINENCE_NODE_MODEL_ID); + Optional owningNode = hashRing.getOwningNodeWithHighestVersion(AD_TASK_MAINTAINENCE_NODE_MODEL_ID); if (!owningNode.isPresent() || !clusterService.localNode().getId().equals(owningNode.get().getId())) { return; } logger.info("Start to maintain running historical tasks"); BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); - query.filter(new TermsQueryBuilder(STATE_FIELD, NOT_ENDED_STATES)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); + query.filter(new TermsQueryBuilder(TimeSeriesTask.STATE_FIELD, NOT_ENDED_STATES)); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // default maintain interval is 5 seconds, so maintain 10 tasks will take at least 50 seconds. - sourceBuilder.query(query).sort(LAST_UPDATE_TIME_FIELD, SortOrder.DESC).size(size); + sourceBuilder.query(query).sort(TimeSeriesTask.LAST_UPDATE_TIME_FIELD, SortOrder.DESC).size(size); SearchRequest searchRequest = new SearchRequest(); searchRequest.source(sourceBuilder); searchRequest.indices(DETECTION_STATE_INDEX); @@ -2982,7 +1839,7 @@ private void maintainRunningHistoricalTask(ConcurrentLinkedQueue taskQue return; } threadPool.schedule(() -> { - resetHistoricalDetectorTaskState(ImmutableList.of(adTask), () -> { + resetHistoricalConfigTaskState(ImmutableList.of(adTask), () -> { logger.debug("Finished maintaining running historical task {}", adTask.getTaskId()); maintainRunningHistoricalTask(taskQueue, transportService); }, transportService, ActionListener.wrap(r -> { @@ -2992,20 +1849,88 @@ private void maintainRunningHistoricalTask(ConcurrentLinkedQueue taskQue } /** - * Maintain running realtime tasks. Check if realtime task cache expires or not. Remove realtime - * task cache directly if expired. + * Get list of task types. + * 1. If date range is null, will return all realtime task types + * 2. If date range is not null, will return all historical detector level tasks types + * if resetLatestTaskStateFlag is true; otherwise return all historical tasks types include + * HC entity level task type. + * @param dateRange detection date range + * @param resetLatestTaskStateFlag reset latest task state or not + * @return list of AD task types */ - public void maintainRunningRealtimeTasks() { - String[] detectorIds = adTaskCacheManager.getDetectorIdsInRealtimeTaskCache(); - if (detectorIds == null || detectorIds.length == 0) { - return; + protected List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag) { + // AD does not support run once + return getTaskTypes(dateRange, resetLatestTaskStateFlag, false); + } + + @Override + protected BiCheckedFunction getTaskParser() { + return ADTask::parse; + } + + @Override + public void createRunOnceTaskAndCleanupStaleTasks( + String configId, + Config config, + TransportService transportService, + ActionListener listener + ) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("AD has no run once yet"); + } + + @Override + public List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag, boolean runOnce) { + if (dateRange == null) { + return REALTIME_TASK_TYPES; + } else { + if (resetLatestTaskStateFlag) { + // return all task types include HC entity task to make sure we can reset all tasks latest flag + return ALL_HISTORICAL_TASK_TYPES; + } else { + return HISTORICAL_DETECTOR_TASK_TYPES; + } } - for (int i = 0; i < detectorIds.length; i++) { - String detectorId = detectorIds[i]; - RealtimeTaskCache taskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); - if (taskCache != null && taskCache.expired()) { - adTaskCacheManager.removeRealtimeTaskCache(detectorId); + } + + /** + * Reset latest config task state. Will reset both historical and realtime tasks. + * [Important!] Make sure listener returns in function + * + * @param tasks tasks + * @param function consumer function + * @param transportService transport service + * @param listener action listener + * @param response type of action listener + */ + @Override + protected void resetLatestConfigTaskState( + List tasks, + Consumer> function, + TransportService transportService, + ActionListener listener + ) { + List runningHistoricalTasks = new ArrayList<>(); + List runningRealtimeTasks = new ArrayList<>(); + + for (TimeSeriesTask task : tasks) { + if (!task.isHistoricalEntityTask() && !task.isDone()) { + if (task.isRealTimeTask()) { + runningRealtimeTasks.add(task); + } else if (task.isHistoricalTask()) { + runningHistoricalTasks.add(task); + } } } + + // resetRealtimeCOnfigTaskState has to be the innermost function call as we return listener there + // AD has no run once and forecasting has no historical. So the run once and historical reset + // function only forwards function call and does not return listener + resetHistoricalConfigTaskState( + runningHistoricalTasks, + () -> resetRealtimeConfigTaskState(runningRealtimeTasks, () -> function.accept(tasks), transportService, listener), + transportService, + listener + ); } } diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java index 84fe0c6fe..df6194353 100644 --- a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java @@ -14,10 +14,10 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADBatchAnomalyResultAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK; public static final ADBatchAnomalyResultAction INSTANCE = new ADBatchAnomalyResultAction(); private ADBatchAnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java b/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java index d865ec14c..84a22b261 100644 --- a/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java @@ -14,10 +14,10 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK_REMOTE; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADBatchTaskRemoteExecutionAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK_REMOTE; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK_REMOTE; public static final ADBatchTaskRemoteExecutionAction INSTANCE = new ADBatchTaskRemoteExecutionAction(); private ADBatchTaskRemoteExecutionAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java b/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java index 31f20fa00..d20759f70 100644 --- a/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java @@ -14,11 +14,11 @@ import static org.opensearch.ad.constant.ADCommonName.CANCEL_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADCancelTaskAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/" + CANCEL_TASK; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detectors/" + CANCEL_TASK; public static final ADCancelTaskAction INSTANCE = new ADCancelTaskAction(); private ADCancelTaskAction() { diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileAction.java b/src/main/java/org/opensearch/ad/transport/ADEntityProfileAction.java similarity index 54% rename from src/main/java/org/opensearch/ad/transport/EntityProfileAction.java rename to src/main/java/org/opensearch/ad/transport/ADEntityProfileAction.java index c699d9a03..11e6a44a4 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADEntityProfileAction.java @@ -12,14 +12,15 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.EntityProfileResponse; -public class EntityProfileAction extends ActionType { +public class ADEntityProfileAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/entity"; - public static final EntityProfileAction INSTANCE = new EntityProfileAction(); + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/entity"; + public static final ADEntityProfileAction INSTANCE = new ADEntityProfileAction(); - private EntityProfileAction() { + private ADEntityProfileAction() { super(NAME, EntityProfileResponse::new); } diff --git a/src/main/java/org/opensearch/ad/transport/ADEntityProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADEntityProfileTransportAction.java new file mode 100644 index 000000000..5ffde2999 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADEntityProfileTransportAction.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.transport.BaseEntityProfileTransportAction; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Transport action to get entity profile. + */ +public class ADEntityProfileTransportAction extends + BaseEntityProfileTransportAction { + + @Inject + public ADEntityProfileTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + HashRing hashRing, + ClusterService clusterService, + ADCacheProvider cacheProvider + ) { + super( + actionFilters, + transportService, + settings, + hashRing, + clusterService, + cacheProvider, + ADEntityProfileAction.NAME, + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT + ); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileAction.java b/src/main/java/org/opensearch/ad/transport/ADProfileAction.java similarity index 59% rename from src/main/java/org/opensearch/ad/transport/ProfileAction.java rename to src/main/java/org/opensearch/ad/transport/ADProfileAction.java index 291dd0982..1d51add9e 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADProfileAction.java @@ -12,20 +12,21 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.ProfileResponse; /** * Profile transport action */ -public class ProfileAction extends ActionType { +public class ADProfileAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile"; - public static final ProfileAction INSTANCE = new ProfileAction(); + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detectors/profile"; + public static final ADProfileAction INSTANCE = new ADProfileAction(); /** * Constructor */ - private ProfileAction() { + private ADProfileAction() { super(NAME, ProfileResponse::new); } diff --git a/src/main/java/org/opensearch/ad/transport/ADProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADProfileTransportAction.java new file mode 100644 index 000000000..af7c40bb4 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADProfileTransportAction.java @@ -0,0 +1,114 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.BaseProfileTransportAction; +import org.opensearch.timeseries.transport.ProfileNodeRequest; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * This class contains the logic to extract the stats from the nodes + */ +public class ADProfileTransportAction extends BaseProfileTransportAction { + private ADModelManager modelManager; + private FeatureManager featureManager; + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param modelManager model manager object + * @param featureManager feature manager object + * @param cacheProvider cache provider + * @param settings Node settings accessor + */ + @Inject + public ADProfileTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ADModelManager modelManager, + FeatureManager featureManager, + ADCacheProvider cacheProvider, + Settings settings + ) { + super( + ADProfileAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + cacheProvider, + settings, + AD_MAX_MODEL_SIZE_PER_NODE + ); + this.modelManager = modelManager; + this.featureManager = featureManager; + } + + @Override + protected ProfileNodeResponse nodeOperation(ProfileNodeRequest request) { + String detectorId = request.getConfigId(); + Set profiles = request.getProfilesToBeRetrieved(); + int shingleSize = -1; + long activeEntity = 0; + long totalUpdates = 0; + Map modelSize = null; + List modelProfiles = null; + int modelCount = 0; + if (request.isModelInPriorityCache()) { + super.nodeOperation(request); + } else { + if (profiles.contains(ProfileName.COORDINATING_NODE) || profiles.contains(ProfileName.SHINGLE_SIZE)) { + shingleSize = featureManager.getShingleSize(detectorId); + } + + if (profiles.contains(ProfileName.TOTAL_SIZE_IN_BYTES) || profiles.contains(ProfileName.MODELS)) { + modelSize = modelManager.getModelSize(detectorId); + } + } + + return new ProfileNodeResponse( + clusterService.localNode(), + modelSize, + shingleSize, + activeEntity, + totalUpdates, + modelProfiles, + modelCount + ); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java index 041d543b7..e54a4747e 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java @@ -12,18 +12,19 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.transport.ResultBulkResponse; import org.opensearch.transport.TransportRequestOptions; -public class ADResultBulkAction extends ActionType { +public class ADResultBulkAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "write/bulk"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "write/bulk"; public static final ADResultBulkAction INSTANCE = new ADResultBulkAction(); private ADResultBulkAction() { - super(NAME, ADResultBulkResponse::new); + super(NAME, ResultBulkResponse::new); } @Override diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java index f5f361f69..0f8430a25 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java @@ -12,73 +12,19 @@ package org.opensearch.ad.transport; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.action.ValidateActions; -import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.timeseries.transport.ResultBulkRequest; -public class ADResultBulkRequest extends ActionRequest implements Writeable { - private final List anomalyResults; - static final String NO_REQUESTS_ADDED_ERR = "no requests added"; +public class ADResultBulkRequest extends ResultBulkRequest { public ADResultBulkRequest() { - anomalyResults = new ArrayList<>(); + super(); } public ADResultBulkRequest(StreamInput in) throws IOException { - super(in); - int size = in.readVInt(); - anomalyResults = new ArrayList<>(size); - for (int i = 0; i < size; i++) { - anomalyResults.add(new ResultWriteRequest(in)); - } - } - - @Override - public ActionRequestValidationException validate() { - ActionRequestValidationException validationException = null; - if (anomalyResults.isEmpty()) { - validationException = ValidateActions.addValidationError(NO_REQUESTS_ADDED_ERR, validationException); - } - return validationException; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeVInt(anomalyResults.size()); - for (ResultWriteRequest result : anomalyResults) { - result.writeTo(out); - } - } - - /** - * - * @return all of the results to send - */ - public List getAnomalyResults() { - return anomalyResults; - } - - /** - * Add result to send - * @param resultWriteRequest The result write request - */ - public void add(ResultWriteRequest resultWriteRequest) { - anomalyResults.add(resultWriteRequest); - } - - /** - * - * @return total index requests - */ - public int numberOfActions() { - return anomalyResults.size(); + super(in, ADResultWriteRequest::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java index 03ca7657c..d5442e57c 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java @@ -14,45 +14,31 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_INDEX_PRESSURE_HARD_LIMIT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.opensearch.index.IndexingPressure.MAX_INDEXING_BYTES; import java.io.IOException; import java.util.List; -import java.util.Random; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.bulk.BulkAction; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.ratelimit.ResultWriteRequest; -import org.opensearch.ad.util.BulkUtil; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexingPressure; -import org.opensearch.tasks.Task; -import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.transport.ResultBulkTransportAction; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; -public class ADResultBulkTransportAction extends HandledTransportAction { +public class ADResultBulkTransportAction extends ResultBulkTransportAction { private static final Logger LOG = LogManager.getLogger(ADResultBulkTransportAction.class); - private IndexingPressure indexingPressure; - private final long primaryAndCoordinatingLimits; - private float softLimit; - private float hardLimit; - private String indexName; - private Client client; - private Random random; @Inject public ADResultBulkTransportAction( @@ -63,69 +49,51 @@ public ADResultBulkTransportAction( ClusterService clusterService, Client client ) { - super(ADResultBulkAction.NAME, transportService, actionFilters, ADResultBulkRequest::new, ThreadPool.Names.SAME); - this.indexingPressure = indexingPressure; - this.primaryAndCoordinatingLimits = MAX_INDEXING_BYTES.get(settings).getBytes(); - this.softLimit = AD_INDEX_PRESSURE_SOFT_LIMIT.get(settings); - this.hardLimit = AD_INDEX_PRESSURE_HARD_LIMIT.get(settings); - this.indexName = ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; - this.client = client; + super( + ADResultBulkAction.NAME, + transportService, + actionFilters, + indexingPressure, + settings, + client, + AD_INDEX_PRESSURE_SOFT_LIMIT.get(settings), + AD_INDEX_PRESSURE_HARD_LIMIT.get(settings), + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + ADResultBulkRequest::new + ); clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_INDEX_PRESSURE_SOFT_LIMIT, it -> softLimit = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_INDEX_PRESSURE_HARD_LIMIT, it -> hardLimit = it); - // random seed is 42. Can be any number - this.random = new Random(42); } @Override - protected void doExecute(Task task, ADResultBulkRequest request, ActionListener listener) { - // Concurrent indexing memory limit = 10% of heap - // indexing pressure = indexing bytes / indexing limit - // Write all until index pressure (global indexing memory pressure) is less than 80% of 10% of heap. Otherwise, index - // all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure). - long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes(); - float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits; - List results = request.getAnomalyResults(); - - if (results == null || results.size() < 1) { - listener.onResponse(new ADResultBulkResponse()); - } - + protected BulkRequest prepareBulkRequest(float indexingPressurePercent, ADResultBulkRequest request) { BulkRequest bulkRequest = new BulkRequest(); + List results = request.getAnomalyResults(); if (indexingPressurePercent <= softLimit) { - for (ResultWriteRequest resultWriteRequest : results) { - addResult(bulkRequest, resultWriteRequest.getResult(), resultWriteRequest.getCustomResultIndex()); + for (ADResultWriteRequest resultWriteRequest : results) { + addResult(bulkRequest, resultWriteRequest.getResult(), resultWriteRequest.getResultIndex()); } } else if (indexingPressurePercent <= hardLimit) { // exceed soft limit (60%) but smaller than hard limit (90%) float acceptProbability = 1 - indexingPressurePercent; - for (ResultWriteRequest resultWriteRequest : results) { + for (ADResultWriteRequest resultWriteRequest : results) { AnomalyResult result = resultWriteRequest.getResult(); if (result.isHighPriority() || random.nextFloat() < acceptProbability) { - addResult(bulkRequest, result, resultWriteRequest.getCustomResultIndex()); + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); } } } else { // if exceeding hard limit, only index non-zero grade or error result - for (ResultWriteRequest resultWriteRequest : results) { + for (ADResultWriteRequest resultWriteRequest : results) { AnomalyResult result = resultWriteRequest.getResult(); if (result.isHighPriority()) { - addResult(bulkRequest, result, resultWriteRequest.getCustomResultIndex()); + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); } } } - if (bulkRequest.numberOfActions() > 0) { - client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(bulkResponse -> { - List failedRequests = BulkUtil.getFailedIndexRequest(bulkRequest, bulkResponse); - listener.onResponse(new ADResultBulkResponse(failedRequests)); - }, e -> { - LOG.error("Failed to bulk index AD result", e); - listener.onFailure(e); - })); - } else { - listener.onResponse(new ADResultBulkResponse()); - } + return bulkRequest; } private void addResult(BulkRequest bulkRequest, AnomalyResult result, String resultIndex) { diff --git a/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java b/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java new file mode 100644 index 000000000..b0f01b996 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_PAGE_SIZE; + +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ADResultProcessor extends + ResultProcessor { + private static final Logger LOG = LogManager.getLogger(ADResultProcessor.class); + + public ADResultProcessor( + Setting requestTimeoutSetting, + float intervalRatioForRequests, + String entityResultAction, + StatNames hcRequestCountStat, + Settings settings, + ClusterService clusterService, + ThreadPool threadPool, + HashRing hashRing, + NodeStateManager nodeStateManager, + TransportService transportService, + ADStats timeSeriesStats, + ADTaskManager realTimeTaskManager, + NamedXContentRegistry xContentRegistry, + Client client, + SecurityClientUtil clientUtil, + IndexNameExpressionResolver indexNameExpressionResolver, + Class transportResultResponseClazz, + FeatureManager featureManager + ) { + super( + requestTimeoutSetting, + intervalRatioForRequests, + entityResultAction, + hcRequestCountStat, + settings, + clusterService, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + hashRing, + nodeStateManager, + transportService, + timeSeriesStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + transportResultResponseClazz, + featureManager, + AD_MAX_ENTITIES_PER_QUERY, + AD_PAGE_SIZE, + AnalysisType.AD, + false, + ADSingleStreamResultAction.NAME + ); + } + + @Override + protected AnomalyResultResponse createResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long configInterval, + Boolean isHC, + String taskId + ) { + return new AnomalyResultResponse(features, error, rcfTotalUpdates, configInterval, isHC, taskId); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADSingleStreamResultAction.java b/src/main/java/org/opensearch/ad/transport/ADSingleStreamResultAction.java new file mode 100644 index 000000000..9a5c74373 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADSingleStreamResultAction.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.constant.ADCommonValue; + +public class ADSingleStreamResultAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "singlestream/result"; + public static final ADSingleStreamResultAction INSTANCE = new ADSingleStreamResultAction(); + + private ADSingleStreamResultAction() { + super(NAME, AcknowledgedResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADSingleStreamResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADSingleStreamResultTransportAction.java new file mode 100644 index 000000000..27a137309 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADSingleStreamResultTransportAction.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheBuffer; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; +import org.opensearch.ad.ratelimit.ADSaveResultStrategy; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.AbstractSingleStreamResultTransportAction; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ADSingleStreamResultTransportAction extends + AbstractSingleStreamResultTransportAction { + + public ADSingleStreamResultTransportAction( + TransportService transportService, + ActionFilters actionFilters, + CircuitBreakerService circuitBreakerService, + ADCacheProvider cache, + NodeStateManager stateManager, + ADCheckpointReadWorker checkpointReadQueue, + ADModelManager modelManager, + ADIndexManagement indexUtil, + ADResultWriteWorker resultWriteQueue, + Stats stats, + ADColdStartWorker forecastColdStartQueue + ) { + super( + transportService, + actionFilters, + circuitBreakerService, + cache, + stateManager, + checkpointReadQueue, + modelManager, + indexUtil, + resultWriteQueue, + stats, + forecastColdStartQueue, + ADSingleStreamResultAction.NAME, + ADIndex.RESULT, + AnalysisType.AD + ); + } + + @Override + public ADResultWriteRequest createResultWriteRequest(Config config, AnomalyResult result) { + return new ADResultWriteRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + config.getId(), + RequestPriority.MEDIUM, + result, + config.getCustomResultIndex() + ); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java b/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java index f6f39ab85..d6fa4c64b 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java @@ -12,22 +12,23 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.StatsNodesResponse; /** * ADStatsNodesAction class */ -public class ADStatsNodesAction extends ActionType { +public class ADStatsNodesAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "stats/nodes"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "stats/nodes"; public static final ADStatsNodesAction INSTANCE = new ADStatsNodesAction(); /** * Constructor */ private ADStatsNodesAction() { - super(NAME, ADStatsNodesResponse::new); + super(NAME, StatsNodesResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java index 17a81da0a..bfaacbef1 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java @@ -11,32 +11,27 @@ package org.opensearch.ad.transport; -import java.io.IOException; -import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Set; -import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.nodes.TransportNodesAction; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.InternalStatNames; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.stats.InternalStatNames; +import org.opensearch.timeseries.transport.BaseStatsNodesTransportAction; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsRequest; import org.opensearch.transport.TransportService; /** * ADStatsNodesTransportAction contains the logic to extract the stats from the nodes */ -public class ADStatsNodesTransportAction extends - TransportNodesAction { +public class ADStatsNodesTransportAction extends BaseStatsNodesTransportAction { - private ADStats adStats; private final JvmService jvmService; private final ADTaskManager adTaskManager; @@ -47,7 +42,7 @@ public class ADStatsNodesTransportAction extends * @param clusterService ClusterService * @param transportService TransportService * @param actionFilters Action Filters - * @param adStats ADStats object + * @param adStats TimeSeriesStats object * @param jvmService ES JVM Service * @param adTaskManager AD task manager */ @@ -61,48 +56,14 @@ public ADStatsNodesTransportAction( JvmService jvmService, ADTaskManager adTaskManager ) { - super( - ADStatsNodesAction.NAME, - threadPool, - clusterService, - transportService, - actionFilters, - ADStatsRequest::new, - ADStatsNodeRequest::new, - ThreadPool.Names.MANAGEMENT, - ADStatsNodeResponse.class - ); - this.adStats = adStats; + super(threadPool, clusterService, transportService, actionFilters, adStats, ADStatsNodesAction.NAME); this.jvmService = jvmService; this.adTaskManager = adTaskManager; } @Override - protected ADStatsNodesResponse newResponse( - ADStatsRequest request, - List responses, - List failures - ) { - return new ADStatsNodesResponse(clusterService.getClusterName(), responses, failures); - } - - @Override - protected ADStatsNodeRequest newNodeRequest(ADStatsRequest request) { - return new ADStatsNodeRequest(request); - } - - @Override - protected ADStatsNodeResponse newNodeResponse(StreamInput in) throws IOException { - return new ADStatsNodeResponse(in); - } - - @Override - protected ADStatsNodeResponse nodeOperation(ADStatsNodeRequest request) { - return createADStatsNodeResponse(request.getADStatsRequest()); - } - - private ADStatsNodeResponse createADStatsNodeResponse(ADStatsRequest adStatsRequest) { - Map statValues = new HashMap<>(); + protected StatsNodeResponse createADStatsNodeResponse(StatsRequest adStatsRequest) { + Map statValues = super.createADStatsNodeResponse(adStatsRequest).getStatsMap(); Set statsToBeRetrieved = adStatsRequest.getStatsToBeRetrieved(); if (statsToBeRetrieved.contains(InternalStatNames.JVM_HEAP_USAGE.getName())) { @@ -120,12 +81,6 @@ private ADStatsNodeResponse createADStatsNodeResponse(ADStatsRequest adStatsRequ statValues.put(InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName(), assignedBatchTaskSlot); } - for (String statName : adStats.getNodeStats().keySet()) { - if (statsToBeRetrieved.contains(statName)) { - statValues.put(statName, adStats.getStats().get(statName).getValue()); - } - } - - return new ADStatsNodeResponse(clusterService.localNode(), statValues); + return new StatsNodeResponse(clusterService.localNode(), statValues); } } diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java index f2b198d1c..f66d9e1ec 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java @@ -14,11 +14,11 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADTaskProfileAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/" + AD_TASK; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/" + AD_TASK; public static final ADTaskProfileAction INSTANCE = new ADTaskProfileAction(); private ADTaskProfileAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java index 6902d6de8..4bfbf7ca3 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java @@ -18,13 +18,13 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.transport.TransportService; public class ADTaskProfileTransportAction extends @@ -79,7 +79,7 @@ protected ADTaskProfileNodeResponse newNodeResponse(StreamInput in) throws IOExc @Override protected ADTaskProfileNodeResponse nodeOperation(ADTaskProfileNodeRequest request) { String remoteNodeId = request.getParentTask().getNodeId(); - Version remoteAdVersion = hashRing.getAdVersion(remoteNodeId); + Version remoteAdVersion = hashRing.getVersion(remoteNodeId); ADTaskProfile adTaskProfile = adTaskManager.getLocalADTaskProfilesByDetectorId(request.getId()); return new ADTaskProfileNodeResponse(clusterService.localNode(), adTaskProfile, remoteAdVersion); } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java index 83ea58960..b03180b70 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java @@ -12,12 +12,12 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; import org.opensearch.timeseries.transport.JobResponse; public class AnomalyDetectorJobAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/jobmanagement"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/jobmanagement"; public static final AnomalyDetectorJobAction INSTANCE = new AnomalyDetectorJobAction(); private AnomalyDetectorJobAction() { diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java deleted file mode 100644 index 3a62315a6..000000000 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.transport; - -import java.io.IOException; - -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.timeseries.model.DateRange; - -public class AnomalyDetectorJobRequest extends ActionRequest { - - private String detectorID; - private DateRange detectionDateRange; - private boolean historical; - private long seqNo; - private long primaryTerm; - private String rawPath; - - public AnomalyDetectorJobRequest(StreamInput in) throws IOException { - super(in); - detectorID = in.readString(); - seqNo = in.readLong(); - primaryTerm = in.readLong(); - rawPath = in.readString(); - if (in.readBoolean()) { - detectionDateRange = new DateRange(in); - } - historical = in.readBoolean(); - } - - public AnomalyDetectorJobRequest(String detectorID, long seqNo, long primaryTerm, String rawPath) { - this(detectorID, null, false, seqNo, primaryTerm, rawPath); - } - - /** - * Constructor function. - * - * The detectionDateRange and historical boolean can be passed in individually. - * The historical flag is for stopping detector, the detectionDateRange is for - * starting detector. It's ok if historical is true but detectionDateRange is - * null. - * - * @param detectorID detector identifier - * @param detectionDateRange detection date range - * @param historical historical analysis or not - * @param seqNo seq no - * @param primaryTerm primary term - * @param rawPath raw request path - */ - public AnomalyDetectorJobRequest( - String detectorID, - DateRange detectionDateRange, - boolean historical, - long seqNo, - long primaryTerm, - String rawPath - ) { - super(); - this.detectorID = detectorID; - this.detectionDateRange = detectionDateRange; - this.historical = historical; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; - this.rawPath = rawPath; - } - - public String getDetectorID() { - return detectorID; - } - - public DateRange getDetectionDateRange() { - return detectionDateRange; - } - - public long getSeqNo() { - return seqNo; - } - - public long getPrimaryTerm() { - return primaryTerm; - } - - public String getRawPath() { - return rawPath; - } - - public boolean isHistorical() { - return historical; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeString(detectorID); - out.writeLong(seqNo); - out.writeLong(primaryTerm); - out.writeString(rawPath); - if (detectionDateRange != null) { - out.writeBoolean(true); - detectionDateRange.writeTo(out); - } else { - out.writeBoolean(false); - } - out.writeBoolean(historical); - } - - @Override - public ActionRequestValidationException validate() { - return null; - } -} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java index 2ffb8b85a..358c9a062 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java @@ -15,47 +15,28 @@ import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_STOP_DETECTOR; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; -import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.tasks.Task; -import org.opensearch.timeseries.model.DateRange; -import org.opensearch.timeseries.transport.JobResponse; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.transport.BaseJobTransportAction; import org.opensearch.transport.TransportService; -public class AnomalyDetectorJobTransportAction extends HandledTransportAction { - private final Logger logger = LogManager.getLogger(AnomalyDetectorJobTransportAction.class); - - private final Client client; - private final ClusterService clusterService; - private final Settings settings; - private final ADIndexManagement anomalyDetectionIndices; - private final NamedXContentRegistry xContentRegistry; - private volatile Boolean filterByEnabled; - private final ADTaskManager adTaskManager; - private final TransportService transportService; - private final ExecuteADResultResponseRecorder recorder; - +public class AnomalyDetectorJobTransportAction extends + BaseJobTransportAction { @Inject public AnomalyDetectorJobTransportAction( TransportService transportService, @@ -63,95 +44,23 @@ public AnomalyDetectorJobTransportAction( Client client, ClusterService clusterService, Settings settings, - ADIndexManagement anomalyDetectionIndices, NamedXContentRegistry xContentRegistry, - ADTaskManager adTaskManager, - ExecuteADResultResponseRecorder recorder - ) { - super(AnomalyDetectorJobAction.NAME, transportService, actionFilters, AnomalyDetectorJobRequest::new); - this.transportService = transportService; - this.client = client; - this.clusterService = clusterService; - this.settings = settings; - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.xContentRegistry = xContentRegistry; - this.adTaskManager = adTaskManager; - filterByEnabled = AD_FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); - this.recorder = recorder; - } - - @Override - protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionListener actionListener) { - String detectorId = request.getDetectorID(); - DateRange detectionDateRange = request.getDetectionDateRange(); - boolean historical = request.isHistorical(); - long seqNo = request.getSeqNo(); - long primaryTerm = request.getPrimaryTerm(); - String rawPath = request.getRawPath(); - TimeValue requestTimeout = AD_REQUEST_TIMEOUT.get(settings); - String errorMessage = rawPath.endsWith(RestHandlerUtils.START_JOB) ? FAIL_TO_START_DETECTOR : FAIL_TO_STOP_DETECTOR; - ActionListener listener = wrapRestActionListener(actionListener, errorMessage); - - // By the time request reaches here, the user permissions are validated by Security plugin. - User user = getUserContext(client); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - resolveUserAndExecute( - user, - detectorId, - filterByEnabled, - listener, - (anomalyDetector) -> executeDetector( - listener, - detectorId, - detectionDateRange, - historical, - seqNo, - primaryTerm, - rawPath, - requestTimeout, - user, - context - ), - client, - clusterService, - xContentRegistry, - AnomalyDetector.class - ); - } catch (Exception e) { - logger.error(e); - listener.onFailure(e); - } - } - - private void executeDetector( - ActionListener listener, - String detectorId, - DateRange detectionDateRange, - boolean historical, - long seqNo, - long primaryTerm, - String rawPath, - TimeValue requestTimeout, - User user, - ThreadContext.StoredContext context + ADIndexJobActionHandler adIndexJobActionHandler ) { - IndexAnomalyDetectorJobActionHandler handler = new IndexAnomalyDetectorJobActionHandler( + super( + transportService, + actionFilters, client, - anomalyDetectionIndices, - detectorId, - seqNo, - primaryTerm, - requestTimeout, + clusterService, + settings, xContentRegistry, - transportService, - adTaskManager, - recorder + AD_FILTER_BY_BACKEND_ROLES, + AnomalyDetectorJobAction.NAME, + AD_REQUEST_TIMEOUT, + FAIL_TO_START_DETECTOR, + FAIL_TO_STOP_DETECTOR, + AnomalyDetector.class, + adIndexJobActionHandler ); - if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { - adTaskManager.startDetector(detectorId, detectionDateRange, handler, user, transportService, context, listener); - } else if (rawPath.endsWith(RestHandlerUtils.STOP_JOB)) { - adTaskManager.stopDetector(detectorId, historical, handler, user, transportService, listener); - } } } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java index d61bd5822..36c8a2c9d 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class AnomalyResultAction extends ActionType { - // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/run"; + // External Action which used for public facing RestAPIs or actions we need to assume cx's role. + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/run"; public static final AnomalyResultAction INSTANCE = new AnomalyResultAction(); private AnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java index e6f788aeb..397271da0 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java @@ -26,56 +26,24 @@ import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.transport.ResultRequest; -public class AnomalyResultRequest extends ActionRequest implements ToXContentObject { - private String adID; - // time range start and end. Unit: epoch milliseconds - private long start; - private long end; - +public class AnomalyResultRequest extends ResultRequest { public AnomalyResultRequest(StreamInput in) throws IOException { super(in); - adID = in.readString(); - start = in.readLong(); - end = in.readLong(); } public AnomalyResultRequest(String adID, long start, long end) { - super(); - this.adID = adID; - this.start = start; - this.end = end; - } - - public long getStart() { - return start; - } - - public long getEnd() { - return end; - } - - public String getAdID() { - return adID; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeString(adID); - out.writeLong(start); - out.writeLong(end); + super(adID, start, end); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(adID)) { + if (Strings.isEmpty(configId)) { validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); } if (start <= 0 || end <= 0 || start > end) { @@ -90,7 +58,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(ADCommonName.ID_JSON_KEY, configId); builder.field(CommonName.START_JSON_KEY, start); builder.field(CommonName.END_JSON_KEY, end); builder.endObject(); diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java index 67113d3af..8708cb92a 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java @@ -17,6 +17,7 @@ import java.time.Duration; import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Optional; @@ -27,11 +28,12 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.transport.ResultResponse; -public class AnomalyResultResponse extends ActionResponse implements ToXContentObject { +public class AnomalyResultResponse extends ResultResponse { public static final String ANOMALY_GRADE_JSON_KEY = "anomalyGrade"; public static final String CONFIDENCE_JSON_KEY = "confidence"; public static final String ANOMALY_SCORE_JSON_KEY = "anomalyScore"; @@ -49,18 +51,13 @@ public class AnomalyResultResponse extends ActionResponse implements ToXContentO private Double anomalyGrade; private Double confidence; - private Double anomalyScore; - private String error; - private List features; - private Long rcfTotalUpdates; - private Long detectorIntervalInMinutes; - private Boolean isHCDetector; private Integer relativeIndex; private double[] relevantAttribution; private double[] pastValues; private double[][] expectedValuesList; private double[] likelihoodOfValues; private Double threshold; + protected Double anomalyScore; // used when returning an error/exception or empty result public AnomalyResultResponse( @@ -68,7 +65,8 @@ public AnomalyResultResponse( String error, Long rcfTotalUpdates, Long detectorIntervalInMinutes, - Boolean isHCDetector + Boolean isHCDetector, + String taskId ) { this( Double.NaN, @@ -84,7 +82,8 @@ public AnomalyResultResponse( null, null, null, - Double.NaN + Double.NaN, + taskId ); } @@ -102,16 +101,13 @@ public AnomalyResultResponse( double[] pastValues, double[][] expectedValuesList, double[] likelihoodOfValues, - Double threshold + Double threshold, + String taskId ) { + super(features, error, rcfTotalUpdates, detectorIntervalInMinutes, isHCDetector, taskId); this.anomalyGrade = anomalyGrade; this.confidence = confidence; this.anomalyScore = anomalyScore; - this.features = features; - this.error = error; - this.rcfTotalUpdates = rcfTotalUpdates; - this.detectorIntervalInMinutes = detectorIntervalInMinutes; - this.isHCDetector = isHCDetector; this.relativeIndex = relativeIndex; this.relevantAttribution = currentTimeAttribution; this.pastValues = pastValues; @@ -134,8 +130,8 @@ public AnomalyResultResponse(StreamInput in) throws IOException { // new field added since AD 1.1 // Only send AnomalyResultRequest to local node, no need to change this part for BWC rcfTotalUpdates = in.readOptionalLong(); - detectorIntervalInMinutes = in.readOptionalLong(); - isHCDetector = in.readOptionalBoolean(); + configIntervalInMinutes = in.readOptionalLong(); + isHC = in.readOptionalBoolean(); this.relativeIndex = in.readOptionalInt(); @@ -171,16 +167,13 @@ public AnomalyResultResponse(StreamInput in) throws IOException { } this.threshold = in.readOptionalDouble(); + this.taskId = in.readOptionalString(); } public double getAnomalyGrade() { return anomalyGrade; } - public List getFeatures() { - return features; - } - public double getConfidence() { return confidence; } @@ -189,22 +182,6 @@ public double getAnomalyScore() { return anomalyScore; } - public String getError() { - return error; - } - - public Long getRcfTotalUpdates() { - return rcfTotalUpdates; - } - - public Long getIntervalInMinutes() { - return detectorIntervalInMinutes; - } - - public Boolean isHCDetector() { - return isHCDetector; - } - public Integer getRelativeIndex() { return relativeIndex; } @@ -240,8 +217,8 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalString(error); out.writeOptionalLong(rcfTotalUpdates); - out.writeOptionalLong(detectorIntervalInMinutes); - out.writeOptionalBoolean(isHCDetector); + out.writeOptionalLong(configIntervalInMinutes); + out.writeOptionalBoolean(isHC); out.writeOptionalInt(relativeIndex); @@ -280,6 +257,7 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalDouble(threshold); + out.writeOptionalString(taskId); } @Override @@ -295,13 +273,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); builder.field(RCF_TOTAL_UPDATES_JSON_KEY, rcfTotalUpdates); - builder.field(DETECTOR_INTERVAL_IN_MINUTES_JSON_KEY, detectorIntervalInMinutes); + builder.field(DETECTOR_INTERVAL_IN_MINUTES_JSON_KEY, configIntervalInMinutes); builder.field(RELATIVE_INDEX_FIELD_JSON_KEY, relativeIndex); builder.field(RELEVANT_ATTRIBUTION_FIELD_JSON_KEY, relevantAttribution); builder.field(PAST_VALUES_FIELD_JSON_KEY, pastValues); builder.field(EXPECTED_VAL_LIST_FIELD_JSON_KEY, expectedValuesList); builder.field(LIKELIHOOD_FIELD_JSON_KEY, likelihoodOfValues); builder.field(THRESHOLD_FIELD_JSON_KEY, threshold); + builder.field(CommonName.TASK_ID_FIELD, taskId); builder.endObject(); return builder; } @@ -325,7 +304,7 @@ public static AnomalyResultResponse fromActionResponse(final ActionResponse acti * * Convert AnomalyResultResponse to AnomalyResult * - * @param detectorId Detector Id + * @param configId Detector Id * @param dataStartInstant data start time * @param dataEndInstant data end time * @param executionStartInstant execution start time @@ -335,8 +314,9 @@ public static AnomalyResultResponse fromActionResponse(final ActionResponse acti * @param error Error * @return converted AnomalyResult */ - public AnomalyResult toAnomalyResult( - String detectorId, + @Override + public List toIndexableResults( + String configId, Instant dataStartInstant, Instant dataEndInstant, Instant executionStartInstant, @@ -347,30 +327,43 @@ public AnomalyResult toAnomalyResult( ) { // Detector interval in milliseconds long detectorIntervalMilli = Duration.between(dataStartInstant, dataEndInstant).toMillis(); - return AnomalyResult - .fromRawTRCFResult( - detectorId, - detectorIntervalMilli, - null, // real time results have no task id - anomalyScore, - anomalyGrade, - confidence, - features, - dataStartInstant, - dataEndInstant, - executionStartInstant, - executionEndInstant, - error, - Optional.empty(), - user, - schemaVersion, - null, // single-stream real-time has no model id - relevantAttribution, - relativeIndex, - pastValues, - expectedValuesList, - likelihoodOfValues, - threshold + return Collections + .singletonList( + AnomalyResult + .fromRawTRCFResult( + configId, + detectorIntervalMilli, + taskId, // real time results have no task id + anomalyScore, + anomalyGrade, + confidence, + features, + dataStartInstant, + dataEndInstant, + executionStartInstant, + executionEndInstant, + error, + Optional.empty(), + user, + schemaVersion, + null, // single-stream real-time has no model id + relevantAttribution, + relativeIndex, + pastValues, + expectedValuesList, + likelihoodOfValues, + threshold + ) ); } + + @Override + public boolean shouldSave() { + // skipping writing to the result index if not necessary + // For a single-stream analysis, the result is not useful if error is null + // and rcf score (e.g., thus anomaly grade/confidence/forecasts) is null. + // For a HC analysis, we don't need to save on the detector level. + // We return 0 or Double.NaN rcf score if there is no error. + return super.shouldSave() || anomalyScore > 0; + } } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java index 084db7f42..1ee5b3ab8 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java @@ -11,139 +11,58 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_PAGE_SIZE; -import static org.opensearch.timeseries.constant.CommonMessages.INVALID_SEARCH_QUERY_MSG; - -import java.net.ConnectException; -import java.util.ArrayList; import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ExceptionsHelper; -import org.opensearch.OpenSearchTimeoutException; -import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.ActionRequest; -import org.opensearch.action.search.SearchPhaseExecutionException; -import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.IndicesOptions; -import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.feature.CompositeRetriever; -import org.opensearch.ad.feature.CompositeRetriever.PageIterator; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SinglePointFeatures; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.stats.ADStats; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; -import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.block.ClusterBlockLevel; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.lease.Releasable; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.NetworkExceptionHelper; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; -import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.node.NodeClosedException; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; -import org.opensearch.timeseries.common.exception.ClientException; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.InternalFailure; import org.opensearch.timeseries.common.exception.LimitExceededException; -import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; -import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; -import org.opensearch.timeseries.model.Config; -import org.opensearch.timeseries.model.Entity; -import org.opensearch.timeseries.model.FeatureData; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; -import org.opensearch.timeseries.util.ExceptionUtil; -import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.transport.ResultProcessor; import org.opensearch.timeseries.util.SecurityClientUtil; -import org.opensearch.transport.ActionNotFoundTransportException; -import org.opensearch.transport.ConnectTransportException; -import org.opensearch.transport.NodeNotConnectedException; -import org.opensearch.transport.ReceiveTimeoutTransportException; -import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportService; public class AnomalyResultTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(AnomalyResultTransportAction.class); - static final String WAIT_FOR_THRESHOLD_ERR_MSG = "Exception in waiting for threshold result"; - static final String NODE_UNRESPONSIVE_ERR_MSG = "Model node is unresponsive. Mute node"; - static final String READ_WRITE_BLOCKED = "Cannot read/write due to global block."; - static final String INDEX_READ_BLOCKED = "Cannot read user index due to read block."; - static final String NULL_RESPONSE = "Received null response from"; - - static final String TROUBLE_QUERYING_ERR_MSG = "Having trouble querying data: "; - static final String NO_ACK_ERR = "no acknowledgements from model hosting nodes."; - - private final TransportService transportService; - private final NodeStateManager stateManager; - private final FeatureManager featureManager; - private final ModelManager modelManager; - private final HashRing hashRing; - private final TransportRequestOptions option; - private final ClusterService clusterService; - private final IndexNameExpressionResolver indexNameExpressionResolver; - private final ADStats adStats; - private final CircuitBreakerService adCircuitBreakerService; - private final ThreadPool threadPool; + private ADResultProcessor resultProcessor; private final Client client; - private final SecurityClientUtil clientUtil; - private final ADTaskManager adTaskManager; - + private CircuitBreakerService adCircuitBreakerService; // Cache HC detector id. This is used to count HC failure stats. We can tell a detector // is HC or not by checking if detector id exists in this field or not. Will add // detector id to this field when start to run realtime detection and remove detector // id once realtime detection done. private final Set hcDetectors; - private NamedXContentRegistry xContentRegistry; - private Settings settings; - // within an interval, how many percents are used to process requests. - // 1.0 means we use all of the detection interval to process requests. - // to ensure we don't block next interval, it is better to set it less than 1.0. - private final float intervalRatioForRequest; - private int maxEntitiesPerInterval; - private int pageSize; + private final ADStats adStats; + private final NodeStateManager nodeStateManager; @Inject public AnomalyResultTransportAction( @@ -152,9 +71,8 @@ public AnomalyResultTransportAction( Settings settings, Client client, SecurityClientUtil clientUtil, - NodeStateManager manager, + NodeStateManager nodeStateManager, FeatureManager featureManager, - ModelManager modelManager, HashRing hashRing, ClusterService clusterService, IndexNameExpressionResolver indexNameExpressionResolver, @@ -162,37 +80,34 @@ public AnomalyResultTransportAction( ADStats adStats, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, - ADTaskManager adTaskManager + ADTaskManager realTimeTaskManager ) { super(AnomalyResultAction.NAME, transportService, actionFilters, AnomalyResultRequest::new); - this.transportService = transportService; - this.settings = settings; + this.resultProcessor = new ADResultProcessor( + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityADResultAction.NAME, + StatNames.AD_HC_EXECUTE_REQUEST_COUNT, + settings, + clusterService, + threadPool, + hashRing, + nodeStateManager, + transportService, + adStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + AnomalyResultResponse.class, + featureManager + ); this.client = client; - this.clientUtil = clientUtil; - this.stateManager = manager; - this.featureManager = featureManager; - this.modelManager = modelManager; - this.hashRing = hashRing; - this.option = TransportRequestOptions - .builder() - .withType(TransportRequestOptions.Type.REG) - .withTimeout(AnomalyDetectorSettings.AD_REQUEST_TIMEOUT.get(settings)) - .build(); - this.clusterService = clusterService; - this.indexNameExpressionResolver = indexNameExpressionResolver; this.adCircuitBreakerService = adCircuitBreakerService; - this.adStats = adStats; - this.threadPool = threadPool; this.hcDetectors = new HashSet<>(); - this.xContentRegistry = xContentRegistry; - this.intervalRatioForRequest = TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS; - - this.maxEntitiesPerInterval = AD_MAX_ENTITIES_PER_QUERY.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_ENTITIES_PER_QUERY, it -> maxEntitiesPerInterval = it); - - this.pageSize = AD_PAGE_SIZE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_PAGE_SIZE, it -> pageSize = it); - this.adTaskManager = adTaskManager; + this.adStats = adStats; + this.nodeStateManager = nodeStateManager; } /** @@ -249,7 +164,7 @@ public AnomalyResultTransportAction( protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { AnomalyResultRequest request = AnomalyResultRequest.fromActionRequest(actionRequest); - String adID = request.getAdID(); + String adID = request.getConfigId(); ActionListener original = listener; listener = ActionListener.wrap(r -> { hcDetectors.remove(adID); @@ -278,864 +193,14 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< return; } try { - stateManager.getConfig(adID, AnalysisType.AD, onGetDetector(listener, adID, request)); + nodeStateManager + .getConfig(adID, AnalysisType.AD, resultProcessor.onGetConfig(listener, adID, request, Optional.of(hcDetectors))); } catch (Exception ex) { - handleExecuteException(ex, listener, adID); + ResultProcessor.handleExecuteException(ex, listener, adID); } } catch (Exception e) { LOG.error(e); listener.onFailure(e); } } - - /** - * didn't use ActionListener.wrap so that I can - * 1) use this to refer to the listener inside the listener - * 2) pass parameters using constructors - * - */ - class PageListener implements ActionListener { - private PageIterator pageIterator; - private String detectorId; - private long dataStartTime; - private long dataEndTime; - - PageListener(PageIterator pageIterator, String detectorId, long dataStartTime, long dataEndTime) { - this.pageIterator = pageIterator; - this.detectorId = detectorId; - this.dataStartTime = dataStartTime; - this.dataEndTime = dataEndTime; - } - - @Override - public void onResponse(CompositeRetriever.Page entityFeatures) { - if (pageIterator.hasNext()) { - pageIterator.next(this); - } - if (entityFeatures != null && false == entityFeatures.isEmpty()) { - // wrap expensive operation inside ad threadpool - threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> { - try { - - Set>> node2Entities = entityFeatures - .getResults() - .entrySet() - .stream() - .filter(e -> hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(e.getKey().toString()).isPresent()) - .collect( - Collectors - .groupingBy( - // from entity name to its node - e -> hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(e.getKey().toString()).get(), - Collectors.toMap(Entry::getKey, Entry::getValue) - ) - ) - .entrySet(); - - Iterator>> iterator = node2Entities.iterator(); - - while (iterator.hasNext()) { - Entry> entry = iterator.next(); - DiscoveryNode modelNode = entry.getKey(); - if (modelNode == null) { - iterator.remove(); - continue; - } - String modelNodeId = modelNode.getId(); - if (stateManager.isMuted(modelNodeId, detectorId)) { - LOG - .info( - String - .format(Locale.ROOT, NODE_UNRESPONSIVE_ERR_MSG + " %s for detector %s", modelNodeId, detectorId) - ); - iterator.remove(); - } - } - - final AtomicReference failure = new AtomicReference<>(); - node2Entities.stream().forEach(nodeEntity -> { - DiscoveryNode node = nodeEntity.getKey(); - transportService - .sendRequest( - node, - EntityResultAction.NAME, - new EntityResultRequest(detectorId, nodeEntity.getValue(), dataStartTime, dataEndTime), - option, - new ActionListenerResponseHandler<>( - new EntityResultListener(node.getId(), detectorId, failure), - AcknowledgedResponse::new, - ThreadPool.Names.SAME - ) - ); - }); - - } catch (Exception e) { - LOG.error("Unexpected exception", e); - handleException(e); - } - }); - } - } - - @Override - public void onFailure(Exception e) { - LOG.error("Unexpetected exception", e); - handleException(e); - } - - private void handleException(Exception e) { - Exception convertedException = convertedQueryFailureException(e, detectorId); - if (false == (convertedException instanceof TimeSeriesException)) { - Throwable cause = ExceptionsHelper.unwrapCause(convertedException); - convertedException = new InternalFailure(detectorId, cause); - } - stateManager.setException(detectorId, convertedException); - } - } - - private ActionListener> onGetDetector( - ActionListener listener, - String adID, - AnomalyResultRequest request - ) { - return ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - listener.onFailure(new EndRunException(adID, "AnomalyDetector is not available.", true)); - return; - } - - AnomalyDetector anomalyDetector = (AnomalyDetector) detectorOptional.get(); - if (anomalyDetector.isHighCardinality()) { - hcDetectors.add(adID); - adStats.getStat(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName()).increment(); - } - - long delayMillis = Optional - .ofNullable((IntervalTimeConfiguration) anomalyDetector.getWindowDelay()) - .map(t -> t.toDuration().toMillis()) - .orElse(0L); - long dataStartTime = request.getStart() - delayMillis; - long dataEndTime = request.getEnd() - delayMillis; - - adTaskManager - .initRealtimeTaskCacheAndCleanupStaleCache( - adID, - anomalyDetector, - transportService, - ActionListener - .runAfter( - initRealtimeTaskCacheListener(adID), - () -> executeAnomalyDetection(listener, adID, request, anomalyDetector, dataStartTime, dataEndTime) - ) - ); - }, exception -> handleExecuteException(exception, listener, adID)); - } - - private ActionListener initRealtimeTaskCacheListener(String detectorId) { - return ActionListener.wrap(r -> { - if (r) { - LOG.debug("Realtime task cache initied for detector {}", detectorId); - } - }, e -> LOG.error("Failed to init realtime task cache for " + detectorId, e)); - } - - private void executeAnomalyDetection( - ActionListener listener, - String adID, - AnomalyResultRequest request, - AnomalyDetector anomalyDetector, - long dataStartTime, - long dataEndTime - ) { - // HC logic starts here - if (anomalyDetector.isHighCardinality()) { - Optional previousException = stateManager.fetchExceptionAndClear(adID); - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error(new ParameterizedMessage("Previous exception of [{}]", adID), exception); - if (exception instanceof EndRunException) { - EndRunException endRunException = (EndRunException) exception; - if (endRunException.isEndNow()) { - listener.onFailure(exception); - return; - } - } - } - - // assume request are in epoch milliseconds - long nextDetectionStartTime = request.getEnd() + (long) (anomalyDetector.getIntervalInMilliseconds() * intervalRatioForRequest); - - CompositeRetriever compositeRetriever = new CompositeRetriever( - dataStartTime, - dataEndTime, - anomalyDetector, - xContentRegistry, - client, - clientUtil, - nextDetectionStartTime, - settings, - maxEntitiesPerInterval, - pageSize, - indexNameExpressionResolver, - clusterService - ); - - PageIterator pageIterator = null; - - try { - pageIterator = compositeRetriever.iterator(); - } catch (Exception e) { - listener.onFailure(new EndRunException(anomalyDetector.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, false)); - return; - } - - PageListener getEntityFeatureslistener = new PageListener(pageIterator, adID, dataStartTime, dataEndTime); - if (pageIterator.hasNext()) { - pageIterator.next(getEntityFeatureslistener); - } - - // We don't know when the pagination will not finish. To not - // block the following interval request to start, we return immediately. - // Pagination will stop itself when the time is up. - if (previousException.isPresent()) { - listener.onFailure(previousException.get()); - } else { - listener - .onResponse( - new AnomalyResultResponse(new ArrayList(), null, null, anomalyDetector.getIntervalInMinutes(), true) - ); - } - return; - } - - // HC logic ends and single entity logic starts here - // We are going to use only 1 model partition for a single stream detector. - // That's why we use 0 here. - String rcfModelID = SingleStreamModelIdMapper.getRcfModelId(adID, 0); - Optional asRCFNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID); - if (!asRCFNode.isPresent()) { - listener.onFailure(new InternalFailure(adID, "RCF model node is not available.")); - return; - } - - DiscoveryNode rcfNode = asRCFNode.get(); - - // we have already returned listener inside shouldStart method - if (!shouldStart(listener, adID, anomalyDetector, rcfNode.getId(), rcfModelID)) { - return; - } - - featureManager - .getCurrentFeatures( - anomalyDetector, - dataStartTime, - dataEndTime, - onFeatureResponseForSingleEntityDetector(adID, anomalyDetector, listener, rcfModelID, rcfNode, dataStartTime, dataEndTime) - ); - } - - // For single entity detector - private ActionListener onFeatureResponseForSingleEntityDetector( - String adID, - AnomalyDetector detector, - ActionListener listener, - String rcfModelId, - DiscoveryNode rcfNode, - long dataStartTime, - long dataEndTime - ) { - return ActionListener.wrap(featureOptional -> { - List featureInResponse = null; - if (featureOptional.getUnprocessedFeatures().isPresent()) { - featureInResponse = ParseUtils.getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector); - } - - if (!featureOptional.getProcessedFeatures().isPresent()) { - Optional exception = coldStartIfNoCheckPoint(detector); - if (exception.isPresent()) { - listener.onFailure(exception.get()); - return; - } - - if (!featureOptional.getUnprocessedFeatures().isPresent()) { - // Feature not available is common when we have data holes. Respond empty response - // and don't log to avoid bloating our logs. - LOG.debug("No data in current detection window between {} and {} for {}", dataStartTime, dataEndTime, adID); - listener - .onResponse( - new AnomalyResultResponse( - new ArrayList(), - "No data in current detection window", - null, - null, - false - ) - ); - } else { - LOG.debug("Return at least current feature value between {} and {} for {}", dataStartTime, dataEndTime, adID); - listener - .onResponse( - new AnomalyResultResponse(featureInResponse, "No full shingle in current detection window", null, null, false) - ); - } - return; - } - - final AtomicReference failure = new AtomicReference(); - - LOG.info("Sending RCF request to {} for model {}", rcfNode.getId(), rcfModelId); - - RCFActionListener rcfListener = new RCFActionListener( - rcfModelId, - failure, - rcfNode.getId(), - detector, - listener, - featureInResponse, - adID - ); - - transportService - .sendRequest( - rcfNode, - RCFResultAction.NAME, - new RCFResultRequest(adID, rcfModelId, featureOptional.getProcessedFeatures().get()), - option, - new ActionListenerResponseHandler<>(rcfListener, RCFResultResponse::new) - ); - }, exception -> { handleQueryFailure(exception, listener, adID); }); - } - - private void handleQueryFailure(Exception exception, ActionListener listener, String adID) { - Exception convertedQueryFailureException = convertedQueryFailureException(exception, adID); - - if (convertedQueryFailureException instanceof EndRunException) { - // invalid feature query - listener.onFailure(convertedQueryFailureException); - } else { - handleExecuteException(convertedQueryFailureException, listener, adID); - } - } - - /** - * Convert a query related exception to EndRunException - * - * These query exception can happen during the starting phase of the OpenSearch - * process. Thus, set the stopNow parameter of these EndRunException to false - * and confirm the EndRunException is not a false positive. - * - * @param exception Exception - * @param adID detector Id - * @return the converted exception if the exception is query related - */ - private Exception convertedQueryFailureException(Exception exception, String adID) { - if (ExceptionUtil.isIndexNotAvailable(exception)) { - return new EndRunException(adID, TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), false).countedInStats(false); - } else if (exception instanceof SearchPhaseExecutionException && invalidQuery((SearchPhaseExecutionException) exception)) { - // This is to catch invalid aggregation on wrong field type. For example, - // sum aggregation on text field. We should end detector run for such case. - return new EndRunException( - adID, - INVALID_SEARCH_QUERY_MSG + " " + ((SearchPhaseExecutionException) exception).getDetailedMessage(), - exception, - false - ).countedInStats(false); - } - - return exception; - } - - /** - * Verify failure of rcf or threshold models. If there is no model, trigger cold - * start. If there is an exception for the previous cold start of this detector, - * throw exception to the caller. - * - * @param failure object that may contain exceptions thrown - * @param detector detector object - * @return exception if AD job execution gets resource not found exception - * @throws Exception when the input failure is not a ResourceNotFoundException. - * List of exceptions we can throw - * 1. Exception from cold start: - * 1). InternalFailure due to - * a. OpenSearchTimeoutException thrown by putModelCheckpoint during cold start - * 2). EndRunException with endNow equal to false - * a. training data not available - * b. cold start cannot succeed - * c. invalid training data - * 3) EndRunException with endNow equal to true - * a. invalid search query - * 2. LimitExceededException from one of RCF model node when the total size of the models - * is more than X% of heap memory. - * 3. InternalFailure wrapping OpenSearchTimeoutException inside caused by - * RCF/Threshold model node failing to get checkpoint to restore model before timeout. - */ - private Exception coldStartIfNoModel(AtomicReference failure, AnomalyDetector detector) throws Exception { - Exception exp = failure.get(); - if (exp == null) { - return null; - } - - // return exceptions like LimitExceededException to caller - if (!(exp instanceof ResourceNotFoundException)) { - return exp; - } - - // fetch previous cold start exception - String adID = detector.getId(); - final Optional previousException = stateManager.fetchExceptionAndClear(adID); - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error("Previous exception of {}: {}", () -> adID, () -> exception); - if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { - return exception; - } - } - LOG.info("Trigger cold start for {}", detector.getId()); - coldStart(detector); - return previousException.orElse(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); - } - - private void findException(Throwable cause, String adID, AtomicReference failure, String nodeId) { - if (cause == null) { - LOG.error(new ParameterizedMessage("Null input exception")); - return; - } - if (cause instanceof Error) { - // we cannot do anything with Error. - LOG.error(new ParameterizedMessage("Error during prediction for {}: ", adID), cause); - return; - } - - Exception causeException = (Exception) cause; - - if (causeException instanceof TimeSeriesException) { - failure.set(causeException); - } else if (causeException instanceof NotSerializableExceptionWrapper) { - // we only expect this happens on AD exceptions - Optional actualException = NotSerializedExceptionName - .convertWrappedTimeSeriesException((NotSerializableExceptionWrapper) causeException, adID); - if (actualException.isPresent()) { - TimeSeriesException adException = actualException.get(); - failure.set(adException); - if (adException instanceof ResourceNotFoundException) { - // During a rolling upgrade or blue/green deployment, ResourceNotFoundException might be caused by old node using RCF - // 1.0 - // cannot recognize new checkpoint produced by the coordinating node using compact RCF. Add pressure to mute the node - // after consecutive failures. - stateManager.addPressure(nodeId, adID); - } - } else { - // some unexpected bugs occur while predicting anomaly - failure.set(new EndRunException(adID, CommonMessages.BUG_RESPONSE, causeException, false)); - } - } else if (causeException instanceof IndexNotFoundException - && causeException.getMessage().contains(ADCommonName.CHECKPOINT_INDEX_NAME)) { - // checkpoint index does not exist - // ResourceNotFoundException will trigger cold start later - failure.set(new ResourceNotFoundException(adID, causeException.getMessage())); - } else if (causeException instanceof OpenSearchTimeoutException) { - // we can have OpenSearchTimeoutException when a node tries to load RCF or - // threshold model - failure.set(new InternalFailure(adID, causeException)); - } else if (causeException instanceof IllegalArgumentException) { - // we can have IllegalArgumentException when a model is corrupted - failure.set(new InternalFailure(adID, causeException)); - } else { - // some unexpected bug occurred or cluster is unstable (e.g., ClusterBlockException) or index is red (e.g. - // NoShardAvailableActionException) while predicting anomaly - failure.set(new EndRunException(adID, CommonMessages.BUG_RESPONSE, causeException, false)); - } - } - - void handleExecuteException(Exception ex, ActionListener listener, String adID) { - if (ex instanceof ClientException) { - listener.onFailure(ex); - } else if (ex instanceof TimeSeriesException) { - listener.onFailure(new InternalFailure((TimeSeriesException) ex)); - } else { - Throwable cause = ExceptionsHelper.unwrapCause(ex); - listener.onFailure(new InternalFailure(adID, cause)); - } - } - - private boolean invalidQuery(SearchPhaseExecutionException ex) { - // If all shards return bad request and failure cause is IllegalArgumentException, we - // consider the feature query is invalid and will not count the error in failure stats. - for (ShardSearchFailure failure : ex.shardFailures()) { - if (RestStatus.BAD_REQUEST != failure.status() || !(failure.getCause() instanceof IllegalArgumentException)) { - return false; - } - } - return true; - } - - // For single entity detector - class RCFActionListener implements ActionListener { - private String modelID; - private AtomicReference failure; - private String rcfNodeID; - private AnomalyDetector detector; - private ActionListener listener; - private List featureInResponse; - private final String adID; - - RCFActionListener( - String modelID, - AtomicReference failure, - String rcfNodeID, - AnomalyDetector detector, - ActionListener listener, - List features, - String adID - ) { - this.modelID = modelID; - this.failure = failure; - this.rcfNodeID = rcfNodeID; - this.detector = detector; - this.listener = listener; - this.featureInResponse = features; - this.adID = adID; - } - - @Override - public void onResponse(RCFResultResponse response) { - try { - stateManager.resetBackpressureCounter(rcfNodeID, adID); - if (response != null) { - listener - .onResponse( - new AnomalyResultResponse( - response.getAnomalyGrade(), - response.getConfidence(), - response.getRCFScore(), - featureInResponse, - null, - response.getTotalUpdates(), - detector.getIntervalInMinutes(), - false, - response.getRelativeIndex(), - response.getAttribution(), - response.getPastValues(), - response.getExpectedValuesList(), - response.getLikelihoodOfValues(), - response.getThreshold() - ) - ); - } else { - LOG.warn(NULL_RESPONSE + " {} for {}", modelID, rcfNodeID); - listener.onFailure(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); - } - } catch (Exception ex) { - LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); - handleExecuteException(ex, listener, adID); - } - } - - @Override - public void onFailure(Exception e) { - try { - handlePredictionFailure(e, adID, rcfNodeID, failure); - Exception exception = coldStartIfNoModel(failure, detector); - if (exception != null) { - listener.onFailure(exception); - } else { - listener.onFailure(new InternalFailure(adID, "Node connection problem or unexpected exception")); - } - } catch (Exception ex) { - LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); - handleExecuteException(ex, listener, adID); - } - } - } - - /** - * Handle a prediction failure. Possibly (i.e., we don't always need to do that) - * convert the exception to a form that AD can recognize and handle and sets the - * input failure reference to the converted exception. - * - * @param e prediction exception - * @param adID Detector Id - * @param nodeID Node Id - * @param failure Parameter to receive the possibly converted function for the - * caller to deal with - */ - private void handlePredictionFailure(Exception e, String adID, String nodeID, AtomicReference failure) { - LOG.error(new ParameterizedMessage("Received an error from node {} while doing model inference for {}", nodeID, adID), e); - if (e == null) { - return; - } - Throwable cause = ExceptionsHelper.unwrapCause(e); - if (hasConnectionIssue(cause)) { - handleConnectionException(nodeID, adID); - } else { - findException(cause, adID, failure, nodeID); - } - } - - /** - * Check if the input exception indicates connection issues. - * During blue-green deployment, we may see ActionNotFoundTransportException. - * Count that as connection issue and isolate that node if it continues to happen. - * - * @param e exception - * @return true if we get disconnected from the node or the node is not in the - * right state (being closed) or transport request times out (sent from TimeoutHandler.run) - */ - private boolean hasConnectionIssue(Throwable e) { - return e instanceof ConnectTransportException - || e instanceof NodeClosedException - || e instanceof ReceiveTimeoutTransportException - || e instanceof NodeNotConnectedException - || e instanceof ConnectException - || NetworkExceptionHelper.isCloseConnectionException(e) - || e instanceof ActionNotFoundTransportException; - } - - private void handleConnectionException(String node, String detectorId) { - final DiscoveryNodes nodes = clusterService.state().nodes(); - if (!nodes.nodeExists(node)) { - hashRing.buildCirclesForRealtimeAD(); - return; - } - // rebuilding is not done or node is unresponsive - stateManager.addPressure(node, detectorId); - } - - /** - * Since we need to read from customer index and write to anomaly result index, - * we need to make sure we can read and write. - * - * @param state Cluster state - * @return whether we have global block or not - */ - private boolean checkGlobalBlock(ClusterState state) { - return state.blocks().globalBlockedException(ClusterBlockLevel.READ) != null - || state.blocks().globalBlockedException(ClusterBlockLevel.WRITE) != null; - } - - /** - * Similar to checkGlobalBlock, we check block on the indices level. - * - * @param state Cluster state - * @param level block level - * @param indices the indices on which to check block - * @return whether any of the index has block on the level. - */ - private boolean checkIndicesBlocked(ClusterState state, ClusterBlockLevel level, String... indices) { - // the original index might be an index expression with wildcards like "log*", - // so we need to expand the expression to concrete index name - String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(state, IndicesOptions.lenientExpandOpen(), indices); - - return state.blocks().indicesBlockedException(level, concreteIndices) != null; - } - - /** - * Check if we should start anomaly prediction. - * - * @param listener listener to respond back to AnomalyResultRequest. - * @param adID detector ID - * @param detector detector instance corresponds to adID - * @param rcfNodeId the rcf model hosting node ID for adID - * @param rcfModelID the rcf model ID for adID - * @return if we can start anomaly prediction. - */ - private boolean shouldStart( - ActionListener listener, - String adID, - AnomalyDetector detector, - String rcfNodeId, - String rcfModelID - ) { - ClusterState state = clusterService.state(); - if (checkGlobalBlock(state)) { - listener.onFailure(new InternalFailure(adID, READ_WRITE_BLOCKED)); - return false; - } - - if (stateManager.isMuted(rcfNodeId, adID)) { - listener - .onFailure( - new InternalFailure( - adID, - String.format(Locale.ROOT, NODE_UNRESPONSIVE_ERR_MSG + " %s for rcf model %s", rcfNodeId, rcfModelID) - ) - ); - return false; - } - - if (checkIndicesBlocked(state, ClusterBlockLevel.READ, detector.getIndices().toArray(new String[0]))) { - listener.onFailure(new InternalFailure(adID, INDEX_READ_BLOCKED)); - return false; - } - - return true; - } - - private void coldStart(AnomalyDetector detector) { - String detectorId = detector.getId(); - - // If last cold start is not finished, we don't trigger another one - if (stateManager.isColdStartRunning(detectorId)) { - return; - } - - final Releasable coldStartFinishingCallback = stateManager.markColdStartRunning(detectorId); - - ActionListener> listener = ActionListener.wrap(trainingData -> { - if (trainingData.isPresent()) { - double[][] dataPoints = trainingData.get(); - - ActionListener trainModelListener = ActionListener - .wrap(res -> { LOG.info("Succeeded in training {}", detectorId); }, exception -> { - if (exception instanceof TimeSeriesException) { - // e.g., partitioned model exceeds memory limit - stateManager.setException(detectorId, exception); - } else if (exception instanceof IllegalArgumentException) { - // IllegalArgumentException due to invalid training data - stateManager - .setException(detectorId, new EndRunException(detectorId, "Invalid training data", exception, false)); - } else if (exception instanceof OpenSearchTimeoutException) { - stateManager - .setException( - detectorId, - new InternalFailure(detectorId, "Time out while indexing cold start checkpoint", exception) - ); - } else { - stateManager - .setException(detectorId, new EndRunException(detectorId, "Error while training model", exception, false)); - } - }); - - modelManager - .trainModel( - detector, - dataPoints, - new ThreadedActionListener<>( - LOG, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - trainModelListener, - false - ) - ); - } else { - stateManager.setException(detectorId, new EndRunException(detectorId, "Cannot get training data", false)); - } - }, exception -> { - if (exception instanceof OpenSearchTimeoutException) { - stateManager.setException(detectorId, new InternalFailure(detectorId, "Time out while getting training data", exception)); - } else if (exception instanceof TimeSeriesException) { - // e.g., Invalid search query - stateManager.setException(detectorId, exception); - } else { - stateManager.setException(detectorId, new EndRunException(detectorId, "Error while cold start", exception, false)); - } - }); - - final ActionListener> listenerWithReleaseCallback = ActionListener - .runAfter(listener, coldStartFinishingCallback::close); - - threadPool - .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) - .execute( - () -> featureManager - .getColdStartData( - detector, - new ThreadedActionListener<>( - LOG, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - listenerWithReleaseCallback, - false - ) - ) - ); - } - - /** - * Check if checkpoint for an detector exists or not. If not and previous - * run is not EndRunException whose endNow is true, trigger cold start. - * @param detector detector object - * @return previous cold start exception - */ - private Optional coldStartIfNoCheckPoint(AnomalyDetector detector) { - String detectorId = detector.getId(); - - Optional previousException = stateManager.fetchExceptionAndClear(detectorId); - - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error(new ParameterizedMessage("Previous exception of {}:", detectorId), exception); - if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { - return previousException; - } - } - - stateManager.getDetectorCheckpoint(detectorId, ActionListener.wrap(checkpointExists -> { - if (!checkpointExists) { - LOG.info("Trigger cold start for {}", detectorId); - coldStart(detector); - } - }, exception -> { - Throwable cause = ExceptionsHelper.unwrapCause(exception); - if (cause instanceof IndexNotFoundException) { - LOG.info("Trigger cold start for {}", detectorId); - coldStart(detector); - } else { - String errorMsg = String.format(Locale.ROOT, "Fail to get checkpoint state for %s", detectorId); - LOG.error(errorMsg, exception); - stateManager.setException(detectorId, new TimeSeriesException(errorMsg, exception)); - } - })); - - return previousException; - } - - class EntityResultListener implements ActionListener { - private String nodeId; - private final String adID; - private AtomicReference failure; - - EntityResultListener(String nodeId, String adID, AtomicReference failure) { - this.nodeId = nodeId; - this.adID = adID; - this.failure = failure; - } - - @Override - public void onResponse(AcknowledgedResponse response) { - try { - if (response.isAcknowledged() == false) { - LOG.error("Cannot send entities' features to {} for {}", nodeId, adID); - stateManager.addPressure(nodeId, adID); - } else { - stateManager.resetBackpressureCounter(nodeId, adID); - } - } catch (Exception ex) { - LOG.error("Unexpected exception: {} for {}", ex, adID); - handleException(ex); - } - } - - @Override - public void onFailure(Exception e) { - try { - // e.g., we have connection issues with all of the nodes while restarting clusters - LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, adID), e); - - handleException(e); - - } catch (Exception ex) { - LOG.error("Unexpected exception: {} for {}", ex, adID); - handleException(ex); - } - } - - private void handleException(Exception e) { - handlePredictionFailure(e, adID, nodeId, failure); - if (failure.get() != null) { - stateManager.setException(adID, failure.get()); - } - } - } } diff --git a/src/main/java/org/opensearch/ad/transport/CronAction.java b/src/main/java/org/opensearch/ad/transport/CronAction.java index 1e64a0f45..91a5aa2cb 100644 --- a/src/main/java/org/opensearch/ad/transport/CronAction.java +++ b/src/main/java/org/opensearch/ad/transport/CronAction.java @@ -12,11 +12,12 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.CronResponse; public class CronAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "cron"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "cron"; public static final CronAction INSTANCE = new CronAction(); private CronAction() { diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelAction.java b/src/main/java/org/opensearch/ad/transport/DeleteADModelAction.java similarity index 55% rename from src/main/java/org/opensearch/ad/transport/DeleteModelAction.java rename to src/main/java/org/opensearch/ad/transport/DeleteADModelAction.java index 3af6982b0..c4eeef176 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteADModelAction.java @@ -12,14 +12,15 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.DeleteModelResponse; -public class DeleteModelAction extends ActionType { +public class DeleteADModelAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "model/delete"; - public static final DeleteModelAction INSTANCE = new DeleteModelAction(); + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "model/delete"; + public static final DeleteADModelAction INSTANCE = new DeleteADModelAction(); - private DeleteModelAction() { + private DeleteADModelAction() { super(NAME, DeleteModelResponse::new); } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteADModelTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteADModelTransportAction.java new file mode 100644 index 000000000..a1a5e78a8 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteADModelTransportAction.java @@ -0,0 +1,104 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.BaseDeleteModelTransportAction; +import org.opensearch.timeseries.transport.DeleteModelNodeRequest; +import org.opensearch.timeseries.transport.DeleteModelNodeResponse; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class DeleteADModelTransportAction extends + BaseDeleteModelTransportAction { + private static final Logger LOG = LogManager.getLogger(DeleteADModelTransportAction.class); + private ADModelManager modelManager; + private FeatureManager featureManager; + + @Inject + public DeleteADModelTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NodeStateManager nodeStateManager, + ADModelManager modelManager, + FeatureManager featureManager, + ADCacheProvider cache, + ADTaskCacheManager adTaskCacheManager, + ADColdStart coldStarter + ) { + super( + threadPool, + clusterService, + transportService, + actionFilters, + nodeStateManager, + cache, + adTaskCacheManager, + coldStarter, + DeleteADModelAction.NAME + ); + this.modelManager = modelManager; + this.featureManager = featureManager; + } + + /** + * + * Delete checkpoint document (including both RCF and thresholding model), in-memory models, + * buffered shingle data, transport state, and anomaly result + * + * @param request delete request + * @return delete response including local node Id. + */ + @Override + protected DeleteModelNodeResponse nodeOperation(DeleteModelNodeRequest request) { + super.nodeOperation(request); + String adID = request.getConfigID(); + + // delete in-memory models and model checkpoint + modelManager + .clear( + adID, + ActionListener + .wrap( + r -> LOG.info("Deleted model for [{}] with response [{}] ", adID, r), + e -> LOG.error("Fail to delete model for " + adID, e) + ) + ); + + // delete buffered shingle data + featureManager.clear(adID); + + LOG.info("Finished deleting ad models for {}", adID); + return new DeleteModelNodeResponse(clusterService.localNode()); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java index 75dc34638..70d655507 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java @@ -13,11 +13,11 @@ import org.opensearch.action.ActionType; import org.opensearch.action.delete.DeleteResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class DeleteAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/delete"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/delete"; public static final DeleteAnomalyDetectorAction INSTANCE = new DeleteAnomalyDetectorAction(); private DeleteAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java index 221a935bc..33124125d 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java @@ -11,58 +11,28 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_DELETE_DETECTOR; -import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; -import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; - -import java.io.IOException; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.action.delete.DeleteResponse; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.tasks.Task; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.function.ExecutorFunction; -import org.opensearch.timeseries.model.Job; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.transport.BaseDeleteConfigTransportAction; import org.opensearch.transport.TransportService; -public class DeleteAnomalyDetectorTransportAction extends HandledTransportAction { - - private static final Logger LOG = LogManager.getLogger(DeleteAnomalyDetectorTransportAction.class); - private final Client client; - private final ClusterService clusterService; - private final TransportService transportService; - private NamedXContentRegistry xContentRegistry; - private final ADTaskManager adTaskManager; - private volatile Boolean filterByEnabled; +public class DeleteAnomalyDetectorTransportAction extends + BaseDeleteConfigTransportAction { @Inject public DeleteAnomalyDetectorTransportAction( @@ -72,153 +42,24 @@ public DeleteAnomalyDetectorTransportAction( ClusterService clusterService, Settings settings, NamedXContentRegistry xContentRegistry, + NodeStateManager nodeStateManager, ADTaskManager adTaskManager ) { - super(DeleteAnomalyDetectorAction.NAME, transportService, actionFilters, DeleteAnomalyDetectorRequest::new); - this.transportService = transportService; - this.client = client; - this.clusterService = clusterService; - this.xContentRegistry = xContentRegistry; - this.adTaskManager = adTaskManager; - filterByEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); - } - - @Override - protected void doExecute(Task task, DeleteAnomalyDetectorRequest request, ActionListener actionListener) { - String detectorId = request.getDetectorID(); - LOG.info("Delete anomaly detector job {}", detectorId); - User user = getUserContext(client); - ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_DELETE_DETECTOR); - // By the time request reaches here, the user permissions are validated by Security plugin. - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - resolveUserAndExecute( - user, - detectorId, - filterByEnabled, - listener, - (anomalyDetector) -> adTaskManager.getDetector(detectorId, detector -> { - if (!detector.isPresent()) { - // In a mixed cluster, if delete detector request routes to node running AD1.0, then it will - // not delete detector tasks. User can re-delete these deleted detector after cluster upgraded, - // in that case, the detector is not present. - LOG.info("Can't find anomaly detector {}", detectorId); - adTaskManager.deleteADTasks(detectorId, () -> deleteAnomalyDetectorJobDoc(detectorId, listener), listener); - return; - } - // Check if there is realtime job or historical analysis task running. If none of these running, we - // can delete the detector. - getDetectorJob(detectorId, listener, () -> { - adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, HISTORICAL_DETECTOR_TASK_TYPES, adTask -> { - if (adTask.isPresent() && !adTask.get().isDone()) { - listener.onFailure(new OpenSearchStatusException("Detector is running", RestStatus.INTERNAL_SERVER_ERROR)); - } else { - adTaskManager.deleteADTasks(detectorId, () -> deleteAnomalyDetectorJobDoc(detectorId, listener), listener); - } - }, transportService, true, listener); - }); - }, listener), - client, - clusterService, - xContentRegistry, - AnomalyDetector.class - ); - } catch (Exception e) { - LOG.error(e); - listener.onFailure(e); - } - } - - private void deleteAnomalyDetectorJobDoc(String detectorId, ActionListener listener) { - LOG.info("Delete anomaly detector job {}", detectorId); - DeleteRequest deleteRequest = new DeleteRequest(CommonName.JOB_INDEX, detectorId) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.delete(deleteRequest, ActionListener.wrap(response -> { - if (response.getResult() == DocWriteResponse.Result.DELETED || response.getResult() == DocWriteResponse.Result.NOT_FOUND) { - deleteDetectorStateDoc(detectorId, listener); - } else { - String message = "Fail to delete anomaly detector job " + detectorId; - LOG.error(message); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - }, exception -> { - LOG.error("Failed to delete AD job for " + detectorId, exception); - if (exception instanceof IndexNotFoundException) { - deleteDetectorStateDoc(detectorId, listener); - } else { - LOG.error("Failed to delete anomaly detector job", exception); - listener.onFailure(exception); - } - })); - } - - private void deleteDetectorStateDoc(String detectorId, ActionListener listener) { - LOG.info("Delete detector info {}", detectorId); - DeleteRequest deleteRequest = new DeleteRequest(ADCommonName.DETECTION_STATE_INDEX, detectorId); - client.delete(deleteRequest, ActionListener.wrap(response -> { - // whether deleted state doc or not, continue as state doc may not exist - deleteAnomalyDetectorDoc(detectorId, listener); - }, exception -> { - if (exception instanceof IndexNotFoundException) { - deleteAnomalyDetectorDoc(detectorId, listener); - } else { - LOG.error("Failed to delete detector state", exception); - listener.onFailure(exception); - } - })); - } - - private void deleteAnomalyDetectorDoc(String detectorId, ActionListener listener) { - LOG.info("Delete anomaly detector {}", detectorId); - DeleteRequest deleteRequest = new DeleteRequest(CommonName.CONFIG_INDEX, detectorId) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.delete(deleteRequest, new ActionListener() { - @Override - public void onResponse(DeleteResponse deleteResponse) { - listener.onResponse(deleteResponse); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }); - } - - private void getDetectorJob(String detectorId, ActionListener listener, ExecutorFunction function) { - if (clusterService.state().metadata().indices().containsKey(CommonName.JOB_INDEX)) { - GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - client.get(request, ActionListener.wrap(response -> onGetAdJobResponseForWrite(response, listener, function), exception -> { - LOG.error("Fail to get anomaly detector job: " + detectorId, exception); - listener.onFailure(exception); - })); - } else { - function.execute(); - } - } - - private void onGetAdJobResponseForWrite(GetResponse response, ActionListener listener, ExecutorFunction function) - throws IOException { - if (response.isExists()) { - String adJobId = response.getId(); - if (adJobId != null) { - // check if AD job is running on the detector, if yes, we can't delete the detector - try ( - XContentParser parser = RestHandlerUtils - .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job adJob = Job.parse(parser); - if (adJob.isEnabled()) { - listener.onFailure(new OpenSearchStatusException("Detector job is running: " + adJobId, RestStatus.BAD_REQUEST)); - return; - } - } catch (IOException e) { - String message = "Failed to parse anomaly detector job " + adJobId; - LOG.error(message, e); - } - } - } - function.execute(); + super( + transportService, + actionFilters, + client, + clusterService, + settings, + xContentRegistry, + nodeStateManager, + adTaskManager, + DeleteAnomalyDetectorAction.NAME, + AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES, + AnalysisType.AD, + ADCommonName.DETECTION_STATE_INDEX, + AnomalyDetector.class, + ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES + ); } } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsAction.java index ae9de4c95..84065dbb7 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsAction.java @@ -12,12 +12,12 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; import org.opensearch.index.reindex.BulkByScrollResponse; public class DeleteAnomalyResultsAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "results/delete"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "results/delete"; public static final DeleteAnomalyResultsAction INSTANCE = new DeleteAnomalyResultsAction(); private DeleteAnomalyResultsAction() { diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java index e2db9ed4a..8c218ac64 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java @@ -13,8 +13,6 @@ import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_DELETE_AD_RESULT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; -import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; import org.apache.logging.log4j.LogManager; @@ -32,6 +30,7 @@ import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.transport.TransportService; public class DeleteAnomalyResultsTransportAction extends HandledTransportAction { @@ -61,7 +60,7 @@ protected void doExecute(Task task, DeleteByQueryRequest request, ActionListener } public void delete(DeleteByQueryRequest request, ActionListener listener) { - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { validateRole(request, user, listener); } catch (Exception e) { @@ -79,7 +78,7 @@ private void validateRole(DeleteByQueryRequest request, User user, ActionListene } else { // Security is enabled and backend role filter is enabled try { - addUserBackendRolesFilter(user, request.getSearchRequest().source()); + ParseUtils.addUserBackendRolesFilter(user, request.getSearchRequest().source()); client.execute(DeleteByQueryAction.INSTANCE, request, listener); } catch (Exception e) { listener.onFailure(e); diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultAction.java b/src/main/java/org/opensearch/ad/transport/EntityADResultAction.java similarity index 62% rename from src/main/java/org/opensearch/ad/transport/EntityResultAction.java rename to src/main/java/org/opensearch/ad/transport/EntityADResultAction.java index c519858b4..f17c23416 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/EntityADResultAction.java @@ -13,14 +13,14 @@ import org.opensearch.action.ActionType; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; -public class EntityResultAction extends ActionType { +public class EntityADResultAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "entity/result"; - public static final EntityResultAction INSTANCE = new EntityResultAction(); + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "entity/result"; + public static final EntityADResultAction INSTANCE = new EntityADResultAction(); - private EntityResultAction() { + private EntityADResultAction() { super(NAME, AcknowledgedResponse::new); } diff --git a/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java new file mode 100644 index 000000000..8147e47b5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java @@ -0,0 +1,176 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; +import org.opensearch.ad.ratelimit.ADSaveResultStrategy; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.EntityResultProcessor; +import org.opensearch.timeseries.transport.EntityResultRequest; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Entry-point for HCAD workflow. We have created multiple queues for + * coordinating the workflow. The overrall workflow is: 1. We store as many + * frequently used entity models in a cache as allowed by the memory limit (10% + * heap). If an entity feature is a hit, we use the in-memory model to detect + * anomalies and record results using the result write queue. 2. If an entity + * feature is a miss, we check if there is free memory or any other entity's + * model can be evacuated. An in-memory entity's frequency may be lower compared + * to the cache miss entity. If that's the case, we replace the lower frequency + * entity's model with the higher frequency entity's model. To load the higher + * frequency entity's model, we first check if a model exists on disk by sending + * a checkpoint read queue request. If there is a checkpoint, we load it to + * memory, perform detection, and save the result using the result write queue. + * Otherwise, we enqueue a cold start request to the cold start queue for model + * training. If training is successful, we save the learned model via the + * checkpoint write queue. 3. We also have the cold entity queue configured for + * cold entities, and the model training and inference are connected by serial + * juxtaposition to limit resource usage. + */ +public class EntityADResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(EntityADResultTransportAction.class); + private CircuitBreakerService adCircuitBreakerService; + private CacheProvider cache; + private final NodeStateManager stateManager; + private ThreadPool threadPool; + private EntityResultProcessor intervalDataProcessor; + + private final ADCacheProvider entityCache; + private final ADModelManager manager; + private final ADStats timeSeriesStats; + private final ADColdStartWorker entityColdStartWorker; + private final ADCheckpointReadWorker checkpointReadQueue; + private final ADColdEntityWorker coldEntityQueue; + private final ADSaveResultStrategy adSaveResultStategy; + + @Inject + public EntityADResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + ADModelManager manager, + CircuitBreakerService adCircuitBreakerService, + ADCacheProvider entityCache, + NodeStateManager stateManager, + ADIndexManagement indexUtil, + ADResultWriteWorker resultWriteQueue, + ADCheckpointReadWorker checkpointReadQueue, + ADColdEntityWorker coldEntityQueue, + ThreadPool threadPool, + ADColdStartWorker entityColdStartWorker, + ADStats timeSeriesStats, + ADSaveResultStrategy adSaveResultStategy + ) { + super(EntityADResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); + this.adCircuitBreakerService = adCircuitBreakerService; + this.cache = entityCache; + this.stateManager = stateManager; + this.threadPool = threadPool; + + this.entityCache = entityCache; + this.manager = manager; + this.timeSeriesStats = timeSeriesStats; + this.entityColdStartWorker = entityColdStartWorker; + this.checkpointReadQueue = checkpointReadQueue; + this.coldEntityQueue = coldEntityQueue; + this.adSaveResultStategy = adSaveResultStategy; + this.intervalDataProcessor = null; + } + + @Override + protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { + if (adCircuitBreakerService.isOpen()) { + threadPool + .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) + .execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); + listener.onFailure(new LimitExceededException(request.getConfigId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + try { + String detectorId = request.getConfigId(); + + Optional previousException = stateManager.fetchExceptionAndClear(detectorId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", detectorId, exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + + listener = ExceptionUtil.wrapListener(listener, exception, detectorId); + } + + this.intervalDataProcessor = new EntityResultProcessor<>( + entityCache, + manager, + timeSeriesStats, + entityColdStartWorker, + checkpointReadQueue, + coldEntityQueue, + adSaveResultStategy, + StatNames.AD_MODEL_CORRUTPION_COUNT + ); + + stateManager + .getConfig( + detectorId, + request.getAnalysisType(), + intervalDataProcessor.onGetConfig(listener, detectorId, request, previousException, request.getAnalysisType()) + ); + } catch (Exception exception) { + LOG.error("fail to get entity's anomaly grade", exception); + listener.onFailure(exception); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java deleted file mode 100644 index d17ce7137..000000000 --- a/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java +++ /dev/null @@ -1,356 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.transport; - -import java.time.Instant; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.indices.ADIndex; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.ml.ThresholdingResult; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.EntityFeatureRequest; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.ratelimit.ResultWriteRequest; -import org.opensearch.ad.ratelimit.ResultWriteWorker; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.common.inject.Inject; -import org.opensearch.core.action.ActionListener; -import org.opensearch.tasks.Task; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.AnalysisType; -import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.breaker.CircuitBreakerService; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.LimitExceededException; -import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.model.Config; -import org.opensearch.timeseries.model.Entity; -import org.opensearch.timeseries.stats.StatNames; -import org.opensearch.timeseries.util.ExceptionUtil; -import org.opensearch.timeseries.util.ParseUtils; -import org.opensearch.transport.TransportService; - -/** - * Entry-point for HCAD workflow. We have created multiple queues for coordinating - * the workflow. The overrall workflow is: - * 1. We store as many frequently used entity models in a cache as allowed by the - * memory limit (10% heap). If an entity feature is a hit, we use the in-memory model - * to detect anomalies and record results using the result write queue. - * 2. If an entity feature is a miss, we check if there is free memory or any other - * entity's model can be evacuated. An in-memory entity's frequency may be lower - * compared to the cache miss entity. If that's the case, we replace the lower - * frequency entity's model with the higher frequency entity's model. To load the - * higher frequency entity's model, we first check if a model exists on disk by - * sending a checkpoint read queue request. If there is a checkpoint, we load it - * to memory, perform detection, and save the result using the result write queue. - * Otherwise, we enqueue a cold start request to the cold start queue for model - * training. If training is successful, we save the learned model via the checkpoint - * write queue. - * 3. We also have the cold entity queue configured for cold entities, and the model - * training and inference are connected by serial juxtaposition to limit resource usage. - */ -public class EntityResultTransportAction extends HandledTransportAction { - - private static final Logger LOG = LogManager.getLogger(EntityResultTransportAction.class); - private ModelManager modelManager; - private CircuitBreakerService adCircuitBreakerService; - private CacheProvider cache; - private final NodeStateManager stateManager; - private ADIndexManagement indexUtil; - private ResultWriteWorker resultWriteQueue; - private CheckpointReadWorker checkpointReadQueue; - private ColdEntityWorker coldEntityQueue; - private ThreadPool threadPool; - private EntityColdStartWorker entityColdStartWorker; - private ADStats adStats; - - @Inject - public EntityResultTransportAction( - ActionFilters actionFilters, - TransportService transportService, - ModelManager manager, - CircuitBreakerService adCircuitBreakerService, - CacheProvider entityCache, - NodeStateManager stateManager, - ADIndexManagement indexUtil, - ResultWriteWorker resultWriteQueue, - CheckpointReadWorker checkpointReadQueue, - ColdEntityWorker coldEntityQueue, - ThreadPool threadPool, - EntityColdStartWorker entityColdStartWorker, - ADStats adStats - ) { - super(EntityResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); - this.modelManager = manager; - this.adCircuitBreakerService = adCircuitBreakerService; - this.cache = entityCache; - this.stateManager = stateManager; - this.indexUtil = indexUtil; - this.resultWriteQueue = resultWriteQueue; - this.checkpointReadQueue = checkpointReadQueue; - this.coldEntityQueue = coldEntityQueue; - this.threadPool = threadPool; - this.entityColdStartWorker = entityColdStartWorker; - this.adStats = adStats; - } - - @Override - protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { - if (adCircuitBreakerService.isOpen()) { - threadPool - .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) - .execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); - listener.onFailure(new LimitExceededException(request.getId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); - return; - } - - try { - String detectorId = request.getId(); - - Optional previousException = stateManager.fetchExceptionAndClear(detectorId); - - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error("Previous exception of {}: {}", detectorId, exception); - if (exception instanceof EndRunException) { - EndRunException endRunException = (EndRunException) exception; - if (endRunException.isEndNow()) { - listener.onFailure(exception); - return; - } - } - - listener = ExceptionUtil.wrapListener(listener, exception, detectorId); - } - - stateManager.getConfig(detectorId, AnalysisType.AD, onGetDetector(listener, detectorId, request, previousException)); - } catch (Exception exception) { - LOG.error("fail to get entity's anomaly grade", exception); - listener.onFailure(exception); - } - } - - private ActionListener> onGetDetector( - ActionListener listener, - String detectorId, - EntityResultRequest request, - Optional prevException - ) { - return ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - listener.onFailure(new EndRunException(detectorId, "AnomalyDetector is not available.", false)); - return; - } - - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); - - if (request.getEntities() == null) { - listener.onFailure(new EndRunException(detectorId, "Fail to get any entities from request.", false)); - return; - } - - Instant executionStartTime = Instant.now(); - Map cacheMissEntities = new HashMap<>(); - for (Entry entityEntry : request.getEntities().entrySet()) { - Entity categoricalValues = entityEntry.getKey(); - - if (isEntityFromOldNodeMsg(categoricalValues) - && detector.getCategoryFields() != null - && detector.getCategoryFields().size() == 1) { - Map attrValues = categoricalValues.getAttributes(); - // handle a request from a version before OpenSearch 1.1. - categoricalValues = Entity - .createSingleAttributeEntity(detector.getCategoryFields().get(0), attrValues.get(ADCommonName.EMPTY_FIELD)); - } - - Optional modelIdOptional = categoricalValues.getModelId(detectorId); - if (false == modelIdOptional.isPresent()) { - continue; - } - - String modelId = modelIdOptional.get(); - double[] datapoint = entityEntry.getValue(); - ModelState entityModel = cache.get().get(modelId, detector); - if (entityModel == null) { - // cache miss - cacheMissEntities.put(categoricalValues, datapoint); - continue; - } - try { - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(datapoint, entityModel, modelId, categoricalValues, detector.getShingleSize()); - // result.getRcfScore() = 0 means the model is not initialized - // result.getGrade() = 0 means it is not an anomaly - // So many OpenSearchRejectedExecutionException if we write no matter what - if (result.getRcfScore() > 0) { - List resultsToSave = result - .toIndexableResults( - detector, - Instant.ofEpochMilli(request.getStart()), - Instant.ofEpochMilli(request.getEnd()), - executionStartTime, - Instant.now(), - ParseUtils.getFeatureData(datapoint, detector), - Optional.ofNullable(categoricalValues), - indexUtil.getSchemaVersion(ADIndex.RESULT), - modelId, - null, - null - ); - for (AnomalyResult r : resultsToSave) { - resultWriteQueue - .put( - new ResultWriteRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, - r, - detector.getCustomResultIndex() - ) - ); - } - } - } catch (IllegalArgumentException e) { - // fail to score likely due to model corruption. Re-cold start to recover. - LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); - adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).increment(); - cache.get().removeEntityModel(detectorId, modelId); - entityColdStartWorker - .put( - new EntityFeatureRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - RequestPriority.MEDIUM, - categoricalValues, - datapoint, - request.getStart() - ) - ); - } - } - - // split hot and cold entities - Pair, List> hotColdEntities = cache - .get() - .selectUpdateCandidate(cacheMissEntities.keySet(), detectorId, detector); - - List hotEntityRequests = new ArrayList<>(); - List coldEntityRequests = new ArrayList<>(); - - for (Entity hotEntity : hotColdEntities.getLeft()) { - double[] hotEntityValue = cacheMissEntities.get(hotEntity); - if (hotEntityValue == null) { - LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", hotEntity)); - continue; - } - hotEntityRequests - .add( - new EntityFeatureRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - // hot entities has MEDIUM priority - RequestPriority.MEDIUM, - hotEntity, - hotEntityValue, - request.getStart() - ) - ); - } - - for (Entity coldEntity : hotColdEntities.getRight()) { - double[] coldEntityValue = cacheMissEntities.get(coldEntity); - if (coldEntityValue == null) { - LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", coldEntity)); - continue; - } - coldEntityRequests - .add( - new EntityFeatureRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - // cold entities has LOW priority - RequestPriority.LOW, - coldEntity, - coldEntityValue, - request.getStart() - ) - ); - } - - checkpointReadQueue.putAll(hotEntityRequests); - coldEntityQueue.putAll(coldEntityRequests); - - // respond back - if (prevException.isPresent()) { - listener.onFailure(prevException.get()); - } else { - listener.onResponse(new AcknowledgedResponse(true)); - } - }, exception -> { - LOG - .error( - new ParameterizedMessage( - "fail to get entity's anomaly grade for detector [{}]: start: [{}], end: [{}]", - detectorId, - request.getStart(), - request.getEnd() - ), - exception - ); - listener.onFailure(exception); - }); - } - - /** - * Whether the received entity comes from an node that doesn't support multi-category fields. - * This can happen during rolling-upgrade or blue/green deployment. - * - * Specifically, when receiving an EntityResultRequest from an incompatible node, - * EntityResultRequest(StreamInput in) gets an String that represents an entity. - * But Entity class requires both an category field name and value. Since we - * don't have access to detector config in EntityResultRequest(StreamInput in), - * we put CommonName.EMPTY_FIELD as the placeholder. In this method, - * we use the same CommonName.EMPTY_FIELD to check if the deserialized entity - * comes from an incompatible node. If it is, we will add the field name back - * as EntityResultTranportAction has access to the detector config object. - * - * @param categoricalValues deserialized Entity from inbound message. - * @return Whether the received entity comes from an node that doesn't support multi-category fields. - */ - private boolean isEntityFromOldNodeMsg(Entity categoricalValues) { - Map attrValues = categoricalValues.getAttributes(); - return (attrValues != null && attrValues.containsKey(ADCommonName.EMPTY_FIELD)); - } -} diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java b/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java index f63a188cd..43c62eed3 100644 --- a/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java @@ -14,12 +14,12 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; import org.opensearch.timeseries.transport.JobResponse; public class ForwardADTaskAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK + "/forward"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK + "/forward"; public static final ForwardADTaskAction INSTANCE = new ForwardADTaskAction(); private ForwardADTaskAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java index d2c571fa8..a0591e052 100644 --- a/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java @@ -11,10 +11,6 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.model.ADTask.ERROR_FIELD; -import static org.opensearch.ad.model.ADTask.STATE_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; - import java.util.Arrays; import java.util.List; @@ -23,10 +19,10 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskAction; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.common.inject.Inject; @@ -35,8 +31,10 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.transport.TransportService; @@ -47,6 +45,7 @@ public class ForwardADTaskTransportAction extends HandledTransportAction { + indexJobHander.startConfig(detector, detectionDateRange, user, transportService, ActionListener.wrap(r -> { adTaskCacheManager.setDetectorTaskSlots(detector.getId(), availableTaskSlots); listener.onResponse(r); }, e -> listener.onFailure(e))); @@ -122,7 +123,9 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener adTaskCacheManager.setDetectorTaskSlots(detectorId, 0); logger.info("Historical HC detector done, will remove from cache, detector id:{}", detectorId); listener.onResponse(new JobResponse(detectorId)); - TaskState state = !adTask.isEntityTask() && adTask.getError() != null ? TaskState.FAILED : TaskState.FINISHED; + TaskState state = !adTask.isHistoricalEntityTask() && adTask.getError() != null + ? TaskState.FAILED + : TaskState.FINISHED; adTaskManager.setHCDetectorTaskDone(adTask, state, listener); } else { logger.debug("Run next entity for detector " + detectorId); @@ -133,11 +136,11 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener adTask.getParentTaskId(), ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, TaskState.RUNNING.name(), - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, adTaskManager.hcDetectorProgress(detectorId), - ERROR_FIELD, + TimeSeriesTask.ERROR_FIELD, adTask.getError() != null ? adTask.getError() : "" ) ); @@ -155,7 +158,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener case PUSH_BACK_ENTITY: logger.debug("Received PUSH_BACK_ENTITY action for detector {}, task {}", detectorId, adTask.getTaskId()); // Push back entity to pending entities queue and run next entity. - if (adTask.isEntityTask()) { // AD task must be entity level task. + if (adTask.isHistoricalEntityTask()) { // AD task must be entity level task. adTaskCacheManager.removeRunningEntity(detectorId, entityValue); if (adTaskManager.isRetryableError(adTask.getError()) && !adTaskCacheManager.exceedRetryLimit(adTask.getConfigId(), adTask.getTaskId())) { @@ -204,7 +207,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener if (detector.isHighCardinality()) { adTaskCacheManager.clearPendingEntities(detectorId); adTaskCacheManager.removeRunningEntity(detectorId, entityValue); - if (!adTaskCacheManager.hasEntity(detectorId) || !adTask.isEntityTask()) { + if (!adTaskCacheManager.hasEntity(detectorId) || !adTask.isHistoricalEntityTask()) { adTaskManager.setHCDetectorTaskDone(adTask, TaskState.STOPPED, listener); } listener.onResponse(new JobResponse(adTask.getTaskId())); diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java index c4232047d..c740ed24e 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class GetAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detectors/get"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detectors/get"; public static final GetAnomalyDetectorAction INSTANCE = new GetAnomalyDetectorAction(); private GetAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java index f3808dab2..fba6c4582 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java @@ -19,7 +19,6 @@ import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.model.EntityProfile; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -28,6 +27,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.EntityProfile; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.util.RestHandlerUtils; diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java index 3b040c9e1..0bae4a5ce 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java @@ -11,88 +11,43 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_GET_DETECTOR; -import static org.opensearch.ad.model.ADTaskType.ALL_DETECTOR_TASK_TYPES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; -import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; -import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.get.MultiGetItemResponse; -import org.opensearch.action.get.MultiGetRequest; -import org.opensearch.action.get.MultiGetResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.ADEntityProfileRunner; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.AnomalyDetectorProfileRunner; -import org.opensearch.ad.EntityProfileRunner; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.EntityProfileName; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.CheckedConsumer; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.tasks.Task; -import org.opensearch.timeseries.Name; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfile; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.BaseGetConfigTransportAction; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; -import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; -import com.google.common.collect.Sets; - -public class GetAnomalyDetectorTransportAction extends HandledTransportAction { - - private static final Logger LOG = LogManager.getLogger(GetAnomalyDetectorTransportAction.class); +public class GetAnomalyDetectorTransportAction extends + BaseGetConfigTransportAction { - private final ClusterService clusterService; - private final Client client; - private final SecurityClientUtil clientUtil; - private final Set allProfileTypeStrs; - private final Set allProfileTypes; - private final Set defaultDetectorProfileTypes; - private final Set allEntityProfileTypeStrs; - private final Set allEntityProfileTypes; - private final Set defaultEntityProfileTypes; - private final NamedXContentRegistry xContentRegistry; - private final DiscoveryNodeFilterer nodeFilter; - private final TransportService transportService; - private volatile Boolean filterByEnabled; - private final ADTaskManager adTaskManager; + public static final Logger LOG = LogManager.getLogger(GetAnomalyDetectorTransportAction.class); @Inject public GetAnomalyDetectorTransportAction( @@ -104,321 +59,105 @@ public GetAnomalyDetectorTransportAction( SecurityClientUtil clientUtil, Settings settings, NamedXContentRegistry xContentRegistry, - ADTaskManager adTaskManager + ADTaskManager adTaskManager, + ADTaskProfileRunner adTaskProfileRunner ) { - super(GetAnomalyDetectorAction.NAME, transportService, actionFilters, GetAnomalyDetectorRequest::new); - this.clusterService = clusterService; - this.client = client; - this.clientUtil = clientUtil; - List allProfiles = Arrays.asList(DetectorProfileName.values()); - this.allProfileTypes = EnumSet.copyOf(allProfiles); - this.allProfileTypeStrs = getProfileListStrs(allProfiles); - List defaultProfiles = Arrays.asList(DetectorProfileName.ERROR, DetectorProfileName.STATE); - this.defaultDetectorProfileTypes = new HashSet(defaultProfiles); - - List allEntityProfiles = Arrays.asList(EntityProfileName.values()); - this.allEntityProfileTypes = EnumSet.copyOf(allEntityProfiles); - this.allEntityProfileTypeStrs = getProfileListStrs(allEntityProfiles); - List defaultEntityProfiles = Arrays.asList(EntityProfileName.STATE); - this.defaultEntityProfileTypes = new HashSet(defaultEntityProfiles); - - this.xContentRegistry = xContentRegistry; - this.nodeFilter = nodeFilter; - filterByEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); - this.transportService = transportService; - this.adTaskManager = adTaskManager; + super( + transportService, + nodeFilter, + actionFilters, + clusterService, + client, + clientUtil, + settings, + xContentRegistry, + adTaskManager, + GetAnomalyDetectorAction.NAME, + AnomalyDetector.class, + AnomalyDetector.PARSE_FIELD_NAME, + ADTaskType.ALL_DETECTOR_TASK_TYPES, + ADTaskType.REALTIME_HC_DETECTOR.name(), + ADTaskType.REALTIME_SINGLE_ENTITY.name(), + ADTaskType.HISTORICAL_HC_DETECTOR.name(), + ADTaskType.HISTORICAL_SINGLE_ENTITY.name(), + AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES, + adTaskProfileRunner + ); } @Override - protected void doExecute(Task task, GetAnomalyDetectorRequest request, ActionListener actionListener) { - String detectorID = request.getDetectorID(); - User user = getUserContext(client); - ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_DETECTOR); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - resolveUserAndExecute( - user, - detectorID, - filterByEnabled, - listener, - (anomalyDetector) -> getExecute(request, listener), - client, - clusterService, - xContentRegistry, - AnomalyDetector.class - ); - } catch (Exception e) { - LOG.error(e); - listener.onFailure(e); + protected void fillInHistoricalTaskforBwc(Map tasks, Optional historicalAdTask) { + if (tasks.containsKey(ADTaskType.HISTORICAL.name())) { + historicalAdTask = Optional.ofNullable(tasks.get(ADTaskType.HISTORICAL.name())); } } - protected void getExecute(GetAnomalyDetectorRequest request, ActionListener listener) { - String detectorID = request.getDetectorID(); - String typesStr = request.getTypeStr(); - String rawPath = request.getRawPath(); - Entity entity = request.getEntity(); - boolean all = request.isAll(); - boolean returnJob = request.isReturnJob(); - boolean returnTask = request.isReturnTask(); - - try { - if (!Strings.isEmpty(typesStr) || rawPath.endsWith(PROFILE) || rawPath.endsWith(PROFILE + "/")) { - if (entity != null) { - Set entityProfilesToCollect = getEntityProfilesToCollect(typesStr, all); - EntityProfileRunner profileRunner = new EntityProfileRunner( - client, - clientUtil, - xContentRegistry, - TimeSeriesSettings.NUM_MIN_SAMPLES - ); - profileRunner.profile(detectorID, entity, entityProfilesToCollect, ActionListener.wrap(profile -> { - listener - .onResponse( - new GetAnomalyDetectorResponse( - 0, - null, - 0, - 0, - null, - null, - false, - null, - null, - false, - null, - null, - profile, - true - ) - ); - }, e -> listener.onFailure(e))); - } else { - Set profilesToCollect = getProfilesToCollect(typesStr, all); - AnomalyDetectorProfileRunner profileRunner = new AnomalyDetectorProfileRunner( - client, - clientUtil, - xContentRegistry, - nodeFilter, - TimeSeriesSettings.NUM_MIN_SAMPLES, - transportService, - adTaskManager - ); - profileRunner.profile(detectorID, getProfileActionListener(listener), profilesToCollect); - } - } else { - if (returnTask) { - adTaskManager.getAndExecuteOnLatestADTasks(detectorID, null, null, ALL_DETECTOR_TASK_TYPES, (taskList) -> { - Optional realtimeAdTask = Optional.empty(); - Optional historicalAdTask = Optional.empty(); - - if (taskList != null && taskList.size() > 0) { - Map adTasks = new HashMap<>(); - List duplicateAdTasks = new ArrayList<>(); - for (ADTask task : taskList) { - if (adTasks.containsKey(task.getTaskType())) { - LOG - .info( - "Found duplicate latest task of detector {}, task id: {}, task type: {}", - detectorID, - task.getTaskType(), - task.getTaskId() - ); - duplicateAdTasks.add(task); - continue; - } - adTasks.put(task.getTaskType(), task); - } - if (duplicateAdTasks.size() > 0) { - adTaskManager.resetLatestFlagAsFalse(duplicateAdTasks); - } - - if (adTasks.containsKey(ADTaskType.REALTIME_HC_DETECTOR.name())) { - realtimeAdTask = Optional.ofNullable(adTasks.get(ADTaskType.REALTIME_HC_DETECTOR.name())); - } else if (adTasks.containsKey(ADTaskType.REALTIME_SINGLE_ENTITY.name())) { - realtimeAdTask = Optional.ofNullable(adTasks.get(ADTaskType.REALTIME_SINGLE_ENTITY.name())); - } - if (adTasks.containsKey(ADTaskType.HISTORICAL_HC_DETECTOR.name())) { - historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL_HC_DETECTOR.name())); - } else if (adTasks.containsKey(ADTaskType.HISTORICAL_SINGLE_ENTITY.name())) { - historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL_SINGLE_ENTITY.name())); - } else if (adTasks.containsKey(ADTaskType.HISTORICAL.name())) { - historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL.name())); - } - } - getDetectorAndJob(detectorID, returnJob, returnTask, realtimeAdTask, historicalAdTask, listener); - }, transportService, true, 2, listener); - } else { - getDetectorAndJob(detectorID, returnJob, returnTask, Optional.empty(), Optional.empty(), listener); - } - } - } catch (Exception e) { - LOG.error(e); - listener.onFailure(e); - } - } - - private void getDetectorAndJob( - String detectorID, + @Override + protected GetAnomalyDetectorResponse createResponse( + long version, + String id, + long primaryTerm, + long seqNo, + AnomalyDetector config, + Job job, boolean returnJob, + Optional realtimeTask, + Optional historicalTask, boolean returnTask, - Optional realtimeAdTask, - Optional historicalAdTask, - ActionListener listener + RestStatus restStatus, + DetectorProfile detectorProfile, + EntityProfile entityProfile, + boolean profileResponse ) { - MultiGetRequest.Item adItem = new MultiGetRequest.Item(CommonName.CONFIG_INDEX, detectorID); - MultiGetRequest multiGetRequest = new MultiGetRequest().add(adItem); - if (returnJob) { - MultiGetRequest.Item adJobItem = new MultiGetRequest.Item(CommonName.JOB_INDEX, detectorID); - multiGetRequest.add(adJobItem); - } - client.multiGet(multiGetRequest, onMultiGetResponse(listener, returnJob, returnTask, realtimeAdTask, historicalAdTask, detectorID)); + return new GetAnomalyDetectorResponse( + version, + id, + primaryTerm, + seqNo, + config, + job, + returnJob, + realtimeTask.orElse(null), + historicalTask.orElse(null), + returnTask, + RestStatus.OK, + detectorProfile, + entityProfile, + profileResponse + ); } - private ActionListener onMultiGetResponse( - ActionListener listener, - boolean returnJob, - boolean returnTask, - Optional realtimeAdTask, - Optional historicalAdTask, - String detectorId + @Override + protected ADEntityProfileRunner createEntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples ) { - return new ActionListener() { - @Override - public void onResponse(MultiGetResponse multiGetResponse) { - MultiGetItemResponse[] responses = multiGetResponse.getResponses(); - AnomalyDetector detector = null; - Job adJob = null; - String id = null; - long version = 0; - long seqNo = 0; - long primaryTerm = 0; - - for (MultiGetItemResponse response : responses) { - if (CommonName.CONFIG_INDEX.equals(response.getIndex())) { - if (response.getResponse() == null || !response.getResponse().isExists()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - id = response.getId(); - version = response.getResponse().getVersion(); - primaryTerm = response.getResponse().getPrimaryTerm(); - seqNo = response.getResponse().getSeqNo(); - if (!response.getResponse().isSourceEmpty()) { - try ( - XContentParser parser = RestHandlerUtils - .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - detector = parser.namedObject(AnomalyDetector.class, AnomalyDetector.PARSE_FIELD_NAME, null); - } catch (Exception e) { - String message = "Failed to parse detector job " + detectorId; - listener.onFailure(buildInternalServerErrorResponse(e, message)); - return; - } - } - } - - if (CommonName.JOB_INDEX.equals(response.getIndex())) { - if (response.getResponse() != null - && response.getResponse().isExists() - && !response.getResponse().isSourceEmpty()) { - try ( - XContentParser parser = RestHandlerUtils - .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - adJob = Job.parse(parser); - } catch (Exception e) { - String message = "Failed to parse detector job " + detectorId; - listener.onFailure(buildInternalServerErrorResponse(e, message)); - return; - } - } - } - } - listener - .onResponse( - new GetAnomalyDetectorResponse( - version, - id, - primaryTerm, - seqNo, - detector, - adJob, - returnJob, - realtimeAdTask.orElse(null), - historicalAdTask.orElse(null), - returnTask, - RestStatus.OK, - null, - null, - false - ) - ); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }; - } - - private ActionListener getProfileActionListener(ActionListener listener) { - return ActionListener.wrap(new CheckedConsumer() { - @Override - public void accept(DetectorProfile profile) throws Exception { - listener - .onResponse( - new GetAnomalyDetectorResponse(0, null, 0, 0, null, null, false, null, null, false, null, profile, null, true) - ); - } - }, exception -> { listener.onFailure(exception); }); - } - - private OpenSearchStatusException buildInternalServerErrorResponse(Exception e, String errorMsg) { - LOG.error(errorMsg, e); - return new OpenSearchStatusException(errorMsg, RestStatus.INTERNAL_SERVER_ERROR); - } - - /** - * - * @param typesStr a list of input profile types separated by comma - * @param all whether we should return all profile in the response - * @return profiles to collect for a detector - */ - private Set getProfilesToCollect(String typesStr, boolean all) { - if (all) { - return this.allProfileTypes; - } else if (Strings.isEmpty(typesStr)) { - return this.defaultDetectorProfileTypes; - } else { - // Filter out unsupported types - Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); - return DetectorProfileName.getNames(Sets.intersection(allProfileTypeStrs, typesInRequest)); - } + return new ADEntityProfileRunner(client, clientUtil, xContentRegistry, TimeSeriesSettings.NUM_MIN_SAMPLES); } - /** - * - * @param typesStr a list of input profile types separated by comma - * @param all whether we should return all profile in the response - * @return profiles to collect for an entity - */ - private Set getEntityProfilesToCollect(String typesStr, boolean all) { - if (all) { - return this.allEntityProfileTypes; - } else if (Strings.isEmpty(typesStr)) { - return this.defaultEntityProfileTypes; - } else { - // Filter out unsupported types - Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); - return EntityProfileName.getNames(Sets.intersection(allEntityProfileTypeStrs, typesInRequest)); - } + @Override + protected AnomalyDetectorProfileRunner createProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeFilterer nodeFilter, + long requiredSamples, + TransportService transportService, + ADTaskManager taskManager, + ADTaskProfileRunner taskProfileRunner + ) { + return new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry, + nodeFilter, + TimeSeriesSettings.NUM_MIN_SAMPLES, + transportService, + taskManager, + taskProfileRunner + ); } - private Set getProfileListStrs(List profileList) { - return profileList.stream().map(profile -> profile.getName()).collect(Collectors.toSet()); - } } diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java index 9ee038336..56103dfc9 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class IndexAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/write"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/write"; public static final IndexAnomalyDetectorAction INSTANCE = new IndexAnomalyDetectorAction(); private IndexAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java index 572e847f9..6a4bb6d1d 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java @@ -34,6 +34,9 @@ public class IndexAnomalyDetectorRequest extends ActionRequest { private Integer maxSingleEntityAnomalyDetectors; private Integer maxMultiEntityAnomalyDetectors; private Integer maxAnomalyFeatures; + // added during refactoring for forecasting. It is fine we add a new field + // since the request is handled by the same node. + private Integer maxCategoricalFields; public IndexAnomalyDetectorRequest(StreamInput in) throws IOException { super(in); @@ -47,6 +50,7 @@ public IndexAnomalyDetectorRequest(StreamInput in) throws IOException { maxSingleEntityAnomalyDetectors = in.readInt(); maxMultiEntityAnomalyDetectors = in.readInt(); maxAnomalyFeatures = in.readInt(); + maxCategoricalFields = in.readInt(); } public IndexAnomalyDetectorRequest( @@ -59,7 +63,8 @@ public IndexAnomalyDetectorRequest( TimeValue requestTimeout, Integer maxSingleEntityAnomalyDetectors, Integer maxMultiEntityAnomalyDetectors, - Integer maxAnomalyFeatures + Integer maxAnomalyFeatures, + Integer maxCategoricalFields ) { super(); this.detectorID = detectorID; @@ -72,6 +77,7 @@ public IndexAnomalyDetectorRequest( this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; this.maxAnomalyFeatures = maxAnomalyFeatures; + this.maxCategoricalFields = maxCategoricalFields; } public String getDetectorID() { @@ -114,6 +120,10 @@ public Integer getMaxAnomalyFeatures() { return maxAnomalyFeatures; } + public Integer getMaxCategoricalFields() { + return maxCategoricalFields; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -127,6 +137,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(maxSingleEntityAnomalyDetectors); out.writeInt(maxMultiEntityAnomalyDetectors); out.writeInt(maxAnomalyFeatures); + out.writeInt(maxCategoricalFields); } @Override diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java index ac0e560b1..5d9b69910 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java @@ -16,7 +16,6 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; import static org.opensearch.timeseries.util.ParseUtils.getConfig; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; import java.util.List; @@ -46,8 +45,10 @@ import org.opensearch.rest.RestRequest; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; @@ -93,7 +94,7 @@ public IndexAnomalyDetectorTransportAction( @Override protected void doExecute(Task task, IndexAnomalyDetectorRequest request, ActionListener actionListener) { - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); String detectorId = request.getDetectorID(); RestRequest.Method method = request.getMethod(); String errorMessage = method == RestRequest.Method.PUT ? FAIL_TO_UPDATE_DETECTOR : FAIL_TO_CREATE_DETECTOR; @@ -116,8 +117,12 @@ private void resolveUserAndExecute( try { // Check if user has backend roles // When filter by is enabled, block users creating/updating detectors who do not have backend roles. - if (filterByEnabled && !checkFilterByBackendRoles(requestedUser, listener)) { - return; + if (filterByEnabled) { + String error = checkFilterByBackendRoles(requestedUser); + if (error != null) { + listener.onFailure(new TimeSeriesException(error)); + return; + } } if (method == RestRequest.Method.PUT) { // requestedUser == null means security is disabled or user is superadmin. In this case we don't need to @@ -164,6 +169,7 @@ protected void adExecute( Integer maxSingleEntityAnomalyDetectors = request.getMaxSingleEntityAnomalyDetectors(); Integer maxMultiEntityAnomalyDetectors = request.getMaxMultiEntityAnomalyDetectors(); Integer maxAnomalyFeatures = request.getMaxAnomalyFeatures(); + Integer maxCategoricalFields = request.getMaxCategoricalFields(); storedContext.restore(); checkIndicesAndExecute(detector.getIndices(), () -> { @@ -175,7 +181,6 @@ protected void adExecute( client, clientUtil, transportService, - listener, anomalyDetectionIndices, detectorId, seqNo, @@ -186,6 +191,7 @@ protected void adExecute( maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry, detectorUser, @@ -193,7 +199,7 @@ protected void adExecute( searchFeatureDao, settings ); - indexAnomalyDetectorActionHandler.start(); + indexAnomalyDetectorActionHandler.start(listener); }, listener); } diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java index c90ecc446..5ae8d6c35 100644 --- a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class PreviewAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/preview"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/preview"; public static final PreviewAnomalyDetectorAction INSTANCE = new PreviewAnomalyDetectorAction(); private PreviewAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java index 5f6c6c9d3..ef82c43b2 100644 --- a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java @@ -16,7 +16,6 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_FEATURES; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_CONCURRENT_PREVIEW; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; @@ -56,6 +55,7 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; @@ -103,7 +103,7 @@ protected void doExecute( ActionListener actionListener ) { String detectorId = request.getId(); - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_PREVIEW_DETECTOR); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { resolveUserAndExecute( diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java b/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java index 147ff74cb..b38a088eb 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class RCFPollingAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "rcfpolling"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "rcfpolling"; public static final RCFPollingAction INSTANCE = new RCFPollingAction(); private RCFPollingAction() { diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java b/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java index a8bd64603..c7783cb8f 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java @@ -19,8 +19,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -30,6 +29,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.transport.TransportException; @@ -48,7 +48,7 @@ public class RCFPollingTransportAction extends HandledTransportAction rcfNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID); + Optional rcfNode = hashRing.getOwningNodeWithSameLocalVersionForRealtime(rcfModelID); if (!rcfNode.isPresent()) { listener.onFailure(new TimeSeriesException(adID, NO_NODE_FOUND_MSG)); return; diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultAction.java b/src/main/java/org/opensearch/ad/transport/RCFResultAction.java index 3480e880a..f551f97df 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class RCFResultAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "rcf/result"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "rcf/result"; public static final RCFResultAction INSTANCE = new RCFResultAction(); private RCFResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java index d7df181bb..59ca12965 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java @@ -20,14 +20,14 @@ import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.stats.ADStats; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; import org.opensearch.tasks.Task; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.stats.StatNames; @@ -36,7 +36,7 @@ public class RCFResultTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(RCFResultTransportAction.class); - private ModelManager manager; + private ADModelManager manager; private CircuitBreakerService adCircuitBreakerService; private HashRing hashRing; private ADStats adStats; @@ -45,7 +45,7 @@ public class RCFResultTransportAction extends HandledTransportAction { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "tasks/search"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "tasks/search"; public static final SearchADTasksAction INSTANCE = new SearchADTasksAction(); private SearchADTasksAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java index c15ece9ab..90ae6cede 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java @@ -13,11 +13,11 @@ import org.opensearch.action.ActionType; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class SearchAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/search"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/search"; public static final SearchAnomalyDetectorAction INSTANCE = new SearchAnomalyDetectorAction(); private SearchAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java index 3f4f7c2fc..50f3b60d4 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java @@ -12,15 +12,16 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.SearchConfigInfoResponse; -public class SearchAnomalyDetectorInfoAction extends ActionType { +public class SearchAnomalyDetectorInfoAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/info"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/info"; public static final SearchAnomalyDetectorInfoAction INSTANCE = new SearchAnomalyDetectorInfoAction(); private SearchAnomalyDetectorInfoAction() { - super(NAME, SearchAnomalyDetectorInfoResponse::new); + super(NAME, SearchConfigInfoResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoTransportAction.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoTransportAction.java index b932ae601..c83ac9ebd 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoTransportAction.java @@ -11,34 +11,14 @@ package org.opensearch.ad.transport; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_GET_CONFIG_INFO; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.core.action.ActionListener; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.TermsQueryBuilder; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.tasks.Task; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.transport.BaseSearchConfigInfoTransportAction; import org.opensearch.transport.TransportService; -public class SearchAnomalyDetectorInfoTransportAction extends - HandledTransportAction { - private static final Logger LOG = LogManager.getLogger(SearchAnomalyDetectorInfoTransportAction.class); - private final Client client; - private final ClusterService clusterService; +public class SearchAnomalyDetectorInfoTransportAction extends BaseSearchConfigInfoTransportAction { @Inject public SearchAnomalyDetectorInfoTransportAction( @@ -47,80 +27,6 @@ public SearchAnomalyDetectorInfoTransportAction( Client client, ClusterService clusterService ) { - super(SearchAnomalyDetectorInfoAction.NAME, transportService, actionFilters, SearchAnomalyDetectorInfoRequest::new); - this.client = client; - this.clusterService = clusterService; - } - - @Override - protected void doExecute( - Task task, - SearchAnomalyDetectorInfoRequest request, - ActionListener actionListener - ) { - String name = request.getName(); - String rawPath = request.getRawPath(); - ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_CONFIG_INFO); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - SearchRequest searchRequest = new SearchRequest().indices(CommonName.CONFIG_INDEX); - if (rawPath.endsWith(RestHandlerUtils.COUNT)) { - // Count detectors - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchRequest.source(searchSourceBuilder); - client.search(searchRequest, new ActionListener() { - - @Override - public void onResponse(SearchResponse searchResponse) { - SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse( - searchResponse.getHits().getTotalHits().value, - false - ); - listener.onResponse(response); - } - - @Override - public void onFailure(Exception e) { - if (e.getClass() == IndexNotFoundException.class) { - // Anomaly Detectors index does not exist - // Could be that user is creating first detector - SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(0, false); - listener.onResponse(response); - } else { - listener.onFailure(e); - } - } - }); - } else { - // Match name with existing detectors - TermsQueryBuilder query = QueryBuilders.termsQuery("name.keyword", name); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - searchRequest.source(searchSourceBuilder); - client.search(searchRequest, new ActionListener() { - - @Override - public void onResponse(SearchResponse searchResponse) { - boolean nameExists = false; - nameExists = searchResponse.getHits().getTotalHits().value > 0; - SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(0, nameExists); - listener.onResponse(response); - } - - @Override - public void onFailure(Exception e) { - if (e.getClass() == IndexNotFoundException.class) { - // Anomaly Detectors index does not exist - // Could be that user is creating first detector - SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(0, false); - listener.onResponse(response); - } else { - listener.onFailure(e); - } - } - }); - } - } catch (Exception e) { - LOG.error(e); - listener.onFailure(e); - } + super(transportService, actionFilters, client, SearchAnomalyDetectorInfoAction.NAME); } } diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java index 7e0178393..e2a5969bd 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java @@ -13,11 +13,11 @@ import org.opensearch.action.ActionType; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class SearchAnomalyResultAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "result/search"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "result/search"; public static final SearchAnomalyResultAction INSTANCE = new SearchAnomalyResultAction(); private SearchAnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java index ee89c4179..8956eeb1d 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class SearchTopAnomalyResultAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "result/topAnomalies"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "result/topAnomalies"; public static final SearchTopAnomalyResultAction INSTANCE = new SearchTopAnomalyResultAction(); private SearchTopAnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java index 82a1a02a3..afe3c4729 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java @@ -46,7 +46,6 @@ import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.script.Script; -import org.opensearch.script.ScriptType; import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; @@ -64,10 +63,10 @@ import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.util.QueryUtil; import org.opensearch.transport.TransportService; -import com.google.common.collect.ImmutableMap; - /** * Transport action to fetch top anomaly results for some HC detector. Generates a * query based on user input to fetch aggregated entity results. @@ -219,7 +218,7 @@ public SearchTopAnomalyResultTransportAction( @Override protected void doExecute(Task task, SearchTopAnomalyResultRequest request, ActionListener listener) { - GetAnomalyDetectorRequest getAdRequest = new GetAnomalyDetectorRequest( + GetConfigRequest getAdRequest = new GetConfigRequest( request.getId(), // The default version value used in org.opensearch.rest.action.RestActions.parseVersion() -3L, @@ -506,7 +505,7 @@ private QueryBuilder generateQuery(SearchTopAnomalyResultRequest request) { private AggregationBuilder generateAggregation(SearchTopAnomalyResultRequest request) { List> sources = new ArrayList<>(); for (String categoryField : request.getCategoryFields()) { - Script script = getScriptForCategoryField(categoryField); + Script script = QueryUtil.getScriptForCategoryField(categoryField); sources.add(new TermsValuesSourceBuilder(categoryField).script(script)); } @@ -529,36 +528,6 @@ private AggregationBuilder generateAggregation(SearchTopAnomalyResultRequest req .subAggregation(bucketSort); } - /** - * Generates the painless script to fetch results that have an entity name matching the passed-in category field. - * - * @param categoryField the category field to be used as a source - * @return the painless script used to get all docs with entity name values matching the category field - */ - private Script getScriptForCategoryField(String categoryField) { - StringBuilder builder = new StringBuilder() - .append("String value = null;") - .append("if (params == null || params._source == null || params._source.entity == null) {") - .append("return \"\"") - .append("}") - .append("for (item in params._source.entity) {") - .append("if (item[\"name\"] == params[\"categoryField\"]) {") - .append("value = item['value'];") - .append("break;") - .append("}") - .append("}") - .append("return value;"); - - // The last argument contains the K/V pair to inject the categoryField value into the script - return new Script( - ScriptType.INLINE, - "painless", - builder.toString(), - Collections.emptyMap(), - ImmutableMap.of("categoryField", categoryField) - ); - } - /** * Creates a descending-ordered List from a min heap. * diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java index 3c1f53d9d..8172aeeaf 100644 --- a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java @@ -12,15 +12,16 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.StatsTimeSeriesResponse; -public class StatsAnomalyDetectorAction extends ActionType { +public class StatsAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/stats"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/stats"; public static final StatsAnomalyDetectorAction INSTANCE = new StatsAnomalyDetectorAction(); private StatsAnomalyDetectorAction() { - super(NAME, StatsAnomalyDetectorResponse::new); + super(NAME, StatsTimeSeriesResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java index ebf4016cf..d96887e12 100644 --- a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java @@ -11,47 +11,31 @@ package org.opensearch.ad.transport; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_GET_STATS; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; - -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Set; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorType; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.ADStatsResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.terms.StringTerms; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.tasks.Task; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.BaseStatsTransportAction; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.transport.StatsResponse; import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; import org.opensearch.transport.TransportService; -public class StatsAnomalyDetectorTransportAction extends HandledTransportAction { +public class StatsAnomalyDetectorTransportAction extends BaseStatsTransportAction { public static final String DETECTOR_TYPE_AGG = "detector_type_agg"; - private final Logger logger = LogManager.getLogger(StatsAnomalyDetectorTransportAction.class); - - private final Client client; - private final ADStats adStats; - private final ClusterService clusterService; @Inject public StatsAnomalyDetectorTransportAction( @@ -62,55 +46,7 @@ public StatsAnomalyDetectorTransportAction( ClusterService clusterService ) { - super(StatsAnomalyDetectorAction.NAME, transportService, actionFilters, ADStatsRequest::new); - this.client = client; - this.adStats = adStats; - this.clusterService = clusterService; - } - - @Override - protected void doExecute(Task task, ADStatsRequest request, ActionListener actionListener) { - ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_STATS); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - getStats(client, listener, request); - } catch (Exception e) { - logger.error(e); - listener.onFailure(e); - } - } - - /** - * Make the 2 requests to get the node and cluster statistics - * - * @param client Client - * @param listener Listener to send response - * @param adStatsRequest Request containing stats to be retrieved - */ - private void getStats(Client client, ActionListener listener, ADStatsRequest adStatsRequest) { - // Use MultiResponsesDelegateActionListener to execute 2 async requests and create the response once they finish - MultiResponsesDelegateActionListener delegateListener = new MultiResponsesDelegateActionListener<>( - getRestStatsListener(listener), - 2, - "Unable to return AD Stats", - false - ); - - getClusterStats(client, delegateListener, adStatsRequest); - getNodeStats(client, delegateListener, adStatsRequest); - } - - /** - * Listener sends response once Node Stats and Cluster Stats are gathered - * - * @param listener Listener to send response - * @return ActionListener for ADStatsResponse - */ - private ActionListener getRestStatsListener(ActionListener listener) { - return ActionListener - .wrap( - adStatsResponse -> { listener.onResponse(new StatsAnomalyDetectorResponse(adStatsResponse)); }, - exception -> listener.onFailure(new OpenSearchStatusException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)) - ); + super(transportService, actionFilters, client, adStats, clusterService, StatsAnomalyDetectorAction.NAME); } /** @@ -121,15 +57,16 @@ private ActionListener getRestStatsListener(ActionListener listener, - ADStatsRequest adStatsRequest + MultiResponsesDelegateActionListener listener, + StatsRequest adStatsRequest ) { - ADStatsResponse adStatsResponse = new ADStatsResponse(); + StatsResponse adStatsResponse = new StatsResponse(); if ((adStatsRequest.getStatsToBeRetrieved().contains(StatNames.DETECTOR_COUNT.getName()) - || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()) - || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName())) + || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()) + || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.HC_DETECTOR_COUNT.getName())) && clusterService.state().getRoutingTable().hasIndex(CommonName.CONFIG_INDEX)) { TermsAggregationBuilder termsAgg = AggregationBuilders.terms(DETECTOR_TYPE_AGG).field(AnomalyDetector.DETECTOR_TYPE_FIELD); @@ -156,13 +93,13 @@ private void getClusterStats( } } if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.DETECTOR_COUNT.getName())) { - adStats.getStat(StatNames.DETECTOR_COUNT.getName()).setValue(totalDetectors); + stats.getStat(StatNames.DETECTOR_COUNT.getName()).setValue(totalDetectors); } - if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())) { - adStats.getStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()).setValue(totalSingleEntityDetectors); + if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())) { + stats.getStat(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()).setValue(totalSingleEntityDetectors); } - if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName())) { - adStats.getStat(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName()).setValue(totalMultiEntityDetectors); + if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.HC_DETECTOR_COUNT.getName())) { + stats.getStat(StatNames.HC_DETECTOR_COUNT.getName()).setValue(totalMultiEntityDetectors); } adStatsResponse.setClusterStats(getClusterStatsMap(adStatsRequest)); listener.onResponse(adStatsResponse); @@ -173,24 +110,6 @@ private void getClusterStats( } } - /** - * Collect Cluster Stats into map to be retrieved - * - * @param adStatsRequest Request containing stats to be retrieved - * @return Map containing Cluster Stats - */ - private Map getClusterStatsMap(ADStatsRequest adStatsRequest) { - Map clusterStats = new HashMap<>(); - Set statsToBeRetrieved = adStatsRequest.getStatsToBeRetrieved(); - adStats - .getClusterStats() - .entrySet() - .stream() - .filter(s -> statsToBeRetrieved.contains(s.getKey())) - .forEach(s -> clusterStats.put(s.getKey(), s.getValue().getValue())); - return clusterStats; - } - /** * Make async request to get the Anomaly Detection statistics from each node and, onResponse, set the * ADStatsNodesResponse field of ADStatsResponse @@ -199,14 +118,11 @@ private Map getClusterStatsMap(ADStatsRequest adStatsRequest) { * @param listener MultiResponsesDelegateActionListener to be used once both requests complete * @param adStatsRequest Request containing stats to be retrieved */ - private void getNodeStats( - Client client, - MultiResponsesDelegateActionListener listener, - ADStatsRequest adStatsRequest - ) { + @Override + protected void getNodeStats(Client client, MultiResponsesDelegateActionListener listener, StatsRequest adStatsRequest) { client.execute(ADStatsNodesAction.INSTANCE, adStatsRequest, ActionListener.wrap(adStatsResponse -> { - ADStatsResponse restADStatsResponse = new ADStatsResponse(); - restADStatsResponse.setADStatsNodesResponse(adStatsResponse); + StatsResponse restADStatsResponse = new StatsResponse(); + restADStatsResponse.setStatsNodesResponse(adStatsResponse); listener.onResponse(restADStatsResponse); }, listener::onFailure)); } diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java b/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java index 5c7182920..15f617e78 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java @@ -12,15 +12,16 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.StopConfigResponse; -public class StopDetectorAction extends ActionType { +public class StopDetectorAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/stop"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/stop"; public static final StopDetectorAction INSTANCE = new StopDetectorAction(); private StopDetectorAction() { - super(NAME, StopDetectorResponse::new); + super(NAME, StopConfigResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java index deafd8854..074165a35 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java @@ -27,10 +27,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.tasks.Task; import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportService; -public class StopDetectorTransportAction extends HandledTransportAction { +public class StopDetectorTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(StopDetectorTransportAction.class); @@ -44,19 +47,19 @@ public StopDetectorTransportAction( ActionFilters actionFilters, Client client ) { - super(StopDetectorAction.NAME, transportService, actionFilters, StopDetectorRequest::new); + super(StopDetectorAction.NAME, transportService, actionFilters, StopConfigRequest::new); this.client = client; this.nodeFilter = nodeFilter; } @Override - protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { - StopDetectorRequest request = StopDetectorRequest.fromActionRequest(actionRequest); - String adID = request.getAdID(); + protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { + StopConfigRequest request = StopConfigRequest.fromActionRequest(actionRequest); + String adID = request.getConfigID(); try { DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); DeleteModelRequest modelDeleteRequest = new DeleteModelRequest(adID, dataNodes); - client.execute(DeleteModelAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> { + client.execute(DeleteADModelAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> { if (response.hasFailures()) { LOG.warn("Cannot delete all models of detector {}", adID); for (FailedNodeException failedNodeException : response.failures()) { @@ -64,14 +67,14 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< } // if customers are using an updated detector and we haven't deleted old // checkpoints, customer would have trouble - listener.onResponse(new StopDetectorResponse(false)); + listener.onResponse(new StopConfigResponse(false)); } else { LOG.info("models of detector {} get deleted", adID); - listener.onResponse(new StopDetectorResponse(true)); + listener.onResponse(new StopConfigResponse(true)); } }, exception -> { LOG.error(new ParameterizedMessage("Deletion of detector [{}] has exception.", adID), exception); - listener.onResponse(new StopDetectorResponse(false)); + listener.onResponse(new StopConfigResponse(false)); })); } catch (Exception e) { LOG.error(FAIL_TO_STOP_DETECTOR + " " + adID, e); diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java b/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java index 1561c08dc..f8a81252a 100644 --- a/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ThresholdResultAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "threshold/result"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "threshold/result"; public static final ThresholdResultAction INSTANCE = new ThresholdResultAction(); private ThresholdResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java index 053d9729b..9c60fcd7f 100644 --- a/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java @@ -15,7 +15,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; import org.opensearch.tasks.Task; @@ -24,10 +24,10 @@ public class ThresholdResultTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(ThresholdResultTransportAction.class); - private ModelManager manager; + private ADModelManager manager; @Inject - public ThresholdResultTransportAction(ActionFilters actionFilters, TransportService transportService, ModelManager manager) { + public ThresholdResultTransportAction(ActionFilters actionFilters, TransportService transportService, ADModelManager manager) { super(ThresholdResultAction.NAME, transportService, actionFilters, ThresholdResultRequest::new); this.manager = manager; } diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java index 432166ac2..cf3f2325a 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java @@ -12,14 +12,15 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.ValidateConfigResponse; -public class ValidateAnomalyDetectorAction extends ActionType { +public class ValidateAnomalyDetectorAction extends ActionType { - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/validate"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/validate"; public static final ValidateAnomalyDetectorAction INSTANCE = new ValidateAnomalyDetectorAction(); private ValidateAnomalyDetectorAction() { - super(NAME, ValidateAnomalyDetectorResponse::new); + super(NAME, ValidateConfigResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java deleted file mode 100644 index 3ee1f0a6e..000000000 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.transport; - -import java.io.IOException; - -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; - -public class ValidateAnomalyDetectorRequest extends ActionRequest { - - private final AnomalyDetector detector; - private final String validationType; - private final Integer maxSingleEntityAnomalyDetectors; - private final Integer maxMultiEntityAnomalyDetectors; - private final Integer maxAnomalyFeatures; - private final TimeValue requestTimeout; - - public ValidateAnomalyDetectorRequest(StreamInput in) throws IOException { - super(in); - detector = new AnomalyDetector(in); - validationType = in.readString(); - maxSingleEntityAnomalyDetectors = in.readInt(); - maxMultiEntityAnomalyDetectors = in.readInt(); - maxAnomalyFeatures = in.readInt(); - requestTimeout = in.readTimeValue(); - } - - public ValidateAnomalyDetectorRequest( - AnomalyDetector detector, - String validationType, - Integer maxSingleEntityAnomalyDetectors, - Integer maxMultiEntityAnomalyDetectors, - Integer maxAnomalyFeatures, - TimeValue requestTimeout - ) { - this.detector = detector; - this.validationType = validationType; - this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; - this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; - this.maxAnomalyFeatures = maxAnomalyFeatures; - this.requestTimeout = requestTimeout; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - detector.writeTo(out); - out.writeString(validationType); - out.writeInt(maxSingleEntityAnomalyDetectors); - out.writeInt(maxMultiEntityAnomalyDetectors); - out.writeInt(maxAnomalyFeatures); - out.writeTimeValue(requestTimeout); - } - - @Override - public ActionRequestValidationException validate() { - return null; - } - - public AnomalyDetector getDetector() { - return detector; - } - - public String getValidationType() { - return validationType; - } - - public Integer getMaxSingleEntityAnomalyDetectors() { - return maxSingleEntityAnomalyDetectors; - } - - public Integer getMaxMultiEntityAnomalyDetectors() { - return maxMultiEntityAnomalyDetectors; - } - - public Integer getMaxAnomalyFeatures() { - return maxAnomalyFeatures; - } - - public TimeValue getRequestTimeout() { - return requestTimeout; - } -} diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java index 16eec43ac..db43e038c 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java @@ -12,61 +12,31 @@ package org.opensearch.ad.transport; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; -import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; - -import java.time.Clock; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; -import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.rest.RestRequest; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.tasks.Task; -import org.opensearch.timeseries.common.exception.ValidationException; import org.opensearch.timeseries.feature.SearchFeatureDao; -import org.opensearch.timeseries.function.ExecutorFunction; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.model.ValidationAspect; -import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.rest.handler.Processor; +import org.opensearch.timeseries.transport.BaseValidateConfigTransportAction; +import org.opensearch.timeseries.transport.ValidateConfigRequest; +import org.opensearch.timeseries.transport.ValidateConfigResponse; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; -public class ValidateAnomalyDetectorTransportAction extends - HandledTransportAction { - private static final Logger logger = LogManager.getLogger(ValidateAnomalyDetectorTransportAction.class); - - private final Client client; - private final SecurityClientUtil clientUtil; - private final ClusterService clusterService; - private final NamedXContentRegistry xContentRegistry; - private final ADIndexManagement anomalyDetectionIndices; - private final SearchFeatureDao searchFeatureDao; - private volatile Boolean filterByEnabled; - private Clock clock; - private Settings settings; +public class ValidateAnomalyDetectorTransportAction extends BaseValidateConfigTransportAction { + public static final Logger logger = LogManager.getLogger(ValidateAnomalyDetectorTransportAction.class); @Inject public ValidateAnomalyDetectorTransportAction( @@ -80,176 +50,41 @@ public ValidateAnomalyDetectorTransportAction( TransportService transportService, SearchFeatureDao searchFeatureDao ) { - super(ValidateAnomalyDetectorAction.NAME, transportService, actionFilters, ValidateAnomalyDetectorRequest::new); - this.client = client; - this.clientUtil = clientUtil; - this.clusterService = clusterService; - this.xContentRegistry = xContentRegistry; - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.filterByEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); - this.searchFeatureDao = searchFeatureDao; - this.clock = Clock.systemUTC(); - this.settings = settings; + super( + ValidateAnomalyDetectorAction.NAME, + client, + clientUtil, + clusterService, + xContentRegistry, + settings, + anomalyDetectionIndices, + actionFilters, + transportService, + searchFeatureDao, + AD_FILTER_BY_BACKEND_ROLES + ); } @Override - protected void doExecute(Task task, ValidateAnomalyDetectorRequest request, ActionListener listener) { - User user = getUserContext(client); - AnomalyDetector anomalyDetector = request.getDetector(); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - resolveUserAndExecute(user, listener, () -> validateExecute(request, user, context, listener)); - } catch (Exception e) { - logger.error(e); - listener.onFailure(e); - } - } - - private void resolveUserAndExecute( - User requestedUser, - ActionListener listener, - ExecutorFunction function - ) { - try { - // Check if user has backend roles - // When filter by is enabled, block users validating detectors who do not have backend roles. - if (filterByEnabled && !checkFilterByBackendRoles(requestedUser, listener)) { - return; - } - // Validate Detector - function.execute(); - } catch (Exception e) { - listener.onFailure(e); - } - } - - private void validateExecute( - ValidateAnomalyDetectorRequest request, - User user, - ThreadContext.StoredContext storedContext, - ActionListener listener - ) { - storedContext.restore(); - AnomalyDetector detector = request.getDetector(); - ActionListener validateListener = ActionListener.wrap(response -> { - logger.debug("Result of validation process " + response); - // forcing response to be empty - listener.onResponse(new ValidateAnomalyDetectorResponse((DetectorValidationIssue) null)); - }, exception -> { - if (exception instanceof ValidationException) { - // ADValidationException is converted as validation issues returned as response to user - DetectorValidationIssue issue = parseADValidationException((ValidationException) exception); - listener.onResponse(new ValidateAnomalyDetectorResponse(issue)); - return; - } - logger.error(exception); - listener.onFailure(exception); - }); - checkIndicesAndExecute(detector.getIndices(), () -> { - ValidateAnomalyDetectorActionHandler handler = new ValidateAnomalyDetectorActionHandler( - clusterService, - client, - clientUtil, - validateListener, - anomalyDetectionIndices, - detector, - request.getRequestTimeout(), - request.getMaxSingleEntityAnomalyDetectors(), - request.getMaxMultiEntityAnomalyDetectors(), - request.getMaxAnomalyFeatures(), - RestRequest.Method.POST, - xContentRegistry, - user, - searchFeatureDao, - request.getValidationType(), - clock, - settings - ); - try { - handler.start(); - } catch (Exception exception) { - String errorMessage = String - .format(Locale.ROOT, "Unknown exception caught while validating detector %s", request.getDetector()); - logger.error(errorMessage, exception); - listener.onFailure(exception); - } - }, listener); - } - - protected DetectorValidationIssue parseADValidationException(ValidationException exception) { - String originalErrorMessage = exception.getMessage(); - String errorMessage = ""; - Map subIssues = null; - IntervalTimeConfiguration intervalSuggestion = exception.getIntervalSuggestion(); - switch (exception.getType()) { - case FEATURE_ATTRIBUTES: - int firstLeftBracketIndex = originalErrorMessage.indexOf("["); - int lastRightBracketIndex = originalErrorMessage.lastIndexOf("]"); - if (firstLeftBracketIndex != -1) { - // if feature issue messages are between square brackets like - // [Feature has issue: A, Feature has issue: B] - errorMessage = originalErrorMessage.substring(firstLeftBracketIndex + 1, lastRightBracketIndex); - subIssues = getFeatureSubIssuesFromErrorMessage(errorMessage); - } else { - // features having issue like over max feature limit, duplicate feature name, etc. - errorMessage = originalErrorMessage; - } - break; - case NAME: - case CATEGORY: - case DETECTION_INTERVAL: - case FILTER_QUERY: - case TIMEFIELD_FIELD: - case SHINGLE_SIZE_FIELD: - case WINDOW_DELAY: - case RESULT_INDEX: - case GENERAL_SETTINGS: - case AGGREGATION: - case TIMEOUT: - case INDICES: - errorMessage = originalErrorMessage; - break; - } - return new DetectorValidationIssue(exception.getAspect(), exception.getType(), errorMessage, subIssues, intervalSuggestion); - } - - // Example of method output: - // String input:Feature has invalid query returning empty aggregated data: average_total_rev, Feature has invalid query causing runtime - // exception: average_total_rev-2 - // output: "sub_issues": { - // "average_total_rev": "Feature has invalid query returning empty aggregated data", - // "average_total_rev-2": "Feature has invalid query causing runtime exception" - // } - private Map getFeatureSubIssuesFromErrorMessage(String errorMessage) { - Map result = new HashMap<>(); - String[] subIssueMessagesSuffix = errorMessage.split(", "); - for (int i = 0; i < subIssueMessagesSuffix.length; i++) { - result.put(subIssueMessagesSuffix[i].split(": ")[1], subIssueMessagesSuffix[i].split(": ")[0]); - } - return result; - } - - private void checkIndicesAndExecute( - List indices, - ExecutorFunction function, - ActionListener listener - ) { - SearchRequest searchRequest = new SearchRequest() - .indices(indices.toArray(new String[0])) - .source(new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery())); - client.search(searchRequest, ActionListener.wrap(r -> function.execute(), e -> { - if (e instanceof IndexNotFoundException) { - // IndexNotFoundException is converted to a ADValidationException that gets - // parsed to a DetectorValidationIssue that is returned to - // the user as a response indicating index doesn't exist - DetectorValidationIssue issue = parseADValidationException( - new ValidationException(ADCommonMessages.INDEX_NOT_FOUND, ValidationIssueType.INDICES, ValidationAspect.DETECTOR) - ); - listener.onResponse(new ValidateAnomalyDetectorResponse(issue)); - return; - } - logger.error(e); - listener.onFailure(e); - })); + protected Processor createProcessor(Config detector, ValidateConfigRequest request, User user) { + return new ValidateAnomalyDetectorActionHandler( + clusterService, + client, + clientUtil, + indexManagement, + detector, + request.getRequestTimeout(), + request.getMaxSingleEntityAnomalyDetectors(), + request.getMaxMultiEntityAnomalyDetectors(), + request.getMaxAnomalyFeatures(), + request.getMaxCategoricalFields(), + RestRequest.Method.POST, + xContentRegistry, + user, + searchFeatureDao, + request.getValidationType(), + clock, + settings + ); } } diff --git a/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java b/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java new file mode 100644 index 000000000..175017450 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.transport.ADResultBulkAction; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; + +public class ADIndexMemoryPressureAwareResultHandler extends + IndexMemoryPressureAwareResultHandler { + private static final Logger LOG = LogManager.getLogger(ADIndexMemoryPressureAwareResultHandler.class); + + @Inject + public ADIndexMemoryPressureAwareResultHandler(Client client, ADIndexManagement anomalyDetectionIndices) { + super(client, anomalyDetectionIndices); + } + + @Override + public void bulk(ADResultBulkRequest currentBulkRequest, ActionListener listener) { + if (currentBulkRequest.numberOfActions() <= 0) { + listener.onFailure(new TimeSeriesException("no result to save")); + return; + } + client.execute(ADResultBulkAction.INSTANCE, currentBulkRequest, ActionListener.wrap(response -> { + LOG.debug(CommonMessages.SUCCESS_SAVING_RESULT_MSG); + listener.onResponse(response); + }, exception -> { + LOG.error("Error in bulking results", exception); + listener.onFailure(exception); + })); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java b/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java index 05c69196d..7a5312652 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java +++ b/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java @@ -11,74 +11,18 @@ package org.opensearch.ad.transport.handler; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_SEARCH; -import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; -import static org.opensearch.timeseries.util.ParseUtils.isAdmin; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.transport.handler.SearchHandler; /** * Handle general search request, check user role and return search response. */ -public class ADSearchHandler { - private final Logger logger = LogManager.getLogger(ADSearchHandler.class); - private final Client client; - private volatile Boolean filterEnabled; +public class ADSearchHandler extends SearchHandler { public ADSearchHandler(Settings settings, ClusterService clusterService, Client client) { - this.client = client; - filterEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterEnabled = it); - } - - /** - * Validate user role, add backend role filter if filter enabled - * and execute search. - * - * @param request search request - * @param actionListener action listerner - */ - public void search(SearchRequest request, ActionListener actionListener) { - User user = getUserContext(client); - ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_SEARCH); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - validateRole(request, user, listener); - } catch (Exception e) { - logger.error(e); - listener.onFailure(e); - } + super(settings, clusterService, client, AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES); } - - private void validateRole(SearchRequest request, User user, ActionListener listener) { - if (user == null || !filterEnabled || isAdmin(user)) { - // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin - // Case 2: If Security is enabled and filter is disabled, proceed with search as - // user is already authenticated to hit this API. - // case 3: user is admin which means we don't have to check backend role filtering - client.search(request, listener); - } else { - // Security is enabled, filter is enabled and user isn't admin - try { - addUserBackendRolesFilter(user, request.source()); - logger.debug("Filtering result by " + user.getBackendRoles()); - client.search(request, listener); - } catch (Exception e) { - listener.onFailure(e); - } - } - } - } diff --git a/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java b/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java deleted file mode 100644 index 13f7e16e7..000000000 --- a/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.transport.handler; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.ExceptionsHelper; -import org.opensearch.ResourceAlreadyExistsException; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.transport.ADResultBulkAction; -import org.opensearch.ad.transport.ADResultBulkRequest; -import org.opensearch.ad.transport.ADResultBulkResponse; -import org.opensearch.ad.util.IndexUtils; -import org.opensearch.client.Client; -import org.opensearch.cluster.block.ClusterBlockLevel; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.common.exception.TimeSeriesException; -import org.opensearch.timeseries.util.ClientUtil; - -/** - * EntityResultTransportAction depends on this class. I cannot use - * AnomalyIndexHandler < AnomalyResult > . All transport actions - * needs dependency injection. Guice has a hard time initializing generics class - * AnomalyIndexHandler < AnomalyResult > due to type erasure. - * To avoid that, I create a class with a built-in details so - * that Guice would be able to work out the details. - * - */ -public class MultiEntityResultHandler extends AnomalyIndexHandler { - private static final Logger LOG = LogManager.getLogger(MultiEntityResultHandler.class); - // package private for testing - static final String SUCCESS_SAVING_RESULT_MSG = "Result saved successfully."; - static final String CANNOT_SAVE_RESULT_ERR_MSG = "Cannot save results due to write block."; - - @Inject - public MultiEntityResultHandler( - Client client, - Settings settings, - ThreadPool threadPool, - ADIndexManagement anomalyDetectionIndices, - ClientUtil clientUtil, - IndexUtils indexUtils, - ClusterService clusterService - ) { - super( - client, - settings, - threadPool, - ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, - anomalyDetectionIndices, - clientUtil, - indexUtils, - clusterService - ); - } - - /** - * Execute the bulk request - * @param currentBulkRequest The bulk request - * @param listener callback after flushing - */ - public void flush(ADResultBulkRequest currentBulkRequest, ActionListener listener) { - if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.indexName)) { - listener.onFailure(new TimeSeriesException(CANNOT_SAVE_RESULT_ERR_MSG)); - return; - } - - try { - if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { - anomalyDetectionIndices.initDefaultResultIndexDirectly(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { - bulk(currentBulkRequest, listener); - } else { - LOG.warn("Creating result index with mappings call not acknowledged."); - listener.onFailure(new TimeSeriesException("", "Creating result index with mappings call not acknowledged.")); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - bulk(currentBulkRequest, listener); - } else { - LOG.warn("Unexpected error creating result index", exception); - listener.onFailure(exception); - } - })); - } else { - bulk(currentBulkRequest, listener); - } - } catch (Exception e) { - LOG.warn("Error in bulking results", e); - listener.onFailure(e); - } - } - - private void bulk(ADResultBulkRequest currentBulkRequest, ActionListener listener) { - if (currentBulkRequest.numberOfActions() <= 0) { - listener.onFailure(new TimeSeriesException("no result to save")); - return; - } - client.execute(ADResultBulkAction.INSTANCE, currentBulkRequest, ActionListener.wrap(response -> { - LOG.debug(SUCCESS_SAVING_RESULT_MSG); - listener.onResponse(response); - }, exception -> { - LOG.error("Error in bulking results", exception); - listener.onFailure(exception); - })); - } -} diff --git a/src/main/java/org/opensearch/forecast/ExecuteForecastResultResponseRecorder.java b/src/main/java/org/opensearch/forecast/ExecuteForecastResultResponseRecorder.java new file mode 100644 index 000000000..16db8cb0d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ExecuteForecastResultResponseRecorder.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Optional; + +import org.opensearch.client.Client; +import org.opensearch.commons.authuser.User; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.ForecastProfileAction; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class ExecuteForecastResultResponseRecorder extends + ExecuteResultResponseRecorder { + + public ExecuteForecastResultResponseRecorder( + ForecastIndexManagement indexManagement, + ResultBulkIndexingHandler resultHandler, + ForecastTaskManager taskManager, + DiscoveryNodeFilterer nodeFilter, + ThreadPool threadPool, + Client client, + NodeStateManager nodeStateManager, + TaskCacheManager taskCacheManager, + int rcfMinSamples + ) { + super( + indexManagement, + resultHandler, + taskManager, + nodeFilter, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + client, + nodeStateManager, + taskCacheManager, + rcfMinSamples, + ForecastIndex.RESULT, + AnalysisType.FORECAST, + ForecastProfileAction.INSTANCE + ); + } + + @Override + protected ForecastResult createErrorResult( + String configId, + Instant dataStartTime, + Instant dataEndTime, + Instant executeEndTime, + String errorMessage, + User user + ) { + return new ForecastResult( + configId, + null, // no task id + new ArrayList(), + dataStartTime, + dataEndTime, + executeEndTime, + Instant.now(), + errorMessage, + Optional.empty(), // single-stream forecasters have no entity + user, + indexManagement.getSchemaVersion(resultIndex) + ); + } + + @Override + protected void updateRealtimeTask(ResultResponse response, String configId) { + if (taskManager.skipUpdateRealtimeTask(configId, response.getError())) { + return; + } + + delayedUpdate(response, configId); + } +} diff --git a/src/main/java/org/opensearch/forecast/ForecastEntityProfileRunner.java b/src/main/java/org/opensearch/forecast/ForecastEntityProfileRunner.java new file mode 100644 index 000000000..e1dcaebfd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ForecastEntityProfileRunner.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.forecast.transport.ForecastEntityProfileAction; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.EntityProfileRunner; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class ForecastEntityProfileRunner extends EntityProfileRunner { + + public ForecastEntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples + ) { + super( + client, + clientUtil, + xContentRegistry, + requiredSamples, + Forecaster::parse, + ForecastNumericSetting.maxCategoricalFields(), + AnalysisType.FORECAST, + ForecastEntityProfileAction.INSTANCE, + ForecastIndex.RESULT.getIndexName(), + ForecastCommonName.FORECASTER_ID_KEY + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/ForecastJobProcessor.java b/src/main/java/org/opensearch/forecast/ForecastJobProcessor.java new file mode 100644 index 000000000..d6128d030 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ForecastJobProcessor.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import java.time.Instant; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.rest.handler.ForecastIndexJobActionHandler; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.ForecastProfileAction; +import org.opensearch.forecast.transport.ForecastResultAction; +import org.opensearch.forecast.transport.ForecastResultRequest; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.JobProcessor; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultRequest; + +public class ForecastJobProcessor extends + JobProcessor { + + private static final Logger log = LogManager.getLogger(ForecastJobProcessor.class); + + private static ForecastJobProcessor INSTANCE; + + public static ForecastJobProcessor getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (ForecastJobProcessor.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new ForecastJobProcessor(); + return INSTANCE; + } + } + + private ForecastJobProcessor() { + // Singleton class, use getJobRunnerInstance method instead of constructor + super(AnalysisType.FORECAST, TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, ForecastResultAction.INSTANCE); + } + + public void registerSettings(Settings settings) { + super.registerSettings(settings, ForecastSettings.FORECAST_MAX_RETRY_FOR_END_RUN_EXCEPTION); + } + + @Override + protected ResultRequest createResultRequest(String configId, long start, long end) { + return new ForecastResultRequest(configId, start, end); + } + + @Override + protected void validateResultIndexAndRunJob( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + String configId, + String user, + List roles, + ExecuteForecastResultResponseRecorder recorder, + Config detector + ) { + ActionListener listener = ActionListener.wrap(r -> { log.debug("Result index is valid"); }, e -> { + Exception exception = new EndRunException(configId, e.getMessage(), false); + handleException(jobParameter, lockService, lock, executionStartTime, executionEndTime, exception, recorder, detector); + }); + String resultIndex = jobParameter.getCustomResultIndex(); + if (resultIndex == null) { + indexManagement.validateDefaultResultIndexForBackendJob(configId, user, roles, () -> { + listener.onResponse(true); + runJob(jobParameter, lockService, lock, executionStartTime, executionEndTime, configId, user, roles, recorder, detector); + }, listener); + } else { + indexManagement.validateCustomIndexForBackendJob(resultIndex, configId, user, roles, () -> { + listener.onResponse(true); + runJob(jobParameter, lockService, lock, executionStartTime, executionEndTime, configId, user, roles, recorder, detector); + }, listener); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/ForecastProfileRunner.java b/src/main/java/org/opensearch/forecast/ForecastProfileRunner.java new file mode 100644 index 000000000..10c1301fd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ForecastProfileRunner.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskProfile; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.model.ForecasterProfile; +import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.ForecastProfileAction; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ProfileRunner; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ForecastProfileRunner extends + ProfileRunner { + + public ForecastProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeFilterer nodeFilter, + long requiredSamples, + TransportService transportService, + ForecastTaskManager forecastTaskManager, + ForecastTaskProfileRunner taskProfileRunner + ) { + super( + client, + clientUtil, + xContentRegistry, + nodeFilter, + requiredSamples, + transportService, + forecastTaskManager, + AnalysisType.FORECAST, + ForecastTaskType.REALTIME_TASK_TYPES, + ForecastTaskType.RUN_ONCE_TASK_TYPES, + ForecastNumericSetting.maxCategoricalFields(), + ProfileName.FORECAST_TASK, + ForecastProfileAction.INSTANCE, + Forecaster::parse, + taskProfileRunner + ); + } + + @Override + protected ForecasterProfile.Builder createProfileBuilder() { + return new ForecasterProfile.Builder(); + } + +} diff --git a/src/main/java/org/opensearch/forecast/ForecastTaskProfileRunner.java b/src/main/java/org/opensearch/forecast/ForecastTaskProfileRunner.java new file mode 100644 index 000000000..f7deb5578 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ForecastTaskProfileRunner.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskProfile; +import org.opensearch.timeseries.TaskProfileRunner; + +public class ForecastTaskProfileRunner implements TaskProfileRunner { + + @Override + public void getTaskProfile(ForecastTask configLevelTask, ActionListener listener) { + // return null since forecasting have no in-memory task profiles as AD + listener.onResponse(null); + } + +} diff --git a/src/main/java/org/opensearch/forecast/caching/ForecastCacheBuffer.java b/src/main/java/org/opensearch/forecast/caching/ForecastCacheBuffer.java new file mode 100644 index 000000000..74fb06d89 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/caching/ForecastCacheBuffer.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.caching; + +import java.time.Clock; +import java.time.Duration; + +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ratelimit.ForecastCheckpointMaintainWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.CacheBuffer; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCacheBuffer extends + CacheBuffer { + + public ForecastCacheBuffer( + int minimumCapacity, + Clock clock, + MemoryTracker memoryTracker, + int checkpointIntervalHrs, + Duration modelTtl, + long memoryConsumptionPerEntity, + ForecastCheckpointWriteWorker checkpointWriteQueue, + ForecastCheckpointMaintainWorker checkpointMaintainQueue, + String configId, + long intervalSecs + ) { + super( + minimumCapacity, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + memoryConsumptionPerEntity, + checkpointWriteQueue, + checkpointMaintainQueue, + configId, + intervalSecs, + Origin.REAL_TIME_FORECASTER + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/caching/ForecastCacheProvider.java b/src/main/java/org/opensearch/forecast/caching/ForecastCacheProvider.java new file mode 100644 index 000000000..f93982cc2 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/caching/ForecastCacheProvider.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.caching; + +import org.opensearch.timeseries.caching.CacheProvider; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCacheProvider extends CacheProvider { + +} diff --git a/src/main/java/org/opensearch/forecast/caching/ForecastPriorityCache.java b/src/main/java/org/opensearch/forecast/caching/ForecastPriorityCache.java new file mode 100644 index 000000000..2d79e8224 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/caching/ForecastPriorityCache.java @@ -0,0 +1,117 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.caching; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_DEDICATED_CACHE_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MODEL_MAX_SIZE_PERCENTAGE; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Optional; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ratelimit.ForecastCheckpointMaintainWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.PriorityCache; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.model.Config; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastPriorityCache extends + PriorityCache { + private ForecastCheckpointWriteWorker checkpointWriteQueue; + private ForecastCheckpointMaintainWorker checkpointMaintainQueue; + + public ForecastPriorityCache( + ForecastCheckpointDao checkpointDao, + int hcDedicatedCacheSize, + Setting checkpointTtl, + int maxInactiveStates, + MemoryTracker memoryTracker, + int numberOfTrees, + Clock clock, + ClusterService clusterService, + Duration modelTtl, + ThreadPool threadPool, + String threadPoolName, + int maintenanceFreqConstant, + Settings settings, + Setting checkpointSavingFreq, + ForecastCheckpointWriteWorker checkpointWriteQueue, + ForecastCheckpointMaintainWorker checkpointMaintainQueue + ) { + super( + checkpointDao, + hcDedicatedCacheSize, + checkpointTtl, + maxInactiveStates, + memoryTracker, + numberOfTrees, + clock, + clusterService, + modelTtl, + threadPool, + threadPoolName, + maintenanceFreqConstant, + settings, + checkpointSavingFreq, + Origin.REAL_TIME_FORECASTER, + FORECAST_DEDICATED_CACHE_SIZE, + FORECAST_MODEL_MAX_SIZE_PERCENTAGE + ); + + this.checkpointWriteQueue = checkpointWriteQueue; + this.checkpointMaintainQueue = checkpointMaintainQueue; + } + + @Override + protected ForecastCacheBuffer createEmptyCacheBuffer(Config config, long requiredMemory) { + return new ForecastCacheBuffer( + config.isHighCardinality() ? hcDedicatedCacheSize : 1, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + requiredMemory, + checkpointWriteQueue, + checkpointMaintainQueue, + config.getId(), + config.getIntervalInSeconds() + ); + } + + @Override + protected ModelState createEmptyModelState(String modelId, String forecasterId) { + return new ModelState<>( + null, + modelId, + forecasterId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + Optional.empty(), + new ArrayDeque<>() + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/cluster/diskcleanup/ForecastCheckpointIndexRetention.java b/src/main/java/org/opensearch/forecast/cluster/diskcleanup/ForecastCheckpointIndexRetention.java new file mode 100644 index 000000000..92c1ac1e7 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/cluster/diskcleanup/ForecastCheckpointIndexRetention.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.cluster.diskcleanup; + +import java.time.Clock; +import java.time.Duration; + +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.timeseries.cluster.diskcleanup.BaseModelCheckpointIndexRetention; +import org.opensearch.timeseries.cluster.diskcleanup.IndexCleanup; + +public class ForecastCheckpointIndexRetention extends BaseModelCheckpointIndexRetention { + + public ForecastCheckpointIndexRetention(Duration defaultCheckpointTtl, Clock clock, IndexCleanup indexCleanup) { + super(defaultCheckpointTtl, clock, indexCleanup, ForecastIndex.CHECKPOINT.getIndexName()); + } + +} diff --git a/src/main/java/org/opensearch/forecast/constant/ForecastCommonName.java b/src/main/java/org/opensearch/forecast/constant/ForecastCommonName.java index 8edaf2d2b..f9dc48985 100644 --- a/src/main/java/org/opensearch/forecast/constant/ForecastCommonName.java +++ b/src/main/java/org/opensearch/forecast/constant/ForecastCommonName.java @@ -45,4 +45,9 @@ public class ForecastCommonName { // Used in stats API // ====================================== public static final String FORECASTER_ID_KEY = "forecaster_id"; + + // ====================================== + // Historical forecasters + // ====================================== + public static final String FORECAST_TASK = "forecast_task"; } diff --git a/src/main/java/org/opensearch/forecast/indices/ForecastIndexManagement.java b/src/main/java/org/opensearch/forecast/indices/ForecastIndexManagement.java index e7d3f3252..bc2798773 100644 --- a/src/main/java/org/opensearch/forecast/indices/ForecastIndexManagement.java +++ b/src/main/java/org/opensearch/forecast/indices/ForecastIndexManagement.java @@ -22,17 +22,26 @@ import java.io.IOException; import java.util.EnumMap; +import java.util.List; +import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.admin.indices.alias.get.GetAliasesRequest; +import org.opensearch.action.admin.indices.alias.get.GetAliasesResponse; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.IndicesOptions; import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.AliasMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.InjectSecurity; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -40,6 +49,8 @@ import org.opensearch.forecast.model.ForecastResult; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.indices.IndexManagement; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -268,4 +279,104 @@ public void initCustomResultIndexDirectly(String resultIndex, ActionListener void validateDefaultResultIndexForBackendJob( + String configId, + String user, + List roles, + ExecutorFunction function, + ActionListener listener + ) { + if (doesAliasExist(ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS)) { + handleExistingIndex(configId, user, roles, function, listener); + } else { + initDefaultResultIndex(configId, user, roles, function, listener); + } + } + + private void handleExistingIndex( + String configId, + String user, + List roles, + ExecutorFunction function, + ActionListener listener + ) { + GetAliasesRequest getAliasRequest = new GetAliasesRequest() + .aliases(ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS) + .indicesOptions(IndicesOptions.lenientExpandOpenHidden()); + + adminClient.indices().getAliases(getAliasRequest, ActionListener.wrap(getAliasResponse -> { + String concreteIndex = getConcreteIndexFromAlias(getAliasResponse); + if (concreteIndex == null) { + listener.onFailure(new EndRunException("Result index alias mapping is empty", false)); + return; + } + + if (!isValidResultIndexMapping(concreteIndex)) { + listener.onFailure(new EndRunException("Result index mapping is not correct", false)); + return; + } + + executeWithSecurityContext(configId, user, roles, function, listener, concreteIndex); + + }, listener::onFailure)); + } + + private String getConcreteIndexFromAlias(GetAliasesResponse getAliasResponse) { + for (Map.Entry> entry : getAliasResponse.getAliases().entrySet()) { + if (!entry.getValue().isEmpty()) { + return entry.getKey(); + } + } + return null; + } + + private void initDefaultResultIndex( + String configId, + String user, + List roles, + ExecutorFunction function, + ActionListener listener + ) { + initDefaultResultIndexDirectly(ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + executeWithSecurityContext(configId, user, roles, function, listener, ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS); + } else { + String error = "Creating result index with mappings call not acknowledged"; + logger.error(error); + listener.onFailure(new TimeSeriesException(error)); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + executeWithSecurityContext(configId, user, roles, function, listener, ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS); + } else { + listener.onFailure(exception); + } + })); + } + + private void executeWithSecurityContext( + String securityLogId, + String user, + List roles, + ExecutorFunction function, + ActionListener listener, + String indexOrAlias + ) { + try (InjectSecurity injectSecurity = new InjectSecurity(securityLogId, settings, client.threadPool().getThreadContext())) { + injectSecurity.inject(user, roles); + ActionListener wrappedListener = ActionListener.wrap(listener::onResponse, e -> { + injectSecurity.close(); + listener.onFailure(e); + }); + validateResultIndexAndExecute(indexOrAlias, () -> { + injectSecurity.close(); + function.execute(); + }, true, wrappedListener); + } catch (Exception e) { + logger.error("Failed to validate custom index for backend job " + securityLogId, e); + listener.onFailure(e); + } + } + } diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastCheckpointDao.java b/src/main/java/org/opensearch/forecast/ml/ForecastCheckpointDao.java new file mode 100644 index 000000000..dc6fa4d50 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/ForecastCheckpointDao.java @@ -0,0 +1,404 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneOffset; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.ClientUtil; + +import com.amazon.randomcutforest.parkservices.RCFCaster; +import com.amazon.randomcutforest.parkservices.state.RCFCasterMapper; +import com.amazon.randomcutforest.parkservices.state.RCFCasterState; +import com.google.gson.Gson; + +import io.protostuff.LinkedBuffer; +import io.protostuff.ProtostuffIOUtil; +import io.protostuff.Schema; + +/** + * The ForecastCheckpointDao class implements all the functionality required for fetching, updating and + * removing forecast checkpoints. + * + */ +public class ForecastCheckpointDao extends CheckpointDao { + public static final Logger logger = LogManager.getLogger(ForecastCheckpointDao.class); + + static final String NOT_ABLE_TO_DELETE_CHECKPOINT_MSG = "Cannot delete all checkpoints of forecaster"; + + RCFCasterMapper mapper; + private Schema rcfCasterSchema; + + public ForecastCheckpointDao( + Client client, + ClientUtil clientUtil, + Gson gson, + int maxCheckpointBytes, + GenericObjectPool serializeRCFBufferPool, + int serializeRCFBufferSize, + ForecastIndexManagement indexUtil, + RCFCasterMapper mapper, + Schema rcfCasterSchema, + Clock clock + ) { + super( + client, + clientUtil, + ForecastIndex.CHECKPOINT.getIndexName(), + gson, + maxCheckpointBytes, + serializeRCFBufferPool, + serializeRCFBufferSize, + indexUtil, + clock + ); + this.mapper = mapper; + this.rcfCasterSchema = rcfCasterSchema; + } + + /** + * Puts a RCFCaster model checkpoint in the storage. Used in single-stream forecasting. + * + * @param modelId id of the model + * @param caster the RCFCaster model + * @param listener onResponse is called with null when the operation is completed + */ + public void putCasterCheckpoint(String modelId, RCFCaster caster, ActionListener listener) { + Map source = new HashMap<>(); + Optional modelCheckpoint = toCheckpoint(Optional.of(caster)); + if (modelCheckpoint.isPresent()) { + source.put(CommonName.FIELD_MODEL, modelCheckpoint.get()); + source.put(CommonName.TIMESTAMP, clock.instant().atZone(ZoneOffset.UTC)); + source.put(CommonName.TIMESTAMP, clock.instant().atZone(ZoneOffset.UTC)); + source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ForecastIndex.CHECKPOINT)); + putModelCheckpoint(modelId, source, listener); + } else { + listener.onFailure(new RuntimeException("Fail to create checkpoint to save")); + } + } + + private Optional toCheckpoint(Optional caster) { + if (caster.isEmpty()) { + return Optional.empty(); + } + Optional checkpoint = Optional.empty(); + Map.Entry result = checkoutOrNewBuffer(); + LinkedBuffer buffer = result.getKey(); + boolean needCheckin = result.getValue(); + try { + checkpoint = toCheckpoint(caster, buffer); + } catch (Exception e) { + logger.error("Failed to serialize model", e); + if (needCheckin) { + try { + serializeRCFBufferPool.invalidateObject(buffer); + needCheckin = false; + } catch (Exception x) { + logger.warn("Failed to invalidate buffer", x); + } + try { + checkpoint = toCheckpoint(caster, LinkedBuffer.allocate(serializeRCFBufferSize)); + } catch (Exception ex) { + logger.warn("Failed to generate checkpoint", ex); + } + } + } finally { + if (needCheckin) { + try { + serializeRCFBufferPool.returnObject(buffer); + } catch (Exception e) { + logger.warn("Failed to return buffer to pool", e); + } + } + } + return checkpoint; + } + + private Optional toCheckpoint(Optional caster, LinkedBuffer buffer) { + if (caster.isEmpty()) { + return Optional.empty(); + } + try { + byte[] bytes = AccessController.doPrivileged((PrivilegedAction) () -> { + RCFCasterState casterState = mapper.toState(caster.get()); + return ProtostuffIOUtil.toByteArray(casterState, rcfCasterSchema, buffer); + }); + return Optional.ofNullable(Base64.getEncoder().encodeToString(bytes)); + } finally { + buffer.clear(); + } + } + + /** + * Prepare for index request using the contents of the given model state. Used in HC forecasting. + * @param modelState an entity model state + * @return serialized JSON map or empty map if the state is too bloated + * @throws IOException when serialization fails + */ + @Override + public Map toIndexSource(ModelState modelState) throws IOException { + Map source = new HashMap<>(); + Optional model = modelState.getModel(); + + Optional serializedModel = toCheckpoint(model); + if (serializedModel.isPresent() && serializedModel.get().length() <= maxCheckpointBytes) { + // we cannot pass Optional as OpenSearch does not know how to serialize an Optional value + source.put(CommonName.FIELD_MODEL, serializedModel.get()); + } else { + logger + .warn( + new ParameterizedMessage( + "[{}]'s model is empty or too large: [{}] bytes", + modelState.getModelId(), + serializedModel.isPresent() ? serializedModel.get().length() : 0 + ) + ); + } + Optional samples = toCheckpoint(modelState.getSamples()); + if (samples.isPresent()) { + source.put(CommonName.SAMPLE_QUEUE, samples.get()); + } + // if there are no samples and no model, no need to index as other information are meta data + if (!source.containsKey(CommonName.SAMPLE_QUEUE) && !source.containsKey(CommonName.FIELD_MODEL)) { + logger.info("nothing to save for [{}]", modelState.getModelId()); + return source; + } + + source.put(ForecastCommonName.FORECASTER_ID_KEY, modelState.getConfigId()); + source.put(CommonName.TIMESTAMP, clock.instant().atZone(ZoneOffset.UTC)); + source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ForecastIndex.CHECKPOINT)); + + Optional entity = modelState.getEntity(); + if (entity.isPresent()) { + source.put(CommonName.ENTITY_KEY, entity.get()); + } + return source; + } + + private void deserializeRCFCasterModel(GetResponse response, String rcfModelId, ActionListener> listener) { + Object model = null; + if (response.isExists()) { + try { + model = response.getSource().get(CommonName.FIELD_MODEL); + listener.onResponse(Optional.ofNullable(toRCFCaster((String) model))); + + } catch (Exception e) { + logger.error(new ParameterizedMessage("Unexpected error when deserializing [{}]", rcfModelId), e); + listener.onResponse(Optional.empty()); + } + } else { + listener.onResponse(Optional.empty()); + } + } + + RCFCaster toRCFCaster(String checkpoint) { + RCFCaster rcfCaster = null; + if (checkpoint != null && checkpoint.length() > 0) { + try { + byte[] bytes = Base64.getDecoder().decode(checkpoint); + RCFCasterState state = rcfCasterSchema.newMessage(); + AccessController.doPrivileged((PrivilegedAction) () -> { + ProtostuffIOUtil.mergeFrom(bytes, state, rcfCasterSchema); + return null; + }); + rcfCaster = mapper.toModel(state); + } catch (RuntimeException e) { + logger.error("Failed to deserialize RCFCaster model", e); + } + } + return rcfCaster; + } + + /** + * Returns to listener the checkpoint for the RCFCaster model. Used in single-stream forecasting. + * + * @param modelId id of the model + * @param listener onResponse is called with the model checkpoint, or empty for no such model + */ + public void getCasterModel(String modelId, ActionListener> listener) { + clientUtil + .asyncRequest( + new GetRequest(indexName, modelId), + client::get, + ActionListener.wrap(response -> deserializeRCFCasterModel(response, modelId, listener), exception -> { + // expected exception, don't print stack trace + if (exception instanceof IndexNotFoundException) { + listener.onResponse(Optional.empty()); + } else { + listener.onFailure(exception); + } + }) + ); + } + + /** + * Load json checkpoint into models. Used in HC forecasting. + * + * @param checkpoint json checkpoint contents + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time; or empty if + * the raw checkpoint is too large + */ + @Override + protected ModelState fromEntityModelCheckpoint(Map checkpoint, String modelId, String configId) { + try { + return AccessController.doPrivileged((PrivilegedAction>) () -> { + + RCFCaster rcfCaster = loadRCFCaster(checkpoint, modelId); + + Entity entity = null; + Object serializedEntity = checkpoint.get(CommonName.ENTITY_KEY); + if (serializedEntity != null) { + try { + entity = Entity.fromJsonArray(serializedEntity); + } catch (Exception e) { + logger.error(new ParameterizedMessage("fail to parse entity", serializedEntity), e); + } + } + + ModelState modelState = new ModelState( + rcfCaster, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + Optional.ofNullable(entity), + loadSampleQueue(checkpoint, modelId) + ); + + modelState.setLastCheckpointTime(loadTimestamp(checkpoint, modelId)); + + return modelState; + }); + } catch (Exception e) { + logger.warn("Exception while deserializing checkpoint " + modelId, e); + // checkpoint corrupted (e.g., a checkpoint not recognized by current code + // due to bugs). Better redo training. + return null; + } + } + + /** + * Delete checkpoints associated with a forecaster. Used in HC forecaster. + * @param forecasterId Forecaster Id + */ + public void deleteModelCheckpointByForecasterId(String forecasterId) { + // A bulk delete request is performed for each batch of matching documents. If a + // search or bulk request is rejected, the requests are retried up to 10 times, + // with exponential back off. If the maximum retry limit is reached, processing + // halts and all failed requests are returned in the response. Any delete + // requests that completed successfully still stick, they are not rolled back. + DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(indexName) + .setQuery(new MatchQueryBuilder(ForecastCommonName.FORECASTER_ID_KEY, forecasterId)) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) + .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. + // Retry in this case + .setRequestsPerSecond(500); // throttle delete requests + logger.info("Delete checkpoints of forecaster {}", forecasterId); + client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { + if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { + logFailure(response, forecasterId); + } + // can return 0 docs get deleted because: + // 1) we cannot find matching docs + // 2) bad stats from OpenSearch. In this case, docs are deleted, but + // OpenSearch says deleted is 0. + logger.info("{} " + CheckpointDao.DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info(CheckpointDao.INDEX_DELETED_LOG_MSG + " {}", forecasterId); + } else { + // Gonna eventually delete in daily cron. + logger.error(NOT_ABLE_TO_DELETE_CHECKPOINT_MSG, exception); + } + })); + } + + @Override + protected DeleteByQueryRequest createDeleteCheckpointRequest(String configId) { + return new DeleteByQueryRequest(indexName) + .setQuery(new MatchQueryBuilder(ForecastCommonName.FORECASTER_ID_KEY, configId)) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) + .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. + // Retry in this case + .setRequestsPerSecond(500); // throttle delete requests + } + + @Override + protected ModelState fromSingleStreamModelCheckpoint(Map checkpoint, String modelId, String configId) { + + return AccessController.doPrivileged((PrivilegedAction>) () -> { + + RCFCaster rcfCaster = loadRCFCaster(checkpoint, modelId); + + ModelState modelState = new ModelState( + rcfCaster, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + Optional.empty(), + loadSampleQueue(checkpoint, modelId) + ); + + modelState.setLastCheckpointTime(loadTimestamp(checkpoint, modelId)); + + return modelState; + }); + } + + private RCFCaster loadRCFCaster(Map checkpoint, String modelId) { + String model = (String) checkpoint.get(CommonName.FIELD_MODEL); + if (model == null || model.length() > maxCheckpointBytes) { + logger.warn(new ParameterizedMessage("[{}]'s model too large: [{}] bytes", modelId, model.length())); + return null; + } + return toRCFCaster(model); + } + + private Instant loadTimestamp(Map checkpoint, String modelId) { + String lastCheckpointTimeString = (String) (checkpoint.get(CommonName.TIMESTAMP)); + return Instant.parse(lastCheckpointTimeString); + } +} diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastColdStart.java b/src/main/java/org/opensearch/forecast/ml/ForecastColdStart.java new file mode 100644 index 000000000..0070ad675 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/ForecastColdStart.java @@ -0,0 +1,161 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.amazon.randomcutforest.config.ForestMode; +import com.amazon.randomcutforest.config.Precision; +import com.amazon.randomcutforest.config.TransformMethod; +import com.amazon.randomcutforest.parkservices.RCFCaster; +import com.amazon.randomcutforest.parkservices.config.Calibration; + +public class ForecastColdStart extends + ModelColdStart { + + private static final Logger logger = LogManager.getLogger(ForecastColdStart.class); + + public ForecastColdStart( + Clock clock, + ThreadPool threadPool, + NodeStateManager nodeStateManager, + int rcfSampleSize, + int numberOfTrees, + int numMinSamples, + SearchFeatureDao searchFeatureDao, + double thresholdMinPvalue, + FeatureManager featureManager, + Duration modelTtl, + ForecastCheckpointWriteWorker checkpointWriteWorker, + int coolDownMinutes, + long rcfSeed, + int defaultTrainSamples, + int maxRoundofColdStart + ) { + // 1 means we sample all real data if possible + super( + modelTtl, + coolDownMinutes, + clock, + threadPool, + numMinSamples, + checkpointWriteWorker, + rcfSeed, + numberOfTrees, + rcfSampleSize, + thresholdMinPvalue, + nodeStateManager, + 1, + defaultTrainSamples, + searchFeatureDao, + featureManager, + maxRoundofColdStart, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + AnalysisType.FORECAST + ); + } + + @Override + protected List trainModelFromDataSegments( + List pointSamples, + Optional entity, + ModelState modelState, + Config config, + String taskId + ) { + if (pointSamples == null || pointSamples.size() == 0) { + logger.info("Return early since data points must not be empty."); + return null; + } + + double[] firstPoint = pointSamples.get(0).getValueList(); + if (firstPoint == null || firstPoint.length == 0) { + logger.info("Return early since data points must not be empty."); + return null; + } + + int shingleSize = config.getShingleSize(); + int forecastHorizon = ((Forecaster) config).getHorizon(); + int dimensions = firstPoint.length * shingleSize; + + RCFCaster.Builder casterBuilder = RCFCaster + .builder() + .dimensions(dimensions) + .numberOfTrees(numberOfTrees) + .shingleSize(shingleSize) + .sampleSize(rcfSampleSize) + .internalShinglingEnabled(true) + .precision(Precision.FLOAT_32) + .anomalyRate(1 - this.thresholdMinPvalue) + .outputAfter(Math.max(shingleSize, numMinSamples)) + .calibration(Calibration.MINIMAL) + .timeDecay(config.getTimeDecay()) + .parallelExecutionEnabled(false) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + // the following affects the moving average in many of the transformations + // the 0.02 corresponds to a half life of 1/0.02 = 50 observations + // this is different from the timeDecay() of RCF; however it is a similar + // concept + .transformDecay(config.getTimeDecay()) + .forecastHorizon(forecastHorizon) + .initialAcceptFraction(initialAcceptFraction) + // normalize transform is required to deal with trend change in forecasting + .transformMethod(TransformMethod.NORMALIZE) + // for forecasting, we don't support other mode + .forestMode(ForestMode.STANDARD); + + casterBuilder = applyImputationMethod(config, casterBuilder); + + if (rcfSeed > 0) { + casterBuilder.randomSeed(rcfSeed); + } + + RCFCaster caster = casterBuilder.build(); + + for (int i = 0; i < pointSamples.size(); i++) { + Sample dataSample = pointSamples.get(i); + double[] dataValue = dataSample.getValueList(); + caster.process(dataValue, dataSample.getDataEndTime().getEpochSecond()); + } + + modelState.setModel(caster); + modelState.setLastUsedTime(clock.instant()); + // save to checkpoint for real time cold start that has no taskId + if (null == taskId) { + checkpointWriteWorker.write(modelState, true, RequestPriority.MEDIUM); + } + return pointSamples; + } +} diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java b/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java new file mode 100644 index 000000000..438c9bdde --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java @@ -0,0 +1,66 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import java.time.Clock; +import java.util.Locale; + +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.ml.ModelManager; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; +import com.amazon.randomcutforest.parkservices.ForecastDescriptor; +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastModelManager extends + ModelManager { + + public ForecastModelManager( + ForecastCheckpointDao checkpointDao, + Clock clock, + int rcfNumTrees, + int rcfNumSamplesInTree, + int rcfNumMinSamples, + ForecastColdStart entityColdStarter, + MemoryTracker memoryTracker, + FeatureManager featureManager + ) { + super(rcfNumTrees, rcfNumSamplesInTree, rcfNumMinSamples, entityColdStarter, memoryTracker, clock, featureManager, checkpointDao); + } + + @Override + protected RCFCasterResult createEmptyResult() { + return new RCFCasterResult(null, 0, 0, 0); + } + + @Override + protected RCFCasterResult toResult(RandomCutForest forecast, RCFDescriptor castDescriptor) { + if (castDescriptor instanceof ForecastDescriptor) { + ForecastDescriptor forecastDescriptor = (ForecastDescriptor) castDescriptor; + // Use forecastDescriptor in the rest of your method + return new RCFCasterResult( + forecastDescriptor.getTimedForecast().rangeVector, + forecastDescriptor.getDataConfidence(), + forecast.getTotalUpdates(), + forecastDescriptor.getRCFScore() + ); + } else { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported type of AnomalyDescriptor : %s", castDescriptor)); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/ml/RCFCasterResult.java b/src/main/java/org/opensearch/forecast/ml/RCFCasterResult.java new file mode 100644 index 000000000..3584c7203 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/RCFCasterResult.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; + +import com.amazon.randomcutforest.returntypes.RangeVector; + +public class RCFCasterResult extends IntermediateResult { + private final RangeVector forecast; + private final double dataQuality; + + public RCFCasterResult(RangeVector forecast, double dataQuality, long totalUpdates, double rcfScore) { + super(totalUpdates, rcfScore); + this.forecast = forecast; + this.dataQuality = dataQuality; + } + + public RangeVector getForecast() { + return forecast; + } + + public double getDataQuality() { + return dataQuality; + } + + @Override + public List toIndexableResults( + Config forecaster, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + List featureData, + Optional entity, + Integer schemaVersion, + String modelId, + String taskId, + String error + ) { + if (forecast.values == null || forecast.values.length == 0) { + return Collections.emptyList(); + } + return ForecastResult + .fromRawRCFCasterResult( + forecaster.getId(), + forecaster.getIntervalInMilliseconds(), + dataQuality, + featureData, + dataStartInstant, + dataEndInstant, + executionStartInstant, + executionEndInstant, + error, + entity, + forecaster.getUser(), + schemaVersion, + modelId, + forecast.values, + forecast.upper, + forecast.lower, + taskId + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/model/FilterBy.java b/src/main/java/org/opensearch/forecast/model/FilterBy.java new file mode 100644 index 000000000..d2be61012 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/FilterBy.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +public enum FilterBy { + BUILD_IN_QUERY, + CUSTOM_QUERY +} diff --git a/src/main/java/org/opensearch/forecast/model/ForecastResult.java b/src/main/java/org/opensearch/forecast/model/ForecastResult.java index 1ce75ff63..34ff4da66 100644 --- a/src/main/java/org/opensearch/forecast/model/ForecastResult.java +++ b/src/main/java/org/opensearch/forecast/model/ForecastResult.java @@ -68,8 +68,9 @@ public class ForecastResult extends IndexableResult { private final Instant forecastDataEndTime; private final Integer horizonIndex; protected final Double dataQuality; + private final String entityId; - // used when indexing exception or error or an empty result + // used when indexing exception or error or a feature only result public ForecastResult( String forecasterId, String taskId, @@ -81,8 +82,7 @@ public ForecastResult( String error, Optional entity, User user, - Integer schemaVersion, - String modelId + Integer schemaVersion ) { this( forecasterId, @@ -97,7 +97,6 @@ public ForecastResult( entity, user, schemaVersion, - modelId, null, null, null, @@ -121,7 +120,6 @@ public ForecastResult( Optional entity, User user, Integer schemaVersion, - String modelId, String featureId, Float forecastValue, Float lowerBound, @@ -141,7 +139,6 @@ public ForecastResult( entity, user, schemaVersion, - modelId, taskId ); this.featureId = featureId; @@ -149,10 +146,11 @@ public ForecastResult( this.forecastValue = forecastValue; this.lowerBound = lowerBound; this.upperBound = upperBound; - this.confidenceIntervalWidth = lowerBound != null && upperBound != null ? Math.abs(upperBound - lowerBound) : Float.NaN; + this.confidenceIntervalWidth = safeAbsoluteDifference(lowerBound, upperBound); this.forecastDataStartTime = forecastDataStartTime; this.forecastDataEndTime = forecastDataEndTime; this.horizonIndex = horizonIndex; + this.entityId = getEntityId(entity, configId); } public static List fromRawRCFCasterResult( @@ -175,9 +173,13 @@ public static List fromRawRCFCasterResult( String taskId ) { int inputLength = featureData.size(); - int numberOfForecasts = forecastsValues.length / inputLength; + int numberOfForecasts = 0; + if (forecastsValues != null) { + numberOfForecasts = forecastsValues.length / inputLength; + } - List convertedForecastValues = new ArrayList<>(numberOfForecasts); + // +1 for actual value + List convertedForecastValues = new ArrayList<>(numberOfForecasts + 1); // store feature data and forecast value separately for easy query on feature data // we can join them using forecasterId, entityId, and executionStartTime/executionEndTime @@ -196,7 +198,6 @@ public static List fromRawRCFCasterResult( entity, user, schemaVersion, - modelId, null, null, null, @@ -219,22 +220,22 @@ public static List fromRawRCFCasterResult( taskId, dataQuality, null, - null, - null, + dataStartTime, + dataEndTime, executionStartTime, executionEndTime, error, entity, user, schemaVersion, - modelId, featureData.get(j).getFeatureId(), forecastsValues[k], forecastsLowers[k], forecastsUppers[k], forecastDataStartTime, forecastDataEndTime, - i + // horizon starts from 1 + i + 1 ) ); } @@ -255,6 +256,7 @@ public ForecastResult(StreamInput input) throws IOException { this.forecastDataStartTime = input.readOptionalInstant(); this.forecastDataEndTime = input.readOptionalInstant(); this.horizonIndex = input.readOptionalInt(); + this.entityId = input.readOptionalString(); } @Override @@ -286,14 +288,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(CommonName.ERROR_FIELD, error); } if (optionalEntity.isPresent()) { - xContentBuilder.field(CommonName.ENTITY_FIELD, optionalEntity.get()); + xContentBuilder.field(CommonName.ENTITY_KEY, optionalEntity.get()); } if (user != null) { xContentBuilder.field(CommonName.USER_FIELD, user); } - if (modelId != null) { - xContentBuilder.field(CommonName.MODEL_ID_FIELD, modelId); - } if (dataQuality != null && !dataQuality.isNaN()) { xContentBuilder.field(CommonName.DATA_QUALITY_FIELD, dataQuality); } @@ -312,13 +311,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (upperBound != null) { xContentBuilder.field(UPPER_BOUND_FIELD, upperBound); } + if (confidenceIntervalWidth != null) { + xContentBuilder.field(INTERVAL_WIDTH_FIELD, confidenceIntervalWidth); + } if (forecastDataStartTime != null) { xContentBuilder.field(FORECAST_DATA_START_TIME_FIELD, forecastDataStartTime.toEpochMilli()); } if (forecastDataEndTime != null) { xContentBuilder.field(FORECAST_DATA_END_TIME_FIELD, forecastDataEndTime.toEpochMilli()); } - if (horizonIndex != null) { + // the document with the actual value should not contain horizonIndex + // its horizonIndex is -1. Actual forecast value starts from horizon index 1 + if (horizonIndex != null && horizonIndex > 0) { xContentBuilder.field(HORIZON_INDEX_FIELD, horizonIndex); } if (featureId != null) { @@ -340,7 +344,6 @@ public static ForecastResult parse(XContentParser parser) throws IOException { Entity entity = null; User user = null; Integer schemaVersion = CommonValue.NO_SCHEMA_VERSION; - String modelId = null; String taskId = null; String featureId = null; @@ -385,7 +388,7 @@ public static ForecastResult parse(XContentParser parser) throws IOException { case CommonName.ERROR_FIELD: error = parser.text(); break; - case CommonName.ENTITY_FIELD: + case CommonName.ENTITY_KEY: entity = Entity.parse(parser); break; case CommonName.USER_FIELD: @@ -394,9 +397,6 @@ public static ForecastResult parse(XContentParser parser) throws IOException { case CommonName.SCHEMA_VERSION_FIELD: schemaVersion = parser.intValue(); break; - case CommonName.MODEL_ID_FIELD: - modelId = parser.text(); - break; case FEATURE_ID_FIELD: featureId = parser.text(); break; @@ -440,7 +440,6 @@ public static ForecastResult parse(XContentParser parser) throws IOException { Optional.ofNullable(entity), user, schemaVersion, - modelId, featureId, forecastValue, lowerBound, @@ -469,7 +468,8 @@ public boolean equals(Object o) { && Objects.equal(confidenceIntervalWidth, that.confidenceIntervalWidth) && Objects.equal(forecastDataStartTime, that.forecastDataStartTime) && Objects.equal(forecastDataEndTime, that.forecastDataEndTime) - && Objects.equal(horizonIndex, that.horizonIndex); + && Objects.equal(horizonIndex, that.horizonIndex) + && Objects.equal(entityId, that.entityId); } @Generated @@ -487,7 +487,8 @@ public int hashCode() { confidenceIntervalWidth, forecastDataStartTime, forecastDataEndTime, - horizonIndex + horizonIndex, + entityId ); return result; } @@ -507,6 +508,7 @@ public String toString() { .append("forecastDataStartTime", forecastDataStartTime) .append("forecastDataEndTime", forecastDataEndTime) .append("horizonIndex", horizonIndex) + .append("entityId", entityId) .toString(); } @@ -523,6 +525,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalInstant(forecastDataStartTime); out.writeOptionalInstant(forecastDataEndTime); out.writeOptionalInt(horizonIndex); + out.writeOptionalString(entityId); } public static ForecastResult getDummyResult() { @@ -537,8 +540,7 @@ public static ForecastResult getDummyResult() { null, Optional.empty(), null, - CommonValue.NO_SCHEMA_VERSION, - null + CommonValue.NO_SCHEMA_VERSION ); } @@ -587,4 +589,43 @@ public Instant getForecastDataEndTime() { public Integer getHorizonIndex() { return horizonIndex; } + + public String getEntityId() { + return entityId; + } + + /** + * Safely calculates the absolute difference between two Float values. + * + *

This method handles potential null values, as well as special Float values + * like NaN, Infinity, and -Infinity. If either of the input values is null, + * the method returns null. If the difference results in NaN or Infinity values, + * the method returns Float.MAX_VALUE. + * + *

Note: Float.MIN_VALUE is considered the smallest positive nonzero value + * of type float. The smallest negative value is -Float.MAX_VALUE. + * + * @param a The first Float value. + * @param b The second Float value. + * @return The absolute difference between the two values, or null if any input is null. + * If the result is NaN or Infinity, returns Float.MAX_VALUE. + */ + public Float safeAbsoluteDifference(Float a, Float b) { + // Check for null values + if (a == null || b == null) { + return null; // or throw an exception, or handle as per your requirements + } + + // Calculate the difference + float diff = a - b; + + // Check for special values + if (Float.isNaN(diff) || Float.isInfinite(diff)) { + return Float.MAX_VALUE; // or handle in any other way you see fit + } + + // Return the absolute difference + return Math.abs(diff); + } + } diff --git a/src/main/java/org/opensearch/forecast/model/ForecastResultBucket.java b/src/main/java/org/opensearch/forecast/model/ForecastResultBucket.java new file mode 100644 index 000000000..aa3dc21db --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/ForecastResultBucket.java @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import java.io.IOException; +import java.util.Map; + +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.annotation.Generated; + +import com.google.common.base.Objects; + +public class ForecastResultBucket implements ToXContentObject, Writeable { + public static final String BUCKETS_FIELD = "buckets"; + public static final String KEY_FIELD = "key"; + public static final String DOC_COUNT_FIELD = "doc_count"; + public static final String BUCKET_INDEX_FIELD = "bucket_index"; + + // e.g., "ip": "1.2.3.4" + private final Map key; + private final int docCount; + private final Map aggregations; + private final int bucketIndex; + + public ForecastResultBucket(Map key, int docCount, Map aggregations, int bucketIndex) { + this.key = key; + this.docCount = docCount; + this.aggregations = aggregations; + this.bucketIndex = bucketIndex; + } + + public ForecastResultBucket(StreamInput input) throws IOException { + this.key = input.readMap(); + this.docCount = input.readInt(); + this.aggregations = input.readMap(StreamInput::readString, StreamInput::readDouble); + this.bucketIndex = input.readInt(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(KEY_FIELD, key) + .field(DOC_COUNT_FIELD, docCount) + .field(BUCKET_INDEX_FIELD, bucketIndex); + + for (Map.Entry entry : aggregations.entrySet()) { + xContentBuilder.field(entry.getKey(), entry.getValue()); + } + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(key); + out.writeInt(docCount); + out.writeMap(aggregations, StreamOutput::writeString, StreamOutput::writeDouble); + out.writeInt(bucketIndex); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ForecastResultBucket that = (ForecastResultBucket) o; + return Objects.equal(key, that.getKey()) + && Objects.equal(docCount, that.getDocCount()) + && Objects.equal(aggregations, that.getAggregations()) + && Objects.equal(bucketIndex, that.getBucketIndex()); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(key, docCount, aggregations, bucketIndex); + } + + @Generated + @Override + public String toString() { + return new ToStringBuilder(this) + .append("key", key) + .append("docCount", docCount) + .append("aggregations", aggregations) + .append("bucketIndex", bucketIndex) + .toString(); + } + + public Map getKey() { + return key; + } + + public int getDocCount() { + return docCount; + } + + public Map getAggregations() { + return aggregations; + } + + public int getBucketIndex() { + return bucketIndex; + } +} diff --git a/src/main/java/org/opensearch/forecast/model/ForecastTask.java b/src/main/java/org/opensearch/forecast/model/ForecastTask.java index 4d7e889d7..4325e56b5 100644 --- a/src/main/java/org/opensearch/forecast/model/ForecastTask.java +++ b/src/main/java/org/opensearch/forecast/model/ForecastTask.java @@ -1,6 +1,17 @@ /* +<<<<<<< HEAD * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 +======= + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. +>>>>>>> f22eaa95 (test) */ package org.opensearch.forecast.model; @@ -128,8 +139,9 @@ public static Builder builder() { } @Override - public boolean isEntityTask() { - return ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY.name().equals(taskType); + public boolean isHistoricalEntityTask() { + // we have no backtesting + return false; } public static class Builder extends TimeSeriesTask.Builder { @@ -324,7 +336,10 @@ public static ForecastTask parse(XContentParser parser, String taskId) throws IO forecaster.getUser(), forecaster.getCustomResultIndex(), forecaster.getHorizon(), - forecaster.getImputationOption() + forecaster.getImputationOption(), + forecaster.getRecencyEmphasis(), + forecaster.getSeasonIntervals(), + forecaster.getHistoryIntervals() ); return new Builder() .taskId(parsedTaskId) diff --git a/src/main/java/org/opensearch/forecast/model/ForecastTaskProfile.java b/src/main/java/org/opensearch/forecast/model/ForecastTaskProfile.java new file mode 100644 index 000000000..fbde0b7d5 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/ForecastTaskProfile.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.TaskProfile; + +public class ForecastTaskProfile extends TaskProfile { + + public ForecastTaskProfile( + ForecastTask forecastTask, + Integer shingleSize, + Long rcfTotalUpdates, + Long modelSizeInBytes, + String nodeId, + String taskId, + String taskType + ) { + super(forecastTask, shingleSize, rcfTotalUpdates, modelSizeInBytes, nodeId, taskId, taskType); + } + + public ForecastTaskProfile(StreamInput input) throws IOException { + if (input.readBoolean()) { + this.task = new ForecastTask(input); + } else { + this.task = null; + } + this.shingleSize = input.readOptionalInt(); + this.rcfTotalUpdates = input.readOptionalLong(); + this.modelSizeInBytes = input.readOptionalLong(); + this.nodeId = input.readOptionalString(); + this.taskId = input.readOptionalString(); + this.taskType = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (task != null) { + out.writeBoolean(true); + task.writeTo(out); + } else { + out.writeBoolean(false); + } + + out.writeOptionalInt(shingleSize); + out.writeOptionalLong(rcfTotalUpdates); + out.writeOptionalLong(modelSizeInBytes); + out.writeOptionalString(nodeId); + out.writeOptionalString(taskId); + out.writeOptionalString(taskType); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + super.toXContent(xContentBuilder); + return xContentBuilder.endObject(); + } + + public static ForecastTaskProfile parse(XContentParser parser) throws IOException { + ForecastTask forecastTask = null; + Integer shingleSize = null; + Long rcfTotalUpdates = null; + Long modelSizeInBytes = null; + String nodeId = null; + String taskId = null; + String taskType = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case ForecastCommonName.FORECAST_TASK: + forecastTask = ForecastTask.parse(parser); + break; + case SHINGLE_SIZE_FIELD: + shingleSize = parser.intValue(); + break; + case RCF_TOTAL_UPDATES_FIELD: + rcfTotalUpdates = parser.longValue(); + break; + case MODEL_SIZE_IN_BYTES: + modelSizeInBytes = parser.longValue(); + break; + case NODE_ID_FIELD: + nodeId = parser.text(); + break; + case TASK_ID_FIELD: + taskId = parser.text(); + break; + case TASK_TYPE_FIELD: + taskType = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + return new ForecastTaskProfile(forecastTask, shingleSize, rcfTotalUpdates, modelSizeInBytes, nodeId, taskId, taskType); + } + + @Override + protected String getTaskFieldName() { + return ForecastCommonName.FORECAST_TASK; + } +} diff --git a/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java b/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java index 76e1aac88..16cbf902a 100644 --- a/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java +++ b/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java @@ -25,45 +25,33 @@ * to single-stream forecasting, and two tasks for HC, one at the forecaster level and another at the entity level. * * Real-time forecasting: - * - FORECAST_REALTIME_SINGLE_STREAM: Represents a task type for single-stream forecasting. Ideal for scenarios where a single + * - REALTIME_FORECAST_SINGLE_STREAM: Represents a task type for single-stream forecasting. Ideal for scenarios where a single * time series is processed in real-time. - * - FORECAST_REALTIME_HC_FORECASTER: Represents a task type for high cardinality (HC) forecasting. Used when dealing with a + * - REALTIME_FORECAST_HC_FORECASTER: Represents a task type for high cardinality (HC) forecasting. Used when dealing with a * large number of distinct entities in real-time. * - * Historical forecasting: - * - FORECAST_HISTORICAL_SINGLE_STREAM: Represents a forecaster-level task for single-stream historical forecasting. - * Suitable for analyzing a single time series in a sequential manner. - * - FORECAST_HISTORICAL_HC_FORECASTER: A forecaster-level task to track overall state, initialization progress, errors, etc., - * for HC forecasting. Central to managing multiple historical time series with high cardinality. - * - FORECAST_HISTORICAL_HC_ENTITY: An entity-level task to track the state, initialization progress, errors, etc., of a - * specific entity within HC historical forecasting. Allows for fine-grained information recording at the entity level. + * Run once forecasting: + * - RUN_ONCE_FORECAST_SINGLE_STREAM: forecast once in single-stream scenario. + * - RUN_ONCE_FORECAST_HC_FORECASTER: forecast once in HC scenario. + * + * enum names need to start with REALTIME or HISTORICAL we use prefix in TaskManager to check if a task is of certain type (e.g., historical) * */ public enum ForecastTaskType implements TaskType { - FORECAST_REALTIME_SINGLE_STREAM, - FORECAST_REALTIME_HC_FORECASTER, - FORECAST_HISTORICAL_SINGLE_STREAM, - // forecaster level task to track overall state, init progress, error etc. for HC forecaster - FORECAST_HISTORICAL_HC_FORECASTER, - // entity level task to track just one specific entity's state, init progress, error etc. - FORECAST_HISTORICAL_HC_ENTITY; + REALTIME_FORECAST_SINGLE_STREAM, + REALTIME_FORECAST_HC_FORECASTER, + RUN_ONCE_FORECAST_SINGLE_STREAM, + RUN_ONCE_FORECAST_HC_FORECASTER; - public static List HISTORICAL_FORECASTER_TASK_TYPES = ImmutableList - .of(ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM); - public static List ALL_HISTORICAL_TASK_TYPES = ImmutableList - .of( - ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, - ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY - ); public static List REALTIME_TASK_TYPES = ImmutableList - .of(ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER); + .of(ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM, ForecastTaskType.REALTIME_FORECAST_HC_FORECASTER); public static List ALL_FORECAST_TASK_TYPES = ImmutableList .of( - ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, - ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, - ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY + REALTIME_FORECAST_SINGLE_STREAM, + REALTIME_FORECAST_HC_FORECASTER, + RUN_ONCE_FORECAST_SINGLE_STREAM, + RUN_ONCE_FORECAST_HC_FORECASTER ); + public static List RUN_ONCE_TASK_TYPES = ImmutableList + .of(ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM, ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER); } diff --git a/src/main/java/org/opensearch/forecast/model/Forecaster.java b/src/main/java/org/opensearch/forecast/model/Forecaster.java index c572c28db..c56c3d40c 100644 --- a/src/main/java/org/opensearch/forecast/model/Forecaster.java +++ b/src/main/java/org/opensearch/forecast/model/Forecaster.java @@ -28,6 +28,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.forecast.constant.ForecastCommonMessages; import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.forecast.settings.ForecastSettings; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.timeseries.common.exception.ValidationException; @@ -38,6 +39,7 @@ import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ShingleGetter; import org.opensearch.timeseries.model.TimeConfiguration; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; @@ -53,6 +55,47 @@ * AnomalyDetector's constructor because detection interval cannot be null. */ public class Forecaster extends Config { + static class ForecastShingleGetter implements ShingleGetter { + private Integer seasonIntervals; + private Integer horizon; + + public ForecastShingleGetter(Integer seasonIntervals, Integer horizon) { + this.seasonIntervals = seasonIntervals; + this.horizon = horizon; + } + + /** + * If the given shingle size is not null, return given shingle size; + * if seasonality or horizon is not null, return max(seasonality hint / 2, horizon / 3); + * otherwise, return default shingle size. + * + * @param customShingleSize Given shingle size + * @return Shingle size + */ + @Override + public Integer getShingleSize(Integer customShingleSize) { + // Return customShingleSize if not null + if (customShingleSize != null) { + return customShingleSize; + } + + // Initialize candidate with the default value + int candidate = TimeSeriesSettings.DEFAULT_SHINGLE_SIZE; + + // Update candidate if seasonIntervals is not null and its half is greater + if (seasonIntervals != null) { + candidate = Math.max(candidate, seasonIntervals / 2); + } + + // Update candidate if horizon is not null and its third is greater + if (horizon != null) { + candidate = Math.max(candidate, horizon / 3); + } + + return candidate; + } + } + public static final String FORECAST_PARSE_FIELD_NAME = "Forecaster"; public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( Forecaster.class, @@ -85,7 +128,10 @@ public Forecaster( User user, String resultIndex, Integer horizon, - ImputationOption imputationOption + ImputationOption imputationOption, + Integer recencyEmphasis, + Integer seasonIntervals, + Integer historyIntervals ) { super( forecasterId, @@ -105,36 +151,60 @@ public Forecaster( user, resultIndex, forecastInterval, - imputationOption + imputationOption, + recencyEmphasis, + seasonIntervals, + new ForecastShingleGetter(seasonIntervals, horizon), + historyIntervals ); checkAndThrowValidationErrors(ValidationAspect.FORECASTER); if (forecastInterval == null) { - errorMessage = ForecastCommonMessages.NULL_FORECAST_INTERVAL; - issueType = ValidationIssueType.FORECAST_INTERVAL; + throw new ValidationException( + ForecastCommonMessages.NULL_FORECAST_INTERVAL, + ValidationIssueType.FORECAST_INTERVAL, + ValidationAspect.FORECASTER + ); } else if (((IntervalTimeConfiguration) forecastInterval).getInterval() <= 0) { - errorMessage = ForecastCommonMessages.INVALID_FORECAST_INTERVAL; - issueType = ValidationIssueType.FORECAST_INTERVAL; + throw new ValidationException( + ForecastCommonMessages.INVALID_FORECAST_INTERVAL, + ValidationIssueType.FORECAST_INTERVAL, + ValidationAspect.FORECASTER + ); } int maxCategoryFields = ForecastNumericSetting.maxCategoricalFields(); if (categoryFields != null && categoryFields.size() > maxCategoryFields) { - errorMessage = CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields); - issueType = ValidationIssueType.CATEGORY; + throw new ValidationException( + CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields), + ValidationIssueType.CATEGORY, + ValidationAspect.FORECASTER + ); } if (invalidHorizon(horizon)) { - errorMessage = "Horizon size must be a positive integer no larger than " - + TimeSeriesSettings.MAX_SHINGLE_SIZE * DEFAULT_HORIZON_SHINGLE_RATIO - + ". Got " - + horizon; - issueType = ValidationIssueType.SHINGLE_SIZE_FIELD; + throw new ValidationException( + "Horizon size must be a positive integer no larger than " + + TimeSeriesSettings.MAX_SHINGLE_SIZE * DEFAULT_HORIZON_SHINGLE_RATIO + + ". Got " + + horizon, + ValidationIssueType.HORIZON_SIZE, + ValidationAspect.FORECASTER + ); } - checkAndThrowValidationErrors(ValidationAspect.FORECASTER); + // 4 comes from Preprocessor.isForecastReasonable + // we have already assigned this.shingleSize in super class + if (this.shingleSize < 4) { + throw new ValidationException( + "Shingle size must be no less than " + ForecastSettings.MINIMUM_SHINLE_SIZE + ". Got " + shingleSize, + ValidationIssueType.SHINGLE_SIZE_FIELD, + ValidationAspect.FORECASTER + ); + } - this.horizon = horizon; + this.horizon = horizon == null ? suggestHorizon() : horizon; } public Forecaster(StreamInput input) throws IOException { @@ -220,7 +290,10 @@ public static Forecaster parse( List categoryField = null; Integer horizon = null; - ImputationOption interpolationOption = null; + ImputationOption imputationOption = null; + Integer recencyEmphasis = null; + Integer seasonality = null; + Integer historyIntervals = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -329,7 +402,16 @@ public static Forecaster parse( horizon = parser.intValue(); break; case IMPUTATION_OPTION_FIELD: - interpolationOption = ImputationOption.parse(parser); + imputationOption = ImputationOption.parse(parser); + break; + case RECENCY_EMPHASIS_FIELD: + recencyEmphasis = parser.intValue(); + break; + case SEASONALITY_FIELD: + seasonality = parser.currentToken() == XContentParser.Token.VALUE_NULL ? null : parser.intValue(); + break; + case HISTORY_INTERVAL_FIELD: + historyIntervals = parser.intValue(); break; default: parser.skipChildren(); @@ -347,7 +429,7 @@ public static Forecaster parse( filterQuery, forecastInterval, windowDelay, - getShingleSize(shingleSize), + shingleSize, uiMetadata, schemaVersion, lastUpdateTime, @@ -355,7 +437,10 @@ public static Forecaster parse( user, resultIndex, horizon, - interpolationOption + imputationOption, + recencyEmphasis, + seasonality, + historyIntervals ); return forecaster; } @@ -402,4 +487,8 @@ protected ValidationAspect getConfigValidationAspect() { public Integer getHorizon() { return horizon; } + + public Integer suggestHorizon() { + return this.shingleSize * DEFAULT_HORIZON_SHINGLE_RATIO; + } } diff --git a/src/main/java/org/opensearch/forecast/model/ForecasterProfile.java b/src/main/java/org/opensearch/forecast/model/ForecasterProfile.java new file mode 100644 index 000000000..809409d0d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/ForecasterProfile.java @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.model; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.model.ConfigProfile; + +public class ForecasterProfile extends ConfigProfile { + + public static class Builder extends ConfigProfile.Builder { + private ForecastTaskProfile forecastTaskProfile; + + public Builder() {} + + @Override + public Builder taskProfile(ForecastTaskProfile forecastTaskProfile) { + this.forecastTaskProfile = forecastTaskProfile; + return this; + } + + @Override + public ForecasterProfile build() { + ForecasterProfile profile = new ForecasterProfile(); + profile.state = state; + profile.error = error; + profile.modelProfile = modelProfile; + profile.modelCount = modelCount; + profile.shingleSize = shingleSize; + profile.coordinatingNode = coordinatingNode; + profile.totalSizeInBytes = totalSizeInBytes; + profile.initProgress = initProgress; + profile.totalEntities = totalEntities; + profile.activeEntities = activeEntities; + profile.taskProfile = forecastTaskProfile; + + return profile; + } + } + + public ForecasterProfile() {} + + public ForecasterProfile(StreamInput in) throws IOException { + super(in); + } + + @Override + protected ForecastTaskProfile createTaskProfile(StreamInput in) throws IOException { + return new ForecastTaskProfile(in); + } + + @Override + protected String getTaskFieldName() { + return ForecastCommonName.FORECAST_TASK; + } +} diff --git a/src/main/java/org/opensearch/forecast/model/Order.java b/src/main/java/org/opensearch/forecast/model/Order.java new file mode 100644 index 000000000..6471bd75a --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/Order.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +public enum Order { + ASC, + DESC +} diff --git a/src/main/java/org/opensearch/forecast/model/Subaggregation.java b/src/main/java/org/opensearch/forecast/model/Subaggregation.java new file mode 100644 index 000000000..376b0226b --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/Subaggregation.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +public class Subaggregation implements Writeable, ToXContentObject { + private static final String AGGREGATION_QUERY = "aggregation_query"; + private static final String ORDER = "order"; + + private final AggregationBuilder aggregation; + private final Order order; + + public Subaggregation(AggregationBuilder aggregation, Order order) { + super(); + this.aggregation = aggregation; + this.order = order; + } + + public Subaggregation(StreamInput input) throws IOException { + this.aggregation = input.readNamedWriteable(AggregationBuilder.class); + this.order = input.readEnum(Order.class); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(ORDER, order.name()) + .field(AGGREGATION_QUERY) + .startObject() + .value(aggregation) + .endObject(); + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(aggregation); + out.writeEnum(order); + } + + /** + * Parse raw json content into Subaggregation instance. + * + * @param parser json based content parser + * @return feature instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static Subaggregation parse(XContentParser parser) throws IOException { + Order order = Order.ASC; + AggregationBuilder aggregation = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + + parser.nextToken(); + switch (fieldName) { + case ORDER: + order = Order.valueOf(parser.text()); + break; + case AGGREGATION_QUERY: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + aggregation = ParseUtils.toAggregationBuilder(parser); + break; + default: + break; + } + } + return new Subaggregation(aggregation, order); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + Subaggregation feature = (Subaggregation) o; + return Objects.equal(order, feature.getOrder()) && Objects.equal(aggregation, feature.getAggregation()); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(aggregation, order); + } + + public AggregationBuilder getAggregation() { + return aggregation; + } + + public Order getOrder() { + return order; + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointMaintainWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointMaintainWorker.java new file mode 100644 index 000000000..1ebba3fe1 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointMaintainWorker.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS; + +import java.time.Clock; +import java.time.Duration; +import java.util.Optional; +import java.util.Random; +import java.util.function.Function; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteRequest; +import org.opensearch.timeseries.ratelimit.RateLimitedRequestWorker; + +public class ForecastCheckpointMaintainWorker extends CheckpointMaintainWorker { + public static final String WORKER_NAME = "forecast-checkpoint-maintain"; + + public ForecastCheckpointMaintainWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + RateLimitedRequestWorker targetQueue, + Duration stateTtl, + NodeStateManager nodeStateManager, + Function> converter + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + targetQueue, + stateTtl, + nodeStateManager, + converter, + AnalysisType.FORECAST + ); + + this.batchSize = FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, it -> this.batchSize = it); + + this.expectedExecutionTimeInMilliSecsPerRequest = ForecastSettings.FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS + .get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer( + FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, + it -> this.expectedExecutionTimeInMilliSecsPerRequest = it + ); + } + +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java new file mode 100644 index 000000000..5bbcb4e1e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; +import org.opensearch.timeseries.stats.StatNames; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCheckpointReadWorker extends + CheckpointReadWorker { + public static final String WORKER_NAME = "forecast-checkpoint-read"; + + public ForecastCheckpointReadWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastModelManager modelManager, + ForecastCheckpointDao checkpointDao, + ForecastColdStartWorker entityColdStartQueue, + NodeStateManager stateManager, + ForecastIndexManagement indexUtil, + Provider cacheProvider, + Duration stateTtl, + ForecastCheckpointWriteWorker checkpointWriteQueue, + ForecastStats forecastStats, + ForecastSaveResultStrategy saveResultStrategy + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + executionTtl, + modelManager, + checkpointDao, + entityColdStartQueue, + stateManager, + indexUtil, + cacheProvider, + stateTtl, + checkpointWriteQueue, + forecastStats, + FORECAST_CHECKPOINT_READ_QUEUE_CONCURRENCY, + FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + ForecastCommonName.FORECAST_CHECKPOINT_INDEX_NAME, + StatNames.FORECAST_MODEL_CORRUTPION_COUNT, + AnalysisType.FORECAST, + saveResultStrategy + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointWriteWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointWriteWorker.java new file mode 100644 index 000000000..6a238ee58 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointWriteWorker.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCheckpointWriteWorker extends + CheckpointWriteWorker { + public static final String WORKER_NAME = "forecast-checkpoint-write"; + + public ForecastCheckpointWriteWorker( + long heapSize, + int singleRequestSize, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastCheckpointDao checkpoint, + String indexName, + Duration checkpointInterval, + NodeStateManager timeSeriesNodeStateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSize, + singleRequestSize, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + FORECAST_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + timeSeriesNodeStateManager, + checkpoint, + indexName, + checkpointInterval, + AnalysisType.FORECAST + ); + } + +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java new file mode 100644 index 000000000..43831f8df --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.ColdEntityWorker; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +/** + * A queue slowly releasing low-priority requests to CheckpointReadQueue + * + * ColdEntityQueue is a queue to absorb cold entities. Like hot entities, we load a cold + * entity's model checkpoint from disk, train models if the checkpoint is not found, + * query for missed features to complete a shingle, use the models to check whether + * the incoming feature is normal, update models, and save the detection results to disks.  + * Implementation-wise, we reuse the queues we have developed for hot entities. + * The differences are: we process hot entities as long as resources (e.g., AD + * thread pool has availability) are available, while we release cold entity requests + * to other queues at a slow controlled pace. Also, cold entity requests' priority is low. + * So only when there are no hot entity requests to process are we going to process cold + * entity requests.  + * + */ +public class ForecastColdEntityWorker extends + ColdEntityWorker { + public static final String WORKER_NAME = "forecast-cold-entity"; + + public ForecastColdEntityWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService forecastCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + ForecastCheckpointReadWorker checkpointReadQueue, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + forecastCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + checkpointReadQueue, + stateTtl, + nodeStateManager, + FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + FORECAST_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnalysisType.FORECAST + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdStartWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdStartWorker.java new file mode 100644 index 000000000..971a8cdbb --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdStartWorker.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_COLD_START_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayDeque; +import java.util.Optional; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.ColdStartWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.util.ParseUtils; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastColdStartWorker extends + ColdStartWorker { + public static final String WORKER_NAME = "forecast-hc-cold-start"; + + public ForecastColdStartWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService circuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastColdStart coldStarter, + Duration stateTtl, + NodeStateManager nodeStateManager, + ForecastPriorityCache cacheProvider, + ForecastModelManager forecastModelManager, + ForecastSaveResultStrategy saveStrategy + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + circuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + FORECAST_COLD_START_QUEUE_CONCURRENCY, + executionTtl, + coldStarter, + stateTtl, + nodeStateManager, + cacheProvider, + AnalysisType.FORECAST, + forecastModelManager, + saveStrategy + ); + } + + @Override + protected ModelState createEmptyState(FeatureRequest coldStartRequest, String modelId, String configId) { + return new ModelState( + null, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + coldStartRequest.getEntity(), + new ArrayDeque<>() + ); + } + + @Override + protected ForecastResult createIndexableResult(Config config, String taskId, String modelId, Sample entry, Optional entity) { + return new ForecastResult( + config.getId(), + taskId, + ParseUtils.getFeatureData(entry.getValueList(), config), + entry.getDataStartTime(), + entry.getDataEndTime(), + Instant.now(), + Instant.now(), + "", + entity, + config.getUser(), + config.getSchemaVersion() + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteRequest.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteRequest.java new file mode 100644 index 000000000..54c33f5bb --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteRequest.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ratelimit; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; + +public class ForecastResultWriteRequest extends ResultWriteRequest { + + public ForecastResultWriteRequest( + long expirationEpochMs, + String forecasterId, + RequestPriority priority, + ForecastResult result, + String resultIndex + ) { + super(expirationEpochMs, forecasterId, priority, result, resultIndex); + } + + public ForecastResultWriteRequest(StreamInput in) throws IOException { + super(in, ForecastResult::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteWorker.java new file mode 100644 index 000000000..7f991bcf6 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteWorker.java @@ -0,0 +1,110 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.transport.ForecastResultBulkRequest; +import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteWorker; + +public class ForecastResultWriteWorker extends + ResultWriteWorker { + public static final String WORKER_NAME = "forecast-result-write"; + + public ForecastResultWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastIndexMemoryPressureAwareResultHandler resultHandler, + NamedXContentRegistry xContentRegistry, + NodeStateManager stateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + FORECAST_RESULT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + FORECAST_RESULT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + stateManager, + resultHandler, + xContentRegistry, + ForecastResult::parse, + AnalysisType.FORECAST + ); + } + + @Override + protected ForecastResultBulkRequest toBatchRequest(List toProcess) { + final ForecastResultBulkRequest bulkRequest = new ForecastResultBulkRequest(); + for (ForecastResultWriteRequest request : toProcess) { + bulkRequest.add(request); + } + return bulkRequest; + } + + @Override + protected ForecastResultWriteRequest createResultWriteRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + ForecastResult result, + String resultIndex + ) { + return new ForecastResultWriteRequest(expirationEpochMs, configId, priority, result, resultIndex); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastSaveResultStrategy.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastSaveResultStrategy.java new file mode 100644 index 000000000..1dc3029a0 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastSaveResultStrategy.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import java.time.Instant; +import java.util.List; +import java.util.Optional; + +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.SaveResultStrategy; +import org.opensearch.timeseries.util.ParseUtils; + +public class ForecastSaveResultStrategy implements SaveResultStrategy { + private int resultMappingVersion; + private ForecastResultWriteWorker resultWriteWorker; + + public ForecastSaveResultStrategy(int resultMappingVersion, ForecastResultWriteWorker resultWriteWorker) { + this.resultMappingVersion = resultMappingVersion; + this.resultWriteWorker = resultWriteWorker; + } + + @Override + public void saveResult(RCFCasterResult result, Config config, FeatureRequest origRequest, String modelId) { + saveResult( + result, + config, + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds()), + modelId, + origRequest.getCurrentFeature(), + origRequest.getEntity(), + origRequest.getTaskId() + ); + } + + @Override + public void saveResult( + RCFCasterResult result, + Config config, + Instant dataStart, + Instant dataEnd, + String modelId, + double[] currentData, + Optional entity, + String taskId + ) { + if (result != null && result.getRcfScore() > 0) { + List indexableResults = result + .toIndexableResults( + config, + dataStart, + dataEnd, + Instant.now(), + Instant.now(), + ParseUtils.getFeatureData(currentData, config), + entity, + resultMappingVersion, + modelId, + taskId, + null + ); + + for (ForecastResult r : indexableResults) { + saveResult(r, config); + } + } + } + + @Override + public void saveResult(ForecastResult result, Config config) { + resultWriteWorker + .put( + new ForecastResultWriteRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + config.getId(), + RequestPriority.MEDIUM, + result, + config.getCustomResultIndex() + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/AbstractForecastSearchAction.java b/src/main/java/org/opensearch/forecast/rest/AbstractForecastSearchAction.java new file mode 100644 index 000000000..0146981ca --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/AbstractForecastSearchAction.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest; + +import java.util.List; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.timeseries.AbstractSearchAction; + +public abstract class AbstractForecastSearchAction extends AbstractSearchAction { + + public AbstractForecastSearchAction( + List urlPaths, + List> deprecatedPaths, + String index, + Class clazz, + ActionType actionType + ) { + super( + urlPaths, + deprecatedPaths, + index, + clazz, + actionType, + ForecastEnabledSetting::isForecastEnabled, + ForecastCommonMessages.DISABLED_ERR_MSG + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/AbstractForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/AbstractForecasterAction.java new file mode 100644 index 000000000..bac785c79 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/AbstractForecasterAction.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_INTERVAL; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_REQUEST_TIMEOUT; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_WINDOW_DELAY; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_FORECAST_FEATURES; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_HC_FORECASTERS; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_SINGLE_STREAM_FORECASTERS; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.rest.BaseRestHandler; + +/** + * This class consists of the base class for validating and indexing forecast REST handlers. + */ +public abstract class AbstractForecasterAction extends BaseRestHandler { + /** + * Timeout duration for the forecast request. + */ + protected volatile TimeValue requestTimeout; + + /** + * Interval at which forecasts are generated. + */ + protected volatile TimeValue forecastInterval; + + /** + * Delay duration before the forecast window begins. + */ + protected volatile TimeValue forecastWindowDelay; + + /** + * Maximum number of single stream forecasters allowed. + */ + protected volatile Integer maxSingleStreamForecasters; + + /** + * Maximum number of high-cardinality (HC) forecasters allowed. + */ + protected volatile Integer maxHCForecasters; + + /** + * Maximum number of features to be used for forecasting. + */ + protected volatile Integer maxForecastFeatures; + + /** + * Maximum number of categorical fields allowed. + */ + protected volatile Integer maxCategoricalFields; + + /** + * Constructor for the base class for validating and indexing forecast REST handlers. + * + * @param settings Settings for the forecast plugin. + * @param clusterService Cluster service. + */ + public AbstractForecasterAction(Settings settings, ClusterService clusterService) { + this.requestTimeout = FORECAST_REQUEST_TIMEOUT.get(settings); + this.forecastInterval = FORECAST_INTERVAL.get(settings); + this.forecastWindowDelay = FORECAST_WINDOW_DELAY.get(settings); + this.maxSingleStreamForecasters = MAX_SINGLE_STREAM_FORECASTERS.get(settings); + this.maxHCForecasters = MAX_HC_FORECASTERS.get(settings); + this.maxForecastFeatures = MAX_FORECAST_FEATURES; + this.maxCategoricalFields = ForecastNumericSetting.maxCategoricalFields(); + // TODO: will add more cluster setting consumer later + // TODO: inject ClusterSettings only if clusterService is only used to get ClusterSettings + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_REQUEST_TIMEOUT, it -> requestTimeout = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_INTERVAL, it -> forecastInterval = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_WINDOW_DELAY, it -> forecastWindowDelay = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_SINGLE_STREAM_FORECASTERS, it -> maxSingleStreamForecasters = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_HC_FORECASTERS, it -> maxHCForecasters = it); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/ForecasterExecutionInput.java b/src/main/java/org/opensearch/forecast/rest/ForecasterExecutionInput.java new file mode 100644 index 000000000..9ba626fcd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/ForecasterExecutionInput.java @@ -0,0 +1,141 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +/** + * Input data needed to trigger forecaster. + */ +public class ForecasterExecutionInput implements ToXContentObject { + + private static final String FORECASTER_ID_FIELD = "forecaster_id"; + private static final String PERIOD_START_FIELD = "period_start"; + private static final String PERIOD_END_FIELD = "period_end"; + private static final String FORECASTER_FIELD = "forecaster"; + private Instant periodStart; + private Instant periodEnd; + private String forecasterId; + private Forecaster forecaster; + + public ForecasterExecutionInput(String forecasterId, Instant periodStart, Instant periodEnd, Forecaster forecaster) { + this.periodStart = periodStart; + this.periodEnd = periodEnd; + this.forecasterId = forecasterId; + this.forecaster = forecaster; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(FORECASTER_ID_FIELD, forecasterId) + .field(PERIOD_START_FIELD, periodStart.toEpochMilli()) + .field(PERIOD_END_FIELD, periodEnd.toEpochMilli()) + .field(FORECASTER_FIELD, forecaster); + return xContentBuilder.endObject(); + } + + public static ForecasterExecutionInput parse(XContentParser parser) throws IOException { + return parse(parser, null); + } + + public static ForecasterExecutionInput parse(XContentParser parser, String inputConfigId) throws IOException { + Instant periodStart = null; + Instant periodEnd = null; + Forecaster forecaster = null; + String forecasterId = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case FORECASTER_ID_FIELD: + forecasterId = parser.text(); + break; + case PERIOD_START_FIELD: + periodStart = ParseUtils.toInstant(parser); + break; + case PERIOD_END_FIELD: + periodEnd = ParseUtils.toInstant(parser); + break; + case FORECASTER_FIELD: + XContentParser.Token token = parser.currentToken(); + if (parser.currentToken().equals(XContentParser.Token.START_OBJECT)) { + forecaster = Forecaster.parse(parser, forecasterId); + } + break; + default: + break; + } + } + if (!Strings.isNullOrEmpty(inputConfigId)) { + forecasterId = inputConfigId; + } + return new ForecasterExecutionInput(forecasterId, periodStart, periodEnd, forecaster); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ForecasterExecutionInput that = (ForecasterExecutionInput) o; + return Objects.equal(periodStart, that.periodStart) + && Objects.equal(periodEnd, that.periodEnd) + && Objects.equal(forecasterId, that.forecasterId) + && Objects.equal(forecaster, that.forecaster); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(periodStart, periodEnd, forecasterId); + } + + public Instant getPeriodStart() { + return periodStart; + } + + public Instant getPeriodEnd() { + return periodEnd; + } + + public String getForecasterId() { + return forecasterId; + } + + public void setForecasterId(String forecasterId) { + this.forecasterId = forecasterId; + } + + public Forecaster getForecaster() { + return forecaster; + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestDeleteForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestDeleteForecasterAction.java new file mode 100644 index 000000000..54be0bb31 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestDeleteForecasterAction.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.DeleteForecasterAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.DeleteConfigRequest; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +public class RestDeleteForecasterAction extends BaseRestHandler { + public static final String DELETE_FORECASTER_ACTION = "delete_forecaster"; + + public RestDeleteForecasterAction() {} + + @Override + public String getName() { + return DELETE_FORECASTER_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterId = request.param(FORECASTER_ID); + DeleteConfigRequest deleteForecasterRequest = new DeleteConfigRequest(forecasterId); + return channel -> client + .execute(DeleteForecasterAction.INSTANCE, deleteForecasterRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + + } + + @Override + public List routes() { + return ImmutableList + .of( + // delete forecaster document + new Route( + RestRequest.Method.DELETE, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestForecasterJobAction.java b/src/main/java/org/opensearch/forecast/rest/RestForecasterJobAction.java new file mode 100644 index 000000000..a5f98829b --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestForecasterJobAction.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; +import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.ForecasterJobAction; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.rest.RestJobAction; +import org.opensearch.timeseries.transport.JobRequest; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +public class RestForecasterJobAction extends RestJobAction { + public static final String FORECAST_JOB_ACTION = "forecaster_job_action"; + + @Override + public String getName() { + return FORECAST_JOB_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterId = request.param(FORECASTER_ID); + String rawPath = request.rawPath(); + DateRange dateRange = parseInputDateRange(request); + + // false means we don't support backtesting and thus no need to stop backtesting + JobRequest forecasterJobRequest = new JobRequest(forecasterId, dateRange, false, rawPath); + + return channel -> client.execute(ForecasterJobAction.INSTANCE, forecasterJobRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + + } + + @Override + public List routes() { + return ImmutableList + .of( + /// start forecaster Job + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID, START_JOB) + ), + /// stop forecaster Job + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID, STOP_JOB) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestForecasterSuggestAction.java b/src/main/java/org/opensearch/forecast/rest/RestForecasterSuggestAction.java new file mode 100644 index 000000000..16a06336e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestForecasterSuggestAction.java @@ -0,0 +1,122 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_REQUEST_TIMEOUT; +import static org.opensearch.timeseries.util.RestHandlerUtils.SUGGEST; +import static org.opensearch.timeseries.util.RestHandlerUtils.TYPE; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.TimeUnit; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.ValidationException; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.SuggestForecasterParamAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.SuggestConfigParamRequest; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to validate anomaly detector configurations. + */ +public class RestForecasterSuggestAction extends BaseRestHandler { + private static final String FORECASTER_SUGGEST_ACTION = "forecaster_suggest_action"; + + private volatile TimeValue requestTimeout; + + public RestForecasterSuggestAction(Settings settings, ClusterService clusterService) { + this.requestTimeout = FORECAST_REQUEST_TIMEOUT.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_REQUEST_TIMEOUT, it -> requestTimeout = it); + } + + @Override + public String getName() { + return FORECASTER_SUGGEST_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, SUGGEST, TYPE) + ) + ); + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + // we have to get the param from a subclass of BaseRestHandler. Otherwise, we cannot parse the type out of request params + String typesStr = request.param(TYPE); + + Forecaster config = parseConfig(parser); + + if (config != null) { + return channel -> { + SuggestConfigParamRequest suggestForecasterParamRequest = new SuggestConfigParamRequest( + AnalysisType.FORECAST, + config, + typesStr, + requestTimeout + ); + client + .execute( + SuggestForecasterParamAction.INSTANCE, + suggestForecasterParamRequest, + new RestToXContentListener<>(channel) + ); + }; + } else { + ValidationException validationException = new ValidationException(); + validationException.addValidationError("fail to parse config"); + throw validationException; + } + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } + + private Forecaster parseConfig(XContentParser parser) throws IOException { + try { + // use default forecaster interval in case of validation exception since it can be empty + return Forecaster.parse(parser, null, null, new TimeValue(1, TimeUnit.MINUTES), null); + } catch (Exception e) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError(e.getMessage()); + throw validationException; + } + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestGetForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestGetForecasterAction.java new file mode 100644 index 000000000..c36a925a4 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestGetForecasterAction.java @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.TYPE; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.GetForecasterAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestActions; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to retrieve an anomaly detector. + */ +public class RestGetForecasterAction extends BaseRestHandler { + + private static final String GET_FORECASTER_ACTION = "get_forecaster"; + + public RestGetForecasterAction() {} + + @Override + public String getName() { + return GET_FORECASTER_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterId = request.param(FORECASTER_ID); + String typesStr = request.param(TYPE); + + String rawPath = request.rawPath(); + boolean returnJob = request.paramAsBoolean("job", false); + boolean returnTask = request.paramAsBoolean("task", false); + boolean all = request.paramAsBoolean("_all", false); + GetConfigRequest getForecasterRequest = new GetConfigRequest( + forecasterId, + RestActions.parseVersion(request), + returnJob, + returnTask, + typesStr, + rawPath, + all, + RestHandlerUtils.buildEntity(request, forecasterId) + ); + + return channel -> client.execute(GetForecasterAction.INSTANCE, getForecasterRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } + + @Override + public List routes() { + return ImmutableList + .of( + // Opensearch-only API. Considering users may provide entity in the search body, + // support POST as well. + + // profile API + new Route( + RestRequest.Method.POST, + String + .format( + Locale.ROOT, + "%s/{%s}/%s", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE + ) + ), + // types is a profile names. See a complete list of supported profiles names in + // org.opensearch.ad.model.ProfileName. + new Route( + RestRequest.Method.POST, + String + .format( + Locale.ROOT, + "%s/{%s}/%s/{%s}", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE, + TYPE + ) + ), + new Route( + RestRequest.Method.GET, + String + .format( + Locale.ROOT, + "%s/{%s}/%s", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE + ) + ), + // types is a profile names. See a complete list of supported profiles names in + // org.opensearch.ad.model.ProfileName. + new Route( + RestRequest.Method.GET, + String + .format( + Locale.ROOT, + "%s/{%s}/%s/{%s}", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE, + TYPE + ) + ), + + // get forecaster API + new Route( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java new file mode 100644 index 000000000..24a9ab037 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java @@ -0,0 +1,148 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.IF_PRIMARY_TERM; +import static org.opensearch.timeseries.util.RestHandlerUtils.IF_SEQ_NO; +import static org.opensearch.timeseries.util.RestHandlerUtils.REFRESH; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.IndexForecasterAction; +import org.opensearch.forecast.transport.IndexForecasterRequest; +import org.opensearch.forecast.transport.IndexForecasterResponse; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.action.RestResponseListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.model.Config; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * Rest handlers to create and update forecaster. + */ +public class RestIndexForecasterAction extends AbstractForecasterAction { + private static final String INDEX_FORECASTER_ACTION = "index_forecaster_action"; + private final Logger logger = LogManager.getLogger(RestIndexForecasterAction.class); + + public RestIndexForecasterAction(Settings settings, ClusterService clusterService) { + super(settings, clusterService); + } + + @Override + public String getName() { + return INDEX_FORECASTER_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterId = request.param(FORECASTER_ID, Config.NO_ID); + logger.info("Forecaster {} action for forecasterId {}", request.method(), forecasterId); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Forecaster forecaster = Forecaster.parse(parser, forecasterId, null, forecastInterval, forecastWindowDelay); + + long seqNo = request.paramAsLong(IF_SEQ_NO, SequenceNumbers.UNASSIGNED_SEQ_NO); + long primaryTerm = request.paramAsLong(IF_PRIMARY_TERM, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + WriteRequest.RefreshPolicy refreshPolicy = request.hasParam(REFRESH) + ? WriteRequest.RefreshPolicy.parse(request.param(REFRESH)) + : WriteRequest.RefreshPolicy.IMMEDIATE; + RestRequest.Method method = request.getHttpRequest().method(); + + IndexForecasterRequest indexAnomalyDetectorRequest = new IndexForecasterRequest( + forecasterId, + seqNo, + primaryTerm, + refreshPolicy, + forecaster, + method, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields + ); + + return channel -> client + .execute(IndexForecasterAction.INSTANCE, indexAnomalyDetectorRequest, indexForecasterResponse(channel, method)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } catch (ValidationException e) { + // convert 500 to 400 errors for validation failures + throw new OpenSearchStatusException(e.getMessage(), RestStatus.BAD_REQUEST); + } + } + + @Override + public List routes() { + return ImmutableList + .of( + // Create + new Route(RestRequest.Method.POST, TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI), + // Update + new Route( + RestRequest.Method.PUT, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID) + ) + ); + } + + private RestResponseListener indexForecasterResponse(RestChannel channel, RestRequest.Method method) { + return new RestResponseListener(channel) { + @Override + public RestResponse buildResponse(IndexForecasterResponse response) throws Exception { + RestStatus restStatus = RestStatus.CREATED; + if (method == RestRequest.Method.PUT) { + restStatus = RestStatus.OK; + } + BytesRestResponse bytesRestResponse = new BytesRestResponse( + restStatus, + response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS) + ); + if (restStatus == RestStatus.CREATED) { + String location = String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_BASE_URI, response.getId()); + bytesRestResponse.addHeader("Location", location); + } + return bytesRestResponse; + } + }; + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestRunOnceForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestRunOnceForecasterAction.java new file mode 100644 index 000000000..042e21820 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestRunOnceForecasterAction.java @@ -0,0 +1,81 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.RUN_ONCE; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.joda.time.Instant; +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.ForecastResultRequest; +import org.opensearch.forecast.transport.ForecastRunOnceAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to handle request to forecast. + */ +public class RestRunOnceForecasterAction extends BaseRestHandler { + + public static final String FORECASTER_ACTION = "run_forecaster_once"; + + public RestRunOnceForecasterAction() {} + + @Override + public String getName() { + return FORECASTER_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + // execute forester once + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID, RUN_ONCE) + ) + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterId = request.param(FORECASTER_ID); + + ForecastResultRequest getRequest = new ForecastResultRequest( + forecasterId, + -1L, // will set it in ResultProcessor.onGetConfig + Instant.now().getMillis() + ); + + return channel -> client.execute(ForecastRunOnceAction.INSTANCE, getRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestSearchForecastTasksAction.java b/src/main/java/org/opensearch/forecast/rest/RestSearchForecastTasksAction.java new file mode 100644 index 000000000..6b72e42e6 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestSearchForecastTasksAction.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.transport.SearchForecastTasksAction; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to search AD tasks. + */ +public class RestSearchForecastTasksAction extends AbstractForecastSearchAction { + + private static final String URL_PATH = TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI + "/tasks/_search"; + private final String SEARCH_FORECASTER_TASKS = "search_forecaster_tasks"; + + public RestSearchForecastTasksAction() { + super( + ImmutableList.of(URL_PATH), + ImmutableList.of(), + ForecastIndex.STATE.getIndexName(), + ForecastTask.class, + SearchForecastTasksAction.INSTANCE + ); + } + + @Override + public String getName() { + return SEARCH_FORECASTER_TASKS; + } + +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterAction.java new file mode 100644 index 000000000..1e5d76b7a --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterAction.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.SEARCH; + +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.transport.SearchForecasterAction; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonName; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to search anomaly detectors. + */ +public class RestSearchForecasterAction extends AbstractForecastSearchAction { + + private static final String URL_PATH = TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI + "/" + SEARCH; + private final String SEARCH_FORECASTER_ACTION = "search_forecaster"; + + public RestSearchForecasterAction() { + super(ImmutableList.of(URL_PATH), ImmutableList.of(), CommonName.CONFIG_INDEX, Forecaster.class, SearchForecasterAction.INSTANCE); + } + + @Override + public String getName() { + return SEARCH_FORECASTER_ACTION; + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterInfoAction.java b/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterInfoAction.java new file mode 100644 index 000000000..16cc54ecd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterInfoAction.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.COUNT; +import static org.opensearch.timeseries.util.RestHandlerUtils.MATCH; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.SearchForecasterInfoAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.SearchConfigInfoRequest; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +public class RestSearchForecasterInfoAction extends BaseRestHandler { + + public static final String SEARCH_FORECASTER_INFO_ACTION = "search_forecaster_info"; + + public RestSearchForecasterInfoAction() {} + + @Override + public String getName() { + return SEARCH_FORECASTER_INFO_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, org.opensearch.client.node.NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterName = request.param("name", null); + String rawPath = request.rawPath(); + + SearchConfigInfoRequest searchForecasterInfoRequest = new SearchConfigInfoRequest(forecasterName, rawPath); + return channel -> client + .execute(SearchForecasterInfoAction.INSTANCE, searchForecasterInfoRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, COUNT) + ), + new Route( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, MATCH) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestSearchTopForecastResultAction.java b/src/main/java/org/opensearch/forecast/rest/RestSearchTopForecastResultAction.java new file mode 100644 index 000000000..49e922e9b --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestSearchTopForecastResultAction.java @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.SearchTopForecastResultAction; +import org.opensearch.forecast.transport.SearchTopForecastResultRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * The REST handler to search top entity anomaly results for HC detectors. + */ +public class RestSearchTopForecastResultAction extends BaseRestHandler { + + private static final String URL_PATH = String + .format( + Locale.ROOT, + "%s/{%s}/%s/%s", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + RestHandlerUtils.FORECASTER_ID, + RestHandlerUtils.RESULTS, + RestHandlerUtils.TOP_FORECASTS + ); + private final String SEARCH_TOP_FORECASTS_ACTION = "search_top_forecasts"; + + public RestSearchTopForecastResultAction() {} + + @Override + public String getName() { + return SEARCH_TOP_FORECASTS_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + // Throw error if disabled + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + // Get the typed request + SearchTopForecastResultRequest searchTopAnomalyResultRequest = getSearchTopForecastResultRequest(request); + + return channel -> client + .execute(SearchTopForecastResultAction.INSTANCE, searchTopAnomalyResultRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + + } + + private SearchTopForecastResultRequest getSearchTopForecastResultRequest(RestRequest request) throws IOException { + String forecasterId; + if (request.hasParam(RestHandlerUtils.FORECASTER_ID)) { + forecasterId = request.param(RestHandlerUtils.FORECASTER_ID); + } else { + throw new IllegalStateException(ForecastCommonMessages.FORECASTER_ID_MISSING_MSG); + } + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + return SearchTopForecastResultRequest.parse(parser, forecasterId); + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.POST, URL_PATH), new Route(RestRequest.Method.GET, URL_PATH)); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestStatsForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestStatsForecasterAction.java new file mode 100644 index 000000000..e4f6f5edd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestStatsForecasterAction.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.forecast.transport.StatsForecasterAction; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.rest.RestStatsAction; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * RestStatsForecasterAction consists of the REST handler to get the stats from forecasting. + */ +public class RestStatsForecasterAction extends RestStatsAction { + + private static final String STATS_FORECASTER_ACTION = "stats_forecaster"; + + /** + * Constructor + * + * @param timeSeriesStats TimeSeriesStats object + * @param nodeFilter util class to get eligible data nodes + */ + public RestStatsForecasterAction(ForecastStats timeSeriesStats, DiscoveryNodeFilterer nodeFilter) { + super(timeSeriesStats, nodeFilter); + } + + @Override + public String getName() { + return STATS_FORECASTER_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + StatsRequest forecastStatsRequest = getRequest(request); + return channel -> client.execute(StatsForecasterAction.INSTANCE, forecastStatsRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route(RestRequest.Method.GET, TimeSeriesAnalyticsPlugin.FORECAST_BASE_URI + "/{nodeId}/stats/"), + new Route(RestRequest.Method.GET, TimeSeriesAnalyticsPlugin.FORECAST_BASE_URI + "/{nodeId}/stats/{stat}"), + new Route(RestRequest.Method.GET, TimeSeriesAnalyticsPlugin.FORECAST_BASE_URI + "/stats/"), + new Route(RestRequest.Method.GET, TimeSeriesAnalyticsPlugin.FORECAST_BASE_URI + "/stats/{stat}") + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestValidateForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestValidateForecasterAction.java new file mode 100644 index 000000000..93ff62288 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestValidateForecasterAction.java @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.VALIDATE; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.ValidateForecasterAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.model.ConfigValidationIssue; +import org.opensearch.timeseries.rest.RestValidateAction; +import org.opensearch.timeseries.transport.ValidateConfigRequest; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to validate anomaly detector configurations. + */ +public class RestValidateForecasterAction extends AbstractForecasterAction { + private static final String VALIDATE_FORECASTER_ACTION = "validate_forecaster_action"; + + private RestValidateAction validateAction; + + public RestValidateForecasterAction(Settings settings, ClusterService clusterService) { + super(settings, clusterService); + this.validateAction = new RestValidateAction( + AnalysisType.FORECAST, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + requestTimeout + ); + } + + @Override + public String getName() { + return VALIDATE_FORECASTER_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, VALIDATE) + ), + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, VALIDATE, TYPE) + ) + ); + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + // we have to get the param from a subclass of BaseRestHandler. Otherwise, we cannot parse the type out of request params + String typesStr = request.param(TYPE); + + return channel -> { + try { + ValidateConfigRequest validateForecasterRequest = validateAction.prepareRequest(request, client, typesStr); + client.execute(ValidateForecasterAction.INSTANCE, validateForecasterRequest, new RestToXContentListener<>(channel)); + } catch (Exception ex) { + if (ex instanceof ValidationException) { + ValidationException forecastException = (ValidationException) ex; + ConfigValidationIssue issue = new ConfigValidationIssue( + forecastException.getAspect(), + forecastException.getType(), + forecastException.getMessage() + ); + validateAction.sendValidationParseResponse(issue, channel); + } else { + throw ex; + } + } + }; + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java new file mode 100644 index 000000000..a41f2bbcc --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java @@ -0,0 +1,282 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest.handler; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.util.Locale; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.IndexForecasterResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.Sets; + +/** + * AbstractForecasterActionHandler extends the AbstractTimeSeriesActionHandler to provide a specialized + * base for handling forecasting-related actions within OpenSearch. This abstract class encapsulates common + * logic and utilities specifically tailored for managing forecasting tasks, including the validation, creation, + * and updating of forecaster configurations. + * + * Key functionalities include: + * - Processing REST requests related to forecasting tasks, ensuring they meet the required standards and formats. + * - Interacting with ForecastIndex and ForecastIndexManagement for forecast-specific index operations. + * - Validating forecasting configurations against various constraints, such as maximum allowed single-stream + * and high cardinality (HC) forecasters, ensuring the configurations adhere to defined limits. + * - Managing user permissions and security for forecasting operations, leveraging the SecurityClientUtil. + * - Extending support for forecasting-specific fields and settings, such as forecast horizon, imputation options, + * and emphasis on recent data. + * + * Usage: + * This class is designed to be extended by concrete handlers that implement forecasting-specific logic for actions + * such as creating a new forecaster, updating existing configurations, or validating forecasting models. It provides + * a structured framework that includes essential services like client communication, security utilities, and task management, + * allowing implementers to focus on the unique aspects of their forecasting tasks. + * + * Extending classes are required to implement abstract methods defined in both this class and its parent, providing + * functionality for parsing forecasting configurations, handling validation exceptions, and constructing response + * objects for REST calls. + * + * Example Extension: + * A concrete implementation might include a IndexForecasterActionHandler that leverages this class to handle the + * creation of new forecaster configurations, including validation against predefined limits and index management. + */ +public abstract class AbstractForecasterActionHandler extends + AbstractTimeSeriesActionHandler { + protected final Logger logger = LogManager.getLogger(AbstractForecasterActionHandler.class); + + public static final String EXCEEDED_MAX_HC_FORECASTERS_PREFIX_MSG = "Can't create more than %d HC forecasters."; + public static final String EXCEEDED_MAX_SINGLE_STREAM_FORECASTERS_PREFIX_MSG = "Can't create more than %d single-stream forecasters."; + public static final String NO_DOCS_IN_USER_INDEX_MSG = "Can't create forecasters as no document is found in the indices: "; + public static final String DUPLICATE_FORECASTER_MSG = + "Cannot create forecasters with name [%s] as it's already used by another forecaster"; + public static final String VALIDATION_FEATURE_FAILURE = "Validation failed for feature(s) of forecaster %s"; + + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client ES node client that executes actions on the local node + * @param clientUtil Forecast security client + * @param transportService ES transport service + * @param forecastIndices forecast index manager + * @param forecasterId forecaster identifier + * @param seqNo sequence number of last modification + * @param primaryTerm primary term of last modification + * @param refreshPolicy refresh policy + * @param forecaster forecaster instance + * @param requestTimeout request time out configuration + * @param maxSingleStreamForecasters max single-stream forecasters allowed + * @param maxHCForecasters max HC forecasters allowed + * @param maxFeatures max features allowed per forecaster + * @param maxCategoricalFields max categorical fields allowed + * @param method Rest Method type + * @param xContentRegistry Registry which is used for XContentParser + * @param user User context + * @param forecastTaskManager Forecast task manager + * @param searchFeatureDao Search utility + * @param validationType validation type in validate API. Can be null (no validation). + * @param isDryRun Whether handler is dryrun or not + * @param clock clock object to know when to timeout + * @param settings Node settings + */ + public AbstractForecasterActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + TransportService transportService, + ForecastIndexManagement forecastIndices, + String forecasterId, + Long seqNo, + Long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + Config forecaster, + TimeValue requestTimeout, + Integer maxSingleStreamForecasters, + Integer maxHCForecasters, + Integer maxFeatures, + Integer maxCategoricalFields, + RestRequest.Method method, + NamedXContentRegistry xContentRegistry, + User user, + ForecastTaskManager forecastTaskManager, + SearchFeatureDao searchFeatureDao, + String validationType, + boolean isDryRun, + Clock clock, + Settings settings + ) { + super( + forecaster, + forecastIndices, + isDryRun, + client, + forecasterId, + clientUtil, + user, + method, + clusterService, + xContentRegistry, + transportService, + requestTimeout, + refreshPolicy, + seqNo, + primaryTerm, + validationType, + searchFeatureDao, + maxFeatures, + maxCategoricalFields, + AnalysisType.FORECAST, + forecastTaskManager, + ForecastTaskType.RUN_ONCE_TASK_TYPES, + true, + maxSingleStreamForecasters, + maxHCForecasters, + clock, + settings + ); + } + + @Override + protected TimeSeriesException createValidationException(String msg, ValidationIssueType type) { + return new ValidationException(msg, type, ValidationAspect.FORECASTER); + } + + @Override + protected Forecaster parse(XContentParser parser, GetResponse response) throws IOException { + return Forecaster.parse(parser, response.getId(), response.getVersion()); + } + + @Override + protected String getExceedMaxSingleStreamConfigsErrorMsg(int maxSingleStreamConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_SINGLE_STREAM_FORECASTERS_PREFIX_MSG, getMaxSingleStreamConfigs()); + } + + @Override + protected String getExceedMaxHCConfigsErrorMsg(int maxHCConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_HC_FORECASTERS_PREFIX_MSG, getMaxHCConfigs()); + } + + @Override + protected String getNoDocsInUserIndexErrorMsg(String suppliedIndices) { + return String.format(Locale.ROOT, NO_DOCS_IN_USER_INDEX_MSG, suppliedIndices); + } + + @Override + protected String getDuplicateConfigErrorMsg(String name) { + return String.format(Locale.ROOT, DUPLICATE_FORECASTER_MSG, name); + } + + @Override + protected Config copyConfig(User user, Config config) { + return new Forecaster( + config.getId(), + config.getVersion(), + config.getName(), + config.getDescription(), + config.getTimeField(), + config.getIndices(), + config.getFeatureAttributes(), + config.getFilterQuery(), + config.getInterval(), + config.getWindowDelay(), + config.getShingleSize(), + config.getUiMetadata(), + config.getSchemaVersion(), + Instant.now(), + config.getCategoryFields(), + user, + config.getCustomResultIndex(), + ((Forecaster) config).getHorizon(), + config.getImputationOption(), + config.getRecencyEmphasis(), + config.getSeasonIntervals(), + config.getHistoryIntervals() + ); + } + + @SuppressWarnings("unchecked") + @Override + protected T createIndexConfigResponse(IndexResponse indexResponse, Config config) { + return (T) new IndexForecasterResponse( + indexResponse.getId(), + indexResponse.getVersion(), + indexResponse.getSeqNo(), + indexResponse.getPrimaryTerm(), + (Forecaster) config, + RestStatus.CREATED + ); + } + + @Override + protected Set getDefaultValidationType() { + return Sets.newHashSet(ValidationAspect.FORECASTER); + } + + @Override + protected String getFeatureErrorMsg(String name) { + return String.format(Locale.ROOT, VALIDATION_FEATURE_FAILURE, name); + } + + @Override + protected void validateModel(ActionListener listener) { + ForecastModelValidationActionHandler modelValidationActionHandler = new ForecastModelValidationActionHandler( + clusterService, + client, + clientUtil, + (ActionListener) listener, + (Forecaster) config, + requestTimeout, + xContentRegistry, + searchFeatureDao, + validationType, + clock, + settings, + user + ); + modelValidationActionHandler.start(); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/ForecastIndexJobActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/ForecastIndexJobActionHandler.java new file mode 100644 index 000000000..c746eba79 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/ForecastIndexJobActionHandler.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest.handler; + +import static org.opensearch.forecast.model.ForecastTaskType.RUN_ONCE_TASK_TYPES; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_REQUEST_TIMEOUT; + +import java.util.List; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.ExecuteForecastResultResponseRecorder; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.ForecastProfileAction; +import org.opensearch.forecast.transport.ForecastResultAction; +import org.opensearch.forecast.transport.ForecastResultRequest; +import org.opensearch.forecast.transport.StopForecasterAction; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.ResultRequest; +import org.opensearch.transport.TransportService; + +public class ForecastIndexJobActionHandler extends + IndexJobActionHandler { + + public ForecastIndexJobActionHandler( + Client client, + ForecastIndexManagement indexManagement, + NamedXContentRegistry xContentRegistry, + ForecastTaskManager adTaskManager, + ExecuteForecastResultResponseRecorder recorder, + NodeStateManager nodeStateManager, + Settings settings + ) { + super( + client, + indexManagement, + xContentRegistry, + adTaskManager, + recorder, + ForecastResultAction.INSTANCE, + AnalysisType.FORECAST, + ForecastIndex.STATE.getIndexName(), + StopForecasterAction.INSTANCE, + nodeStateManager, + settings, + FORECAST_REQUEST_TIMEOUT + ); + } + + @Override + protected ResultRequest createResultRequest(String configID, long start, long end) { + return new ForecastResultRequest(configID, start, end); + } + + @Override + protected List getBatchConfigTaskTypes() { + return RUN_ONCE_TASK_TYPES; + } + + /** + * Stop config. + * For realtime, will set job as disabled. + * For run once, will set its task as inactive. + * + * @param configId config id + * @param historical stop historical analysis or not + * @param user user + * @param transportService transport service + * @param listener action listener + */ + @Override + public void stopConfig( + String configId, + boolean historical, + User user, + TransportService transportService, + ActionListener listener + ) { + // make sure forecaster exists + nodeStateManager.getConfig(configId, AnalysisType.FORECAST, (config) -> { + if (!config.isPresent()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); + return; + } + taskManager.getAndExecuteOnLatestConfigLevelTask(configId, ForecastTaskType.RUN_ONCE_TASK_TYPES, (task) -> { + // stop realtime forecaster job + stopJob(configId, transportService, listener); + }, transportService, true, listener); // true means reset task state as inactive/stopped state + }, listener); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/ForecastModelValidationActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/ForecastModelValidationActionHandler.java new file mode 100644 index 000000000..f03c1fdc7 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/ForecastModelValidationActionHandler.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest.handler; + +import java.time.Clock; + +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.rest.handler.ModelValidationActionHandler; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class ForecastModelValidationActionHandler extends ModelValidationActionHandler { + + public ForecastModelValidationActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + ActionListener listener, + Forecaster config, + TimeValue requestTimeout, + NamedXContentRegistry xContentRegistry, + SearchFeatureDao searchFeatureDao, + String validationType, + Clock clock, + Settings settings, + User user + ) { + super( + clusterService, + client, + clientUtil, + listener, + config, + requestTimeout, + xContentRegistry, + searchFeatureDao, + validationType, + clock, + settings, + user, + AnalysisType.FORECAST + ); + } + +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/IndexForecasterActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/IndexForecasterActionHandler.java new file mode 100644 index 000000000..78cce651b --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/IndexForecasterActionHandler.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest.handler; + +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.IndexForecasterResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +/** + * process create/update forecaster request + * + */ +public class IndexForecasterActionHandler extends AbstractForecasterActionHandler { + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client OS node client that executes actions on the local node + * @param transportService OS transport service + * @param forecastIndices forecast index manager + * @param forecasterId forecaster identifier + * @param seqNo sequence number of last modification + * @param primaryTerm primary term of last modification + * @param refreshPolicy refresh policy + * @param forecaster forecaster instance + * @param requestTimeout request time out configuration + * @param maxSingleStreamForecasters max single-stream forecasters allowed + * @param maxHCForecasters max HC forecasters allowed + * @param maxForecastFeatures max features allowed per forecaster + * @param maxCategoricalFields max number of categorical fields + * @param method Rest Method type + * @param xContentRegistry Registry which is used for XContentParser + * @param user User context + */ + public IndexForecasterActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + TransportService transportService, + ForecastIndexManagement forecastIndices, + String forecasterId, + Long seqNo, + Long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + Forecaster forecaster, + TimeValue requestTimeout, + Integer maxSingleStreamForecasters, + Integer maxHCForecasters, + Integer maxForecastFeatures, + Integer maxCategoricalFields, + RestRequest.Method method, + NamedXContentRegistry xContentRegistry, + User user, + ForecastTaskManager taskManager, + SearchFeatureDao searchFeatureDao, + Settings settings + ) { + super( + clusterService, + client, + clientUtil, + transportService, + forecastIndices, + forecasterId, + seqNo, + primaryTerm, + refreshPolicy, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry, + user, + taskManager, + searchFeatureDao, + null, + false, + null, + settings + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/ValidateForecasterActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/ValidateForecasterActionHandler.java new file mode 100644 index 000000000..8ebf9ea5e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/ValidateForecasterActionHandler.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest.handler; + +import java.time.Clock; + +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; + +/** + * ValidateForecasterActionHandler extends the AbstractForecasterActionHandler to specifically handle + * the validation of forecasting configurations within OpenSearch. This class is responsible for initiating + * and executing the validation process for forecasters, ensuring that the configurations provided by the user + * meet the necessary criteria and constraints for successful forecasting operations. + * + * Key responsibilities include: + * - Performing thorough validation of forecaster configurations, including checks against maximum allowed + * configurations, feature constraints, and categorical field limitations. + * - Utilizing the SearchFeatureDao to validate the feasibility of feature queries included in the forecaster + * configuration, ensuring that they can be successfully executed within the OpenSearch environment. + * - Leveraging the broader framework provided by AbstractForecasterActionHandler to manage common tasks such + * as security checks, user context management, and interaction with forecast indices. + * + * Usage: + * This handler is invoked during the forecaster configuration validation process, typically through REST API + * calls made by users attempting to create or update forecasters. It is designed to provide immediate feedback + * on the validity of the proposed configuration, helping users to adjust their settings to meet the system's + * requirements and best practices. + * + * The ValidateForecasterActionHandler is instantiated with detailed configuration parameters, including limits + * on the number of features, categorical fields, and other critical settings. It then proceeds to validate these + * configurations against the existing system constraints and the data available in the specified indices. + * + * Example: + * The handler could be triggered by a REST call to validate a new forecaster configuration before its creation, + * ensuring that all specified settings and features are valid and that the forecaster is likely to operate + * successfully within the given constraints and data environment. + */ +public class ValidateForecasterActionHandler extends AbstractForecasterActionHandler { + + public ValidateForecasterActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + ForecastIndexManagement forecastIndices, + Config forecaster, + TimeValue requestTimeout, + Integer maxSingleStreamForecasters, + Integer maxHCForecasters, + Integer maxFeatures, + Integer maxCategoricalFields, + Method method, + NamedXContentRegistry xContentRegistry, + User user, + SearchFeatureDao searchFeatureDao, + String validationType, + Clock clock, + Settings settings + ) { + super( + clusterService, + client, + clientUtil, + null, + forecastIndices, + Config.NO_ID, + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxFeatures, + maxCategoricalFields, + method, + xContentRegistry, + user, + null, + searchFeatureDao, + validationType, + true, + clock, + settings + ); + } + +} diff --git a/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java b/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java index 1db9bf340..b22ffc1cd 100644 --- a/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java +++ b/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java @@ -27,31 +27,12 @@ public class ForecastEnabledSetting extends DynamicNumericSetting { */ public static final String FORECAST_ENABLED = "plugins.forecast.enabled"; - public static final String FORECAST_BREAKER_ENABLED = "plugins.forecast.breaker.enabled"; - - public static final String FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED = "plugins.forecast.door_keeper_in_cache.enabled";; - public static final Map> settings = unmodifiableMap(new HashMap>() { { /** * forecast enable/disable setting */ put(FORECAST_ENABLED, Setting.boolSetting(FORECAST_ENABLED, true, NodeScope, Dynamic)); - - /** - * forecast breaker enable/disable setting - */ - put(FORECAST_BREAKER_ENABLED, Setting.boolSetting(FORECAST_BREAKER_ENABLED, true, NodeScope, Dynamic)); - - /** - * We have a bloom filter placed in front of inactive entity cache to - * filter out unpopular items that are not likely to appear more - * than once. Whether this bloom filter is enabled or not. - */ - put( - FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, - Setting.boolSetting(FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, false, NodeScope, Dynamic) - ); } }); @@ -73,20 +54,4 @@ public static synchronized ForecastEnabledSetting getInstance() { public static boolean isForecastEnabled() { return ForecastEnabledSetting.getInstance().getSettingValue(ForecastEnabledSetting.FORECAST_ENABLED); } - - /** - * Whether forecast circuit breaker is enabled or not. If disabled, an open circuit breaker wouldn't cause an forecast job to be stopped. - * @return whether forecast circuit breaker is enabled or not. - */ - public static boolean isForecastBreakerEnabled() { - return ForecastEnabledSetting.getInstance().getSettingValue(ForecastEnabledSetting.FORECAST_BREAKER_ENABLED); - } - - /** - * If enabled, we filter out unpopular items that are not likely to appear more than once - * @return wWhether door keeper in cache is enabled or not. - */ - public static boolean isDoorKeeperInCacheEnabled() { - return ForecastEnabledSetting.getInstance().getSettingValue(ForecastEnabledSetting.FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED); - } } diff --git a/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java b/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java index 8aeaeb6c3..be0975f2b 100644 --- a/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java +++ b/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java @@ -77,7 +77,7 @@ public final class ForecastSettings { public static final int MAX_FORECAST_FEATURES = 1; // ====================================== - // AD Index setting + // Index setting // ====================================== public static int FORECAST_MAX_UPDATE_RETRY_TIMES = 10_000; @@ -386,4 +386,8 @@ public final class ForecastSettings { public static final Setting FORECAST_MAX_MODEL_SIZE_PER_NODE = Setting .intSetting("plugins.forecast.max_model_size_per_node", 100, 1, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + // ====================================== + // ML + // ====================================== + public static final int MINIMUM_SHINLE_SIZE = 4; } diff --git a/src/main/java/org/opensearch/forecast/stats/ForecastModelsOnNodeSupplier.java b/src/main/java/org/opensearch/forecast/stats/ForecastModelsOnNodeSupplier.java new file mode 100644 index 000000000..fd53bbbdd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/stats/ForecastModelsOnNodeSupplier.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.stats; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MAX_MODEL_SIZE_PER_NODE; +import static org.opensearch.timeseries.ml.ModelState.LAST_CHECKPOINT_TIME_KEY; +import static org.opensearch.timeseries.ml.ModelState.LAST_USED_TIME_KEY; +import static org.opensearch.timeseries.ml.ModelState.MODEL_TYPE_KEY; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.constant.CommonName; + +public class ForecastModelsOnNodeSupplier implements Supplier>> { + private ForecastCacheProvider forecastCache; + private volatile int forecastNumModelsToReturn; + + /** + * Set that contains the model stats that should be exposed. + */ + public static Set MODEL_STATE_STAT_KEYS = new HashSet<>( + Arrays + .asList( + CommonName.MODEL_ID_FIELD, + MODEL_TYPE_KEY, + CommonName.ENTITY_KEY, + LAST_USED_TIME_KEY, + LAST_CHECKPOINT_TIME_KEY, + ForecastCommonName.FORECASTER_ID_KEY + ) + ); + + /** + * Constructor + * + * @param forecastCache object that manages HC forecasters' models + * @param settings node settings accessor + * @param clusterService Cluster service accessor + */ + public ForecastModelsOnNodeSupplier(ForecastCacheProvider forecastCache, Settings settings, ClusterService clusterService) { + this.forecastCache = forecastCache; + this.forecastNumModelsToReturn = FORECAST_MAX_MODEL_SIZE_PER_NODE.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(FORECAST_MAX_MODEL_SIZE_PER_NODE, it -> this.forecastNumModelsToReturn = it); + } + + @Override + public List> get() { + Stream> forecastStream = forecastCache + .get() + .getAllModels() + .stream() + .limit(forecastNumModelsToReturn) + .map( + modelState -> modelState + .getModelStateAsMap() + .entrySet() + .stream() + .filter(entry -> MODEL_STATE_STAT_KEYS.contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ); + + return forecastStream.collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/forecast/stats/ForecastStats.java b/src/main/java/org/opensearch/forecast/stats/ForecastStats.java new file mode 100644 index 000000000..197043cde --- /dev/null +++ b/src/main/java/org/opensearch/forecast/stats/ForecastStats.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.stats; + +import java.util.Map; + +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.stats.TimeSeriesStat; + +public class ForecastStats extends Stats { + + public ForecastStats(Map> stats) { + super(stats); + } + +} diff --git a/src/main/java/org/opensearch/forecast/stats/suppliers/ForecastModelsOnNodeCountSupplier.java b/src/main/java/org/opensearch/forecast/stats/suppliers/ForecastModelsOnNodeCountSupplier.java new file mode 100644 index 000000000..7a8b5283d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/stats/suppliers/ForecastModelsOnNodeCountSupplier.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.stats.suppliers; + +import java.util.function.Supplier; + +import org.opensearch.forecast.caching.ForecastCacheProvider; + +/** + * ModelsOnNodeCountSupplier provides the number of models a node contains + */ +public class ForecastModelsOnNodeCountSupplier implements Supplier { + private ForecastCacheProvider forecastCache; + + /** + * Constructor + * + * @param forecastCache object that manages models + */ + public ForecastModelsOnNodeCountSupplier(ForecastCacheProvider forecastCache) { + this.forecastCache = forecastCache; + } + + @Override + public Long get() { + return forecastCache.get().getAllModels().stream().count(); + } +} diff --git a/src/main/java/org/opensearch/forecast/task/ForecastTaskManager.java b/src/main/java/org/opensearch/forecast/task/ForecastTaskManager.java new file mode 100644 index 000000000..bc2c63002 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/task/ForecastTaskManager.java @@ -0,0 +1,521 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.task; + +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FORECASTER_IS_RUNNING; +import static org.opensearch.forecast.indices.ForecastIndexManagement.ALL_FORECAST_RESULTS_INDEX_PATTERN; +import static org.opensearch.forecast.model.ForecastTask.FORECASTER_ID_FIELD; +import static org.opensearch.forecast.model.ForecastTaskType.REALTIME_TASK_TYPES; +import static org.opensearch.forecast.settings.ForecastSettings.DELETE_FORECAST_RESULT_WHEN_DELETE_FORECASTER; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_OLD_TASK_DOCS_PER_FORECASTER; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME; +import static org.opensearch.timeseries.model.TimeSeriesTask.TASK_ID_FIELD; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.function.BiCheckedFunction; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.function.ResponseTransformer; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.transport.TransportService; + +public class ForecastTaskManager extends + TaskManager { + private final Logger logger = LogManager.getLogger(ForecastTaskManager.class); + + public ForecastTaskManager( + TaskCacheManager forecastTaskCacheManager, + Client client, + NamedXContentRegistry xContentRegistry, + ForecastIndexManagement forecastIndices, + ClusterService clusterService, + Settings settings, + ThreadPool threadPool, + NodeStateManager nodeStateManager + ) { + super( + forecastTaskCacheManager, + clusterService, + client, + ForecastIndex.STATE.getIndexName(), + ForecastTaskType.REALTIME_TASK_TYPES, + Collections.emptyList(), + ForecastTaskType.RUN_ONCE_TASK_TYPES, + forecastIndices, + nodeStateManager, + AnalysisType.FORECAST, + xContentRegistry, + FORECASTER_ID_FIELD, + MAX_OLD_TASK_DOCS_PER_FORECASTER, + settings, + threadPool, + ALL_FORECAST_RESULTS_INDEX_PATTERN, + FORECAST_THREAD_POOL_NAME, + DELETE_FORECAST_RESULT_WHEN_DELETE_FORECASTER, + TaskState.INACTIVE + ); + } + + /** + * Init realtime task cache Realtime forecast depending on job scheduler to choose node (job coordinating node) + * to run forecast job. Nodes have primary or replica shard of the job index are candidate to run forecast job. + * Job scheduler will build hash ring on these candidate nodes and choose one to run forecast job. + * If forecast job index shard relocated, for example new node added into cluster, then job scheduler will + * rebuild hash ring and may choose different node to run forecast job. So we need to init realtime task cache + * on new forecast job coordinating node. + * + * If realtime task cache inited for the first time on this node, listener will return true; otherwise + * listener will return false. + * + * We don't clean up realtime task cache on old coordinating node as HourlyCron will clear cache on old coordinating node. + * + * @param forecasterId forecaster id + * @param forecaster forecaster + * @param transportService transport service + * @param listener listener + */ + @Override + public void initRealtimeTaskCacheAndCleanupStaleCache( + String forecasterId, + Config forecaster, + TransportService transportService, + ActionListener listener + ) { + try { + if (taskCacheManager.getRealtimeTaskCache(forecasterId) != null) { + listener.onResponse(false); + return; + } + + getAndExecuteOnLatestConfigLevelTask(forecasterId, REALTIME_TASK_TYPES, (forecastTaskOptional) -> { + if (forecastTaskOptional.isEmpty()) { + logger.debug("Can't find realtime task for forecaster {}, init realtime task cache directly", forecasterId); + ExecutorFunction function = () -> createNewTask( + forecaster, + null, + false, + forecaster.getUser(), + clusterService.localNode().getId(), + TaskState.CREATED, + ActionListener.wrap(r -> { + logger.info("Recreate realtime task successfully for forecaster {}", forecasterId); + taskCacheManager.initRealtimeTaskCache(forecasterId, forecaster.getIntervalInMilliseconds()); + listener.onResponse(true); + }, e -> { + logger.error("Failed to recreate realtime task for forecaster " + forecasterId, e); + listener.onFailure(e); + }) + ); + recreateRealtimeTaskBeforeExecuting(function, listener); + return; + } + + logger.info("Init realtime task cache for forecaster {}", forecasterId); + taskCacheManager.initRealtimeTaskCache(forecasterId, forecaster.getIntervalInMilliseconds()); + listener.onResponse(true); + }, transportService, false, listener); + } catch (Exception e) { + logger.error("Failed to init realtime task cache for " + forecasterId, e); + listener.onFailure(e); + } + } + + /** + * Update forecast task with specific fields. + * + * @param taskId forecast task id + * @param updatedFields updated fields, key: filed name, value: new value + */ + public void updateForecastTask(String taskId, Map updatedFields) { + updateForecastTask(taskId, updatedFields, ActionListener.wrap(response -> { + if (response.status() == RestStatus.OK) { + logger.debug("Updated forecast task successfully: {}, task id: {}", response.status(), taskId); + } else { + logger.error("Failed to update forecast task {}, status: {}", taskId, response.status()); + } + }, e -> { logger.error("Failed to update task: " + taskId, e); })); + } + + /** + * Update forecast task for specific fields. + * + * @param taskId task id + * @param updatedFields updated fields, key: filed name, value: new value + * @param listener action listener + */ + public void updateForecastTask(String taskId, Map updatedFields, ActionListener listener) { + UpdateRequest updateRequest = new UpdateRequest(ForecastIndex.STATE.getIndexName(), taskId); + Map updatedContent = new HashMap<>(); + updatedContent.putAll(updatedFields); + updatedContent.put(TimeSeriesTask.LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); + updateRequest.doc(updatedContent); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.update(updateRequest, listener); + } + + private void recreateRealtimeTaskBeforeExecuting(ExecutorFunction function, ActionListener listener) { + if (indexManagement.doesStateIndexExist()) { + function.execute(); + } else { + // If forecast state index doesn't exist, create index and execute function. + indexManagement.initStateIndex(ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("Created {} with mappings.", ForecastIndex.STATE.getIndexName()); + function.execute(); + } else { + String error = String + .format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, ForecastIndex.STATE.getIndexName()); + logger.warn(error); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + function.execute(); + } else { + logger.error("Failed to init anomaly detection state index", e); + listener.onFailure(e); + } + })); + } + } + + /** + * Poll deleted detector task from cache and delete its child tasks and AD results. + */ + @Override + public void cleanChildTasksAndResultsOfDeletedTask() { + if (!taskCacheManager.hasDeletedTask()) { + return; + } + threadPool.schedule(() -> { + String taskId = taskCacheManager.pollDeletedTask(); + if (taskId == null) { + return; + } + DeleteByQueryRequest deleteForecastResultsRequest = new DeleteByQueryRequest(ALL_FORECAST_RESULTS_INDEX_PATTERN); + deleteForecastResultsRequest.setQuery(new TermsQueryBuilder(TASK_ID_FIELD, taskId)); + client.execute(DeleteByQueryAction.INSTANCE, deleteForecastResultsRequest, ActionListener.wrap(res -> { + logger.debug("Successfully deleted forecast results of task " + taskId); + DeleteByQueryRequest deleteChildTasksRequest = new DeleteByQueryRequest(ForecastIndex.STATE.getIndexName()); + deleteChildTasksRequest.setQuery(new TermsQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, taskId)); + + client.execute(DeleteByQueryAction.INSTANCE, deleteChildTasksRequest, ActionListener.wrap(r -> { + logger.debug("Successfully deleted child tasks of task " + taskId); + cleanChildTasksAndResultsOfDeletedTask(); + }, e -> { logger.error("Failed to delete child tasks of task " + taskId, e); })); + }, ex -> { logger.error("Failed to delete forecast results for task " + taskId, ex); })); + }, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), AD_BATCH_TASK_THREAD_POOL_NAME); + } + + @Override + public void startHistorical( + Config config, + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener + ) { + // TODO Auto-generated method stub + + } + + @Override + protected TaskType getTaskType(Config config, DateRange dateRange, boolean runOnce) { + if (runOnce) { + return config.isHighCardinality() + ? ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER + : ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM; + } else { + return config.isHighCardinality() + ? ForecastTaskType.REALTIME_FORECAST_HC_FORECASTER + : ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM; + } + } + + @Override + protected void createNewTask( + Config config, + DateRange dateRange, + boolean runOnce, + User user, + String coordinatingNode, + TaskState initialState, + ActionListener listener + ) { + String userName = user == null ? null : user.getName(); + Instant now = Instant.now(); + String taskType = getTaskType(config, dateRange, runOnce).name(); + ForecastTask.Builder forecastTaskBuilder = new ForecastTask.Builder() + .configId(config.getId()) + .forecaster((Forecaster) config) + .isLatest(true) + .taskType(taskType) + .executionStartTime(now) + .state(initialState.name()) + .lastUpdateTime(now) + .startedBy(userName) + .coordinatingNode(coordinatingNode) + .user(user); + + ResponseTransformer responseTransformer; + + final ForecastTask forecastTask; + + // used for run once + if (initialState == TaskState.INIT_TEST) { + forecastTask = forecastTaskBuilder.build(); + responseTransformer = (indexResponse) -> (T) forecastTask; + } else { + forecastTask = forecastTaskBuilder.taskProgress(0.0f).initProgress(0.0f).dateRange(dateRange).build(); + // used for real time + responseTransformer = (indexResponse) -> (T) new JobResponse(indexResponse.getId()); + } + + createTaskDirectly( + forecastTask, + r -> onIndexConfigTaskResponse( + r, + forecastTask, + (response, delegatedListener) -> cleanOldConfigTaskDocs(response, forecastTask, responseTransformer, delegatedListener), + listener + ), + listener + ); + + } + + @Override + public void cleanConfigCache( + TimeSeriesTask task, + TransportService transportService, + ExecutorFunction function, + ActionListener listener + ) { + // no op for forecaster as we rely on state ttl to auto clean it + // only execute function + function.execute(); + } + + @Override + protected boolean isHistoricalHCTask(TimeSeriesTask task) { + // we have no backtesting + return false; + } + + @Override + protected void onIndexConfigTaskResponse( + IndexResponse response, + ForecastTask forecastTask, + BiConsumer> function, + ActionListener listener + ) { + if (response == null || response.getResult() != CREATED) { + String errorMsg = ExceptionUtil.getShardsFailure(response); + listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); + return; + } + forecastTask.setTaskId(response.getId()); + ActionListener delegatedListener = ActionListener.wrap(r -> { listener.onResponse(r); }, e -> { + handleTaskException(forecastTask, e); + if (e instanceof DuplicateTaskException) { + listener.onFailure(new OpenSearchStatusException(FORECASTER_IS_RUNNING, RestStatus.BAD_REQUEST)); + } else { + // TODO: For historical forecast task, what to do if any exception happened? + // For realtime forecast, task cache will be inited when realtime job starts, check + // ForecastTaskManager#initRealtimeTaskCache for details. Here the + // realtime task cache not inited yet when create forecast task, so no need to cleanup. + listener.onFailure(e); + } + }); + // TODO: what to do if this is a historical task? + if (function != null) { + function.accept(response, delegatedListener); + } + } + + @Override + protected void runBatchResultAction( + IndexResponse response, + ForecastTask tsTask, + ResponseTransformer responseTransformer, + ActionListener listener + ) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Forecast does not support back testing yet."); + } + + @Override + protected BiCheckedFunction getTaskParser() { + return ForecastTask::parse; + } + + @Override + public void createRunOnceTaskAndCleanupStaleTasks( + String configId, + Config config, + TransportService transportService, + ActionListener listener + ) { + ForecastTaskType taskType = config.isHighCardinality() + ? ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER + : ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM; + + try { + + if (indexManagement.doesStateIndexExist()) { + // If state index exist, check if latest task is running + getAndExecuteOnLatestConfigLevelTask(config.getId(), Arrays.asList(taskType), (task) -> { + if (!task.isPresent() || task.get().isDone()) { + updateLatestFlagOfOldTasksAndCreateNewTask(config, null, true, config.getUser(), TaskState.INIT_TEST, listener); + } else { + listener.onFailure(new OpenSearchStatusException("run once is on-going", RestStatus.BAD_REQUEST)); + } + }, transportService, true, listener); + } else { + // If state index doesn't exist, create index and execute forecast. + indexManagement.initStateIndex(ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("Created {} with mappings.", stateIndex); + updateLatestFlagOfOldTasksAndCreateNewTask(config, null, true, config.getUser(), TaskState.INIT_TEST, listener); + } else { + String error = String.format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, stateIndex); + logger.warn(error); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + updateLatestFlagOfOldTasksAndCreateNewTask(config, null, true, config.getUser(), TaskState.INIT_TEST, listener); + } else { + logger.error("Failed to init anomaly detection state index", e); + listener.onFailure(e); + } + })); + } + } catch (Exception e) { + logger.error("Failed to start detector " + config.getId(), e); + listener.onFailure(e); + } + } + + @Override + public List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag, boolean runOnce) { + if (runOnce) { + return ForecastTaskType.RUN_ONCE_TASK_TYPES; + } else { + return ForecastTaskType.REALTIME_TASK_TYPES; + } + } + + private void resetRunOnceConfigTaskState( + List runOnceTasks, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + if (ParseUtils.isNullOrEmpty(runOnceTasks)) { + function.execute(); + return; + } + ForecastTask forecastTask = (ForecastTask) runOnceTasks.get(0); + resetTaskStateAsStopped(forecastTask, function, transportService, listener); + } + + /** + * Reset latest config task state. Will reset both historical and realtime tasks. + * [Important!] Make sure listener returns in function + * + * @param tasks tasks + * @param function consumer function + * @param transportService transport service + * @param listener action listener + * @param response type of action listener + */ + @Override + protected void resetLatestConfigTaskState( + List tasks, + Consumer> function, + TransportService transportService, + ActionListener listener + ) { + List runningRealtimeTasks = new ArrayList<>(); + List runningRunOnceTasks = new ArrayList<>(); + + for (TimeSeriesTask task : tasks) { + if (!task.isHistoricalEntityTask() && !task.isDone()) { + if (task.isRealTimeTask()) { + runningRealtimeTasks.add(task); + } else if (task.isRunOnceTask()) { + runningRunOnceTasks.add(task); + } + } + } + + resetRunOnceConfigTaskState( + runningRunOnceTasks, + () -> resetRealtimeConfigTaskState(runningRealtimeTasks, () -> function.accept(tasks), transportService, listener), + transportService, + listener + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/BuildInQuery.java b/src/main/java/org/opensearch/forecast/transport/BuildInQuery.java new file mode 100644 index 000000000..c36c930c6 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/BuildInQuery.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +public enum BuildInQuery { + MIN_CONFIDENCE_INTERVAL_WIDTH, + MAX_CONFIDENCE_INTERVAL_WIDTH, + MIN_VALUE_WITHIN_THE_HORIZON, + MAX_VALUE_WITHIN_THE_HORIZON, + DISTANCE_TO_THRESHOLD_VALUE +} diff --git a/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelAction.java b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelAction.java new file mode 100644 index 000000000..eab816842 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.DeleteModelResponse; + +public class DeleteForecastModelAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "model/delete"; + public static final DeleteForecastModelAction INSTANCE = new DeleteForecastModelAction(); + + private DeleteForecastModelAction() { + super(NAME, DeleteModelResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelTransportAction.java b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelTransportAction.java new file mode 100644 index 000000000..fad3bdd12 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelTransportAction.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.BaseDeleteModelTransportAction; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class DeleteForecastModelTransportAction extends + BaseDeleteModelTransportAction { + + @Inject + public DeleteForecastModelTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NodeStateManager nodeStateManager, + ForecastCacheProvider cache, + TaskCacheManager taskCacheManager, + ForecastColdStart coldStarter + ) { + super( + threadPool, + clusterService, + transportService, + actionFilters, + nodeStateManager, + cache, + taskCacheManager, + coldStarter, + DeleteForecastModelAction.NAME + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/DeleteForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/DeleteForecasterAction.java new file mode 100644 index 000000000..c18bc2327 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/DeleteForecasterAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class DeleteForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/delete"; + public static final DeleteForecasterAction INSTANCE = new DeleteForecasterAction(); + + private DeleteForecasterAction() { + super(NAME, DeleteResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/DeleteForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/DeleteForecasterTransportAction.java new file mode 100644 index 000000000..bf6094934 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/DeleteForecasterTransportAction.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.BaseDeleteConfigTransportAction; +import org.opensearch.transport.TransportService; + +public class DeleteForecasterTransportAction extends + BaseDeleteConfigTransportAction { + + @Inject + public DeleteForecasterTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + NamedXContentRegistry xContentRegistry, + NodeStateManager nodeStateManager, + ForecastTaskManager taskManager + ) { + super( + transportService, + actionFilters, + client, + clusterService, + settings, + xContentRegistry, + nodeStateManager, + taskManager, + DeleteForecasterAction.NAME, + ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES, + AnalysisType.FORECAST, + ForecastIndex.STATE.getIndexName(), + Forecaster.class, + ForecastTaskType.RUN_ONCE_TASK_TYPES + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/EntityForecastResultAction.java b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultAction.java new file mode 100644 index 000000000..77eec3d51 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class EntityForecastResultAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "entity/result"; + public static final EntityForecastResultAction INSTANCE = new EntityForecastResultAction(); + + private EntityForecastResultAction() { + super(NAME, AcknowledgedResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java new file mode 100644 index 000000000..d638b3bae --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java @@ -0,0 +1,174 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +package org.opensearch.forecast.transport; + +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastCheckpointReadWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastColdEntityWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.forecast.ratelimit.ForecastResultWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastSaveResultStrategy; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.EntityResultProcessor; +import org.opensearch.timeseries.transport.EntityResultRequest; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +/** + * Entry-point for HC forecast workflow. We have created multiple queues for coordinating + * the workflow. The overrall workflow is: + * 1. We store as many frequently used entity models in a cache as allowed by the + * memory limit (by default 10% heap). If an entity feature is a hit, we use the in-memory model + * to forecast and record results using the result write queue. + * 2. If an entity feature is a miss, we check if there is free memory or any other + * entity's model can be evacuated. An in-memory entity's frequency may be lower + * compared to the cache miss entity. If that's the case, we replace the lower + * frequency entity's model with the higher frequency entity's model. To load the + * higher frequency entity's model, we first check if a model exists on disk by + * sending a checkpoint read queue request. If there is a checkpoint, we load it + * to memory, perform forecast, and save the result using the result write queue. + * Otherwise, we enqueue a cold start request to the cold start queue for model + * training. If training is successful, we save the learned model via the checkpoint + * write queue. + * 3. We also have the cold entity queue configured for cold entities, and the model + * training and inference are connected by serial juxtaposition to limit resource usage. + */ +public class EntityForecastResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(EntityForecastResultTransportAction.class); + private CircuitBreakerService circuitBreakerService; + private CacheProvider cache; + private final NodeStateManager stateManager; + private ThreadPool threadPool; + private EntityResultProcessor intervalDataProcessor; + + private final ForecastCacheProvider entityCache; + private final ForecastModelManager manager; + private final ForecastStats timeSeriesStats; + private final ForecastColdStartWorker entityColdStartWorker; + private final ForecastCheckpointReadWorker checkpointReadQueue; + private final ForecastColdEntityWorker coldEntityQueue; + private final ForecastSaveResultStrategy forecastSaveResultStategy; + + @Inject + public EntityForecastResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + ForecastModelManager manager, + CircuitBreakerService adCircuitBreakerService, + ForecastCacheProvider entityCache, + NodeStateManager stateManager, + ForecastIndexManagement indexUtil, + ForecastResultWriteWorker resultWriteQueue, + ForecastCheckpointReadWorker checkpointReadQueue, + ForecastColdEntityWorker coldEntityQueue, + ThreadPool threadPool, + ForecastColdStartWorker entityColdStartWorker, + ForecastStats timeSeriesStats, + ForecastSaveResultStrategy forecastSaveResultStategy + ) { + super(EntityForecastResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); + this.circuitBreakerService = adCircuitBreakerService; + this.cache = entityCache; + this.stateManager = stateManager; + this.threadPool = threadPool; + this.intervalDataProcessor = null; + this.entityCache = entityCache; + this.manager = manager; + this.timeSeriesStats = timeSeriesStats; + this.entityColdStartWorker = entityColdStartWorker; + this.checkpointReadQueue = checkpointReadQueue; + this.coldEntityQueue = coldEntityQueue; + this.forecastSaveResultStategy = forecastSaveResultStategy; + } + + @Override + protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { + if (circuitBreakerService.isOpen()) { + threadPool + .executor(TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME) + .execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); + listener.onFailure(new LimitExceededException(request.getConfigId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + try { + String forecasterId = request.getConfigId(); + + Optional previousException = stateManager.fetchExceptionAndClear(forecasterId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", forecasterId, exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + + listener = ExceptionUtil.wrapListener(listener, exception, forecasterId); + } + + intervalDataProcessor = new EntityResultProcessor<>( + entityCache, + manager, + timeSeriesStats, + entityColdStartWorker, + checkpointReadQueue, + coldEntityQueue, + forecastSaveResultStategy, + StatNames.FORECAST_MODEL_CORRUTPION_COUNT + ); + + stateManager + .getConfig( + forecasterId, + request.getAnalysisType(), + intervalDataProcessor.onGetConfig(listener, forecasterId, request, previousException, request.getAnalysisType()) + ); + } catch (Exception exception) { + LOG.error("fail to get entity's forecasts", exception); + listener.onFailure(exception); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileAction.java new file mode 100644 index 000000000..de4b48ef4 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.EntityProfileResponse; + +public class ForecastEntityProfileAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "forecasters/profile/entity"; + public static final ForecastEntityProfileAction INSTANCE = new ForecastEntityProfileAction(); + + private ForecastEntityProfileAction() { + super(NAME, EntityProfileResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileTransportAction.java new file mode 100644 index 000000000..6fe726c4e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileTransportAction.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.transport.BaseEntityProfileTransportAction; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +/** + * Transport action to get entity profile. + */ +public class ForecastEntityProfileTransportAction extends + BaseEntityProfileTransportAction { + + @Inject + public ForecastEntityProfileTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + HashRing hashRing, + ClusterService clusterService, + ForecastCacheProvider cacheProvider + ) { + super( + actionFilters, + transportService, + settings, + hashRing, + clusterService, + cacheProvider, + ForecastEntityProfileAction.NAME, + ForecastSettings.FORECAST_REQUEST_TIMEOUT + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastProfileAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastProfileAction.java new file mode 100644 index 000000000..35595a76f --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastProfileAction.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.ProfileResponse; + +/** + * Profile transport action + */ +public class ForecastProfileAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecasters/profile"; + public static final ForecastProfileAction INSTANCE = new ForecastProfileAction(); + + /** + * Constructor + */ + private ForecastProfileAction() { + super(NAME, ProfileResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastProfileTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastProfileTransportAction.java new file mode 100644 index 000000000..87c7ccdba --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastProfileTransportAction.java @@ -0,0 +1,63 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MAX_MODEL_SIZE_PER_NODE; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.transport.BaseProfileTransportAction; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +/** + * This class contains the logic to extract the stats from the nodes + */ +public class ForecastProfileTransportAction extends BaseProfileTransportAction { + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param cacheProvider cache provider + * @param settings Node settings accessor + */ + @Inject + public ForecastProfileTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ForecastCacheProvider cacheProvider, + Settings settings + ) { + super( + ForecastProfileAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + cacheProvider, + settings, + FORECAST_MAX_MODEL_SIZE_PER_NODE + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultAction.java new file mode 100644 index 000000000..ef9178a02 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultAction.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class ForecastResultAction extends ActionType { + // External Action which used for public facing RestAPIs or actions we need to assume cx's role. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/run"; + public static final ForecastResultAction INSTANCE = new ForecastResultAction(); + + private ForecastResultAction() { + super(NAME, ForecastResultResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkAction.java new file mode 100644 index 000000000..6394636b3 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkAction.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.transport.TransportRequestOptions; + +public class ForecastResultBulkAction extends ActionType { + + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "write/bulk"; + public static final ForecastResultBulkAction INSTANCE = new ForecastResultBulkAction(); + + private ForecastResultBulkAction() { + super(NAME, ResultBulkResponse::new); + } + + @Override + public TransportRequestOptions transportOptions(Settings settings) { + return TransportRequestOptions.builder().withType(TransportRequestOptions.Type.BULK).build(); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkRequest.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkRequest.java new file mode 100644 index 000000000..730275b4d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkRequest.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.timeseries.transport.ResultBulkRequest; + +public class ForecastResultBulkRequest extends ResultBulkRequest { + + public ForecastResultBulkRequest() { + super(); + } + + public ForecastResultBulkRequest(StreamInput in) throws IOException { + super(in, ForecastResultWriteRequest::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java new file mode 100644 index 000000000..95422a98a --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_INDEX_PRESSURE_HARD_LIMIT; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_INDEX_PRESSURE_SOFT_LIMIT; + +import java.util.List; + +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.index.IndexingPressure; +import org.opensearch.timeseries.transport.ResultBulkTransportAction; +import org.opensearch.transport.TransportService; + +public class ForecastResultBulkTransportAction extends + ResultBulkTransportAction { + + @Inject + public ForecastResultBulkTransportAction( + TransportService transportService, + ActionFilters actionFilters, + IndexingPressure indexingPressure, + Settings settings, + ClusterService clusterService, + Client client + ) { + super( + ForecastResultBulkAction.NAME, + transportService, + actionFilters, + indexingPressure, + settings, + client, + FORECAST_INDEX_PRESSURE_SOFT_LIMIT.get(settings), + FORECAST_INDEX_PRESSURE_HARD_LIMIT.get(settings), + ForecastIndex.RESULT.getIndexName(), + ForecastResultBulkRequest::new + ); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_INDEX_PRESSURE_SOFT_LIMIT, it -> softLimit = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_INDEX_PRESSURE_HARD_LIMIT, it -> hardLimit = it); + } + + @Override + protected BulkRequest prepareBulkRequest(float indexingPressurePercent, ForecastResultBulkRequest request) { + BulkRequest bulkRequest = new BulkRequest(); + List results = request.getAnomalyResults(); + + if (indexingPressurePercent <= softLimit) { + for (ForecastResultWriteRequest resultWriteRequest : results) { + addResult(bulkRequest, resultWriteRequest.getResult(), resultWriteRequest.getResultIndex()); + } + } else if (indexingPressurePercent <= hardLimit) { + // exceed soft limit (60%) but smaller than hard limit (90%) + float acceptProbability = 1 - indexingPressurePercent; + for (ForecastResultWriteRequest resultWriteRequest : results) { + ForecastResult result = resultWriteRequest.getResult(); + if (random.nextFloat() < acceptProbability) { + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); + } + } + } else { + // if exceeding hard limit, only index error result + for (ForecastResultWriteRequest resultWriteRequest : results) { + ForecastResult result = resultWriteRequest.getResult(); + if (result.isHighPriority()) { + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); + } + } + } + + return bulkRequest; + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultProcessor.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultProcessor.java new file mode 100644 index 000000000..8489bf47f --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultProcessor.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MAX_ENTITIES_PER_INTERVAL; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_PAGE_SIZE; + +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ForecastResultProcessor extends + ResultProcessor { + + private static final Logger LOG = LogManager.getLogger(ForecastResultProcessor.class); + + public ForecastResultProcessor( + Setting requestTimeoutSetting, + float intervalRatioForRequests, + String entityResultAction, + StatNames hcRequestCountStat, + Settings settings, + ClusterService clusterService, + ThreadPool threadPool, + HashRing hashRing, + NodeStateManager nodeStateManager, + TransportService transportService, + ForecastStats timeSeriesStats, + ForecastTaskManager realTimeTaskManager, + NamedXContentRegistry xContentRegistry, + Client client, + SecurityClientUtil clientUtil, + IndexNameExpressionResolver indexNameExpressionResolver, + Class transportResultResponseClazz, + FeatureManager featureManager, + AnalysisType analysisType, + boolean runOnce + ) { + super( + requestTimeoutSetting, + intervalRatioForRequests, + entityResultAction, + hcRequestCountStat, + settings, + clusterService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + hashRing, + nodeStateManager, + transportService, + timeSeriesStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + transportResultResponseClazz, + featureManager, + FORECAST_MAX_ENTITIES_PER_INTERVAL, + FORECAST_PAGE_SIZE, + analysisType, + runOnce, + ForecastSingleStreamResultAction.NAME + ); + } + + @Override + protected ForecastResultResponse createResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long configInterval, + Boolean isHC, + String taskId + ) { + return new ForecastResultResponse(features, error, rcfTotalUpdates, configInterval, isHC, taskId); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultRequest.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultRequest.java new file mode 100644 index 000000000..074e975e4 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultRequest.java @@ -0,0 +1,71 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.transport.ResultRequest; + +public class ForecastResultRequest extends ResultRequest { + + public ForecastResultRequest(StreamInput in) throws IOException { + super(in); + in.readEnum(AnalysisType.class); + } + + public ForecastResultRequest(String forecastID, long start, long end) { + super(forecastID, start, end); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(configId)) { + validationException = addValidationError(ForecastCommonMessages.FORECASTER_ID_MISSING_MSG, validationException); + } + // at least end time should be set + if (end <= 0) { + validationException = addValidationError( + String.format(Locale.ROOT, "%s: start %d, end %d", ForecastCommonMessages.INVALID_TIMESTAMP_ERR_MSG, start, end), + validationException + ); + } + return validationException; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ForecastCommonName.ID_JSON_KEY, configId); + builder.field(CommonName.START_JSON_KEY, start); + builder.field(CommonName.END_JSON_KEY, end); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultResponse.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultResponse.java new file mode 100644 index 000000000..b1c7a8b47 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultResponse.java @@ -0,0 +1,221 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.transport.ResultResponse; + +public class ForecastResultResponse extends ResultResponse { + public static final String DATA_QUALITY_JSON_KEY = "dataQuality"; + public static final String ERROR_JSON_KEY = "error"; + public static final String FEATURES_JSON_KEY = "features"; + public static final String FEATURE_VALUE_JSON_KEY = "value"; + public static final String RCF_TOTAL_UPDATES_JSON_KEY = "rcfTotalUpdates"; + public static final String FORECASTER_INTERVAL_IN_MINUTES_JSON_KEY = "forecasterIntervalInMinutes"; + public static final String FORECAST_VALUES_JSON_KEY = "forecastValues"; + public static final String FORECAST_UPPERS_JSON_KEY = "forecastUppers"; + public static final String FORECAST_LOWERS_JSON_KEY = "forecastLowers"; + + private Double dataQuality; + private float[] forecastsValues; + private float[] forecastsUppers; + private float[] forecastsLowers; + + // used when returning an error/exception or empty result + public ForecastResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long forecasterIntervalInMinutes, + Boolean isHCForecaster, + String taskId + ) { + this(Double.NaN, features, error, rcfTotalUpdates, forecasterIntervalInMinutes, isHCForecaster, null, null, null, taskId); + } + + public ForecastResultResponse( + Double confidence, + List features, + String error, + Long rcfTotalUpdates, + Long forecasterIntervalInMinutes, + Boolean isHCForecaster, + float[] forecastsValues, + float[] forecastsUppers, + float[] forecastsLowers, + String taskId + ) { + super(features, error, rcfTotalUpdates, forecasterIntervalInMinutes, isHCForecaster, taskId); + this.dataQuality = confidence; + this.forecastsValues = forecastsValues; + this.forecastsUppers = forecastsUppers; + this.forecastsLowers = forecastsLowers; + this.taskId = taskId; + } + + public ForecastResultResponse(StreamInput in) throws IOException { + super(in); + dataQuality = in.readDouble(); + int size = in.readVInt(); + features = new ArrayList(); + for (int i = 0; i < size; i++) { + features.add(new FeatureData(in)); + } + error = in.readOptionalString(); + rcfTotalUpdates = in.readOptionalLong(); + configIntervalInMinutes = in.readOptionalLong(); + isHC = in.readOptionalBoolean(); + + if (in.readBoolean()) { + forecastsValues = in.readFloatArray(); + forecastsUppers = in.readFloatArray(); + forecastsLowers = in.readFloatArray(); + } else { + forecastsValues = null; + forecastsUppers = null; + forecastsLowers = null; + } + taskId = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(dataQuality); + out.writeVInt(features.size()); + for (FeatureData feature : features) { + feature.writeTo(out); + } + out.writeOptionalString(error); + out.writeOptionalLong(rcfTotalUpdates); + out.writeOptionalLong(configIntervalInMinutes); + out.writeOptionalBoolean(isHC); + + if (forecastsValues != null) { + if (forecastsUppers == null || forecastsLowers == null) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "null value: forecastsUppers: %s, forecastsLowers: %s", forecastsUppers, forecastsLowers) + ); + } + out.writeBoolean(true); + out.writeFloatArray(forecastsValues); + out.writeFloatArray(forecastsUppers); + out.writeFloatArray(forecastsLowers); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(taskId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (dataQuality != null && !dataQuality.equals(Double.NaN)) { + builder.field(DATA_QUALITY_JSON_KEY, dataQuality); + } + if (error != null) { + builder.field(ERROR_JSON_KEY, error); + } + if (features != null && features.size() > 0) { + builder.startArray(FEATURES_JSON_KEY); + for (FeatureData feature : features) { + feature.toXContent(builder, params); + } + builder.endArray(); + } + if (rcfTotalUpdates != null) { + builder.field(RCF_TOTAL_UPDATES_JSON_KEY, rcfTotalUpdates); + } + if (forecastsValues != null) { + builder.field(FORECAST_VALUES_JSON_KEY, forecastsValues); + } + if (forecastsUppers != null) { + builder.field(FORECAST_UPPERS_JSON_KEY, forecastsUppers); + } + if (forecastsLowers != null) { + builder.field(FORECAST_LOWERS_JSON_KEY, forecastsLowers); + } + if (taskId != null) { + builder.field(CommonName.TASK_ID_FIELD, taskId); + } + // don't show interval as we only need to access it in memory to compute init estimated time remaining + builder.endObject(); + return builder; + } + + /** + * + * Convert ForecastResultResponse to ForecastResult + * + * @param forecastId Forecaster Id + * @param dataStartInstant data start time + * @param dataEndInstant data end time + * @param executionStartInstant execution start time + * @param executionEndInstant execution end time + * @param schemaVersion Schema version + * @param user Detector author + * @param error Error + * @return converted ForecastResult + */ + @Override + public List toIndexableResults( + String forecastId, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + Integer schemaVersion, + User user, + String error + ) { + // Forecast interval in milliseconds + long forecasterIntervalMilli = Duration.between(dataStartInstant, dataEndInstant).toMillis(); + return ForecastResult + .fromRawRCFCasterResult( + forecastId, + forecasterIntervalMilli, + dataQuality, + features, + dataStartInstant, + dataEndInstant, + executionStartInstant, + executionEndInstant, + error, + Optional.empty(), + user, + schemaVersion, + null, // single-stream real-time has no model id + forecastsValues, + forecastsUppers, + forecastsLowers, + taskId // real time results have no task id + ); + } + + @Override + public boolean shouldSave() { + return super.shouldSave() || (forecastsValues != null && forecastsValues.length > 0); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java new file mode 100644 index 000000000..1db61e5d9 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java @@ -0,0 +1,188 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ForecastResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(ForecastResultTransportAction.class); + private ResultProcessor resultProcessor; + private final Client client; + private CircuitBreakerService circuitBreakerService; + // Cache HC forecaster id. This is used to count HC failure stats. We can tell a forecaster + // is HC or not by checking if forecaster id exists in this field or not. Will add + // forecaster id to this field when start to run realtime detection and remove forecaster + // id once realtime detection done. + private final Set hcForecasters; + private final ForecastStats forecastStats; + private final NodeStateManager nodeStateManager; + private final Settings settings; + private final ClusterService clusterService; + private final ThreadPool threadPool; + private final HashRing hashRing; + private final TransportService transportService; + private final ForecastTaskManager realTimeTaskManager; + private final NamedXContentRegistry xContentRegistry; + private final SecurityClientUtil clientUtil; + private final IndexNameExpressionResolver indexNameExpressionResolver; + private final FeatureManager featureManager; + + @Inject + public ForecastResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + Client client, + SecurityClientUtil clientUtil, + NodeStateManager nodeStateManager, + FeatureManager featureManager, + ForecastModelManager modelManager, + HashRing hashRing, + ClusterService clusterService, + IndexNameExpressionResolver indexNameExpressionResolver, + CircuitBreakerService circuitBreakerService, + ForecastStats forecastStats, + ThreadPool threadPool, + NamedXContentRegistry xContentRegistry, + ForecastTaskManager realTimeTaskManager + ) { + super(ForecastResultAction.NAME, transportService, actionFilters, ForecastResultRequest::new); + + this.settings = settings; + this.clusterService = clusterService; + this.threadPool = threadPool; + this.hashRing = hashRing; + this.transportService = transportService; + this.realTimeTaskManager = realTimeTaskManager; + this.xContentRegistry = xContentRegistry; + this.clientUtil = clientUtil; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.featureManager = featureManager; + + this.client = client; + this.circuitBreakerService = circuitBreakerService; + this.hcForecasters = new HashSet<>(); + this.forecastStats = forecastStats; + this.nodeStateManager = nodeStateManager; + + this.resultProcessor = null; + } + + @Override + protected void doExecute(Task task, ForecastResultRequest request, ActionListener listener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + String forecastID = request.getConfigId(); + ActionListener original = listener; + listener = ActionListener.wrap(r -> { + hcForecasters.remove(forecastID); + original.onResponse(r); + }, e -> { + // If exception is TimeSeriesException and it should not be counted in stats, + // we will not count it in failure stats. + if (!(e instanceof TimeSeriesException) || ((TimeSeriesException) e).isCountedInStats()) { + forecastStats.getStat(StatNames.FORECAST_EXECUTE_FAIL_COUNT.getName()).increment(); + if (hcForecasters.contains(forecastID)) { + forecastStats.getStat(StatNames.FORECAST_HC_EXECUTE_FAIL_COUNT.getName()).increment(); + } + } + hcForecasters.remove(forecastID); + original.onFailure(e); + }); + + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new EndRunException(forecastID, ForecastCommonMessages.DISABLED_ERR_MSG, true).countedInStats(false); + } + + forecastStats.getStat(StatNames.FORECAST_EXECUTE_REQUEST_COUNT.getName()).increment(); + + if (circuitBreakerService.isOpen()) { + listener.onFailure(new LimitExceededException(forecastID, CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + this.resultProcessor = new ForecastResultProcessor( + ForecastSettings.FORECAST_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityForecastResultAction.NAME, + StatNames.FORECAST_HC_EXECUTE_REQUEST_COUNT, + settings, + clusterService, + threadPool, + hashRing, + nodeStateManager, + transportService, + forecastStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + ForecastResultResponse.class, + featureManager, + AnalysisType.FORECAST, + false + ); + + try { + nodeStateManager + .getConfig( + forecastID, + AnalysisType.FORECAST, + resultProcessor.onGetConfig(listener, forecastID, request, Optional.of(hcForecasters)) + ); + } catch (Exception ex) { + ResultProcessor.handleExecuteException(ex, listener, forecastID); + } + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceAction.java new file mode 100644 index 000000000..addbf3216 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class ForecastRunOnceAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/runOnce"; + public static final ForecastRunOnceAction INSTANCE = new ForecastRunOnceAction(); + + private ForecastRunOnceAction() { + super(NAME, ForecastResultResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileAction.java new file mode 100644 index 000000000..1025d3dfc --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileAction.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.BooleanResponse; + +public class ForecastRunOnceProfileAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/runOnceProfile"; + public static final ForecastRunOnceProfileAction INSTANCE = new ForecastRunOnceProfileAction(); + + private ForecastRunOnceProfileAction() { + super(NAME, BooleanResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileRequest.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileRequest.java new file mode 100644 index 000000000..67f807f50 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileRequest.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class ForecastRunOnceProfileRequest extends BaseNodesRequest { + private String configId; + + public ForecastRunOnceProfileRequest(StreamInput in) throws IOException { + super(in); + configId = in.readString(); + } + + /** + * Constructor + * + * @param configId config id + */ + public ForecastRunOnceProfileRequest(String configId, DiscoveryNode... nodes) { + super(nodes); + this.configId = configId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(configId); + } + + public String getConfigId() { + return configId; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileTransportAction.java new file mode 100644 index 000000000..a9fe218a8 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileTransportAction.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.ratelimit.ForecastCheckpointReadWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.transport.BooleanNodeResponse; +import org.opensearch.timeseries.transport.BooleanResponse; +import org.opensearch.timeseries.transport.ForecastRunOnceProfileNodeRequest; +import org.opensearch.transport.TransportService; + +public class ForecastRunOnceProfileTransportAction extends + TransportNodesAction { + private final ForecastColdStartWorker coldStartWorker; + private final ForecastCheckpointReadWorker checkpointReadWorker; + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param settings Node settings accessor + */ + @Inject + public ForecastRunOnceProfileTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + Settings settings, + ForecastColdStartWorker coldStartWorker, + ForecastCheckpointReadWorker checkpointReadWorker + ) { + super( + ForecastRunOnceProfileAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + ForecastRunOnceProfileRequest::new, + ForecastRunOnceProfileNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + BooleanNodeResponse.class + ); + this.coldStartWorker = coldStartWorker; + this.checkpointReadWorker = checkpointReadWorker; + } + + @Override + protected BooleanResponse newResponse( + ForecastRunOnceProfileRequest request, + List responses, + List failures + ) { + return new BooleanResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected ForecastRunOnceProfileNodeRequest newNodeRequest(ForecastRunOnceProfileRequest request) { + return new ForecastRunOnceProfileNodeRequest(request); + } + + @Override + protected BooleanNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new BooleanNodeResponse(in); + } + + @Override + protected BooleanNodeResponse nodeOperation(ForecastRunOnceProfileNodeRequest request) { + String configId = request.getConfigId(); + + return new BooleanNodeResponse( + clusterService.localNode(), + coldStartWorker.hasConfigId(configId) || checkpointReadWorker.hasConfigId(configId) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceTransportAction.java new file mode 100644 index 000000000..dcd132a05 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceTransportAction.java @@ -0,0 +1,354 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.core.rest.RestStatus.CONFLICT; +import static org.opensearch.core.rest.RestStatus.FORBIDDEN; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_FORECAST_FEATURES; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_HC_FORECASTERS; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_SINGLE_STREAM_FORECASTERS; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; + +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableMap; + +public class ForecastRunOnceTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(ForecastRunOnceTransportAction.class); + private ResultProcessor resultProcessor; + private final Client client; + private CircuitBreakerService circuitBreakerService; + private final NodeStateManager nodeStateManager; + + private final Settings settings; + private final ClusterService clusterService; + private final ThreadPool threadPool; + private final HashRing hashRing; + private final TransportService transportService; + private final ForecastTaskManager taskManager; + private final NamedXContentRegistry xContentRegistry; + private final SecurityClientUtil clientUtil; + private final IndexNameExpressionResolver indexNameExpressionResolver; + private final FeatureManager featureManager; + private final ForecastStats forecastStats; + private volatile Boolean filterByEnabled; + + protected volatile Integer maxSingleStreamForecasters; + protected volatile Integer maxHCForecasters; + protected volatile Integer maxForecastFeatures; + protected volatile Integer maxCategoricalFields; + + @Inject + public ForecastRunOnceTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + Client client, + SecurityClientUtil clientUtil, + NodeStateManager nodeStateManager, + FeatureManager featureManager, + ForecastModelManager modelManager, + HashRing hashRing, + ClusterService clusterService, + IndexNameExpressionResolver indexNameExpressionResolver, + CircuitBreakerService circuitBreakerService, + ForecastStats forecastStats, + ThreadPool threadPool, + NamedXContentRegistry xContentRegistry, + ForecastTaskManager realTimeTaskManager + ) { + super(ForecastRunOnceAction.NAME, transportService, actionFilters, ForecastResultRequest::new); + + this.resultProcessor = null; + this.settings = settings; + this.clusterService = clusterService; + this.threadPool = threadPool; + this.hashRing = hashRing; + this.transportService = transportService; + this.taskManager = realTimeTaskManager; + this.xContentRegistry = xContentRegistry; + this.clientUtil = clientUtil; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.featureManager = featureManager; + this.forecastStats = forecastStats; + + this.client = client; + this.circuitBreakerService = circuitBreakerService; + this.nodeStateManager = nodeStateManager; + filterByEnabled = ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + + this.maxSingleStreamForecasters = MAX_SINGLE_STREAM_FORECASTERS.get(settings); + this.maxHCForecasters = MAX_HC_FORECASTERS.get(settings); + this.maxForecastFeatures = MAX_FORECAST_FEATURES; + this.maxCategoricalFields = ForecastNumericSetting.maxCategoricalFields(); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_SINGLE_STREAM_FORECASTERS, it -> maxSingleStreamForecasters = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_HC_FORECASTERS, it -> maxHCForecasters = it); + } + + @Override + protected void doExecute(Task task, ForecastResultRequest request, ActionListener listener) { + String forecastID = request.getConfigId(); + User user = ParseUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + + resolveUserAndExecute( + user, + forecastID, + filterByEnabled, + listener, + (forecaster) -> executeRunOnce(forecastID, request, listener), + client, + clusterService, + xContentRegistry, + Forecaster.class + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(new OpenSearchStatusException("Failed to run once forecaster " + forecastID, INTERNAL_SERVER_ERROR)); + } + } + + private void executeRunOnce(String forecastID, ForecastResultRequest request, ActionListener listener) { + if (!ForecastEnabledSetting.isForecastEnabled()) { + listener.onFailure(new OpenSearchStatusException(ForecastCommonMessages.DISABLED_ERR_MSG, FORBIDDEN)); + } + + if (circuitBreakerService.isOpen()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, SERVICE_UNAVAILABLE)); + return; + } + + client.execute(ForecastRunOnceProfileAction.INSTANCE, new ForecastRunOnceProfileRequest(forecastID), ActionListener.wrap(r -> { + if (r.isAnswerTrue()) { + listener + .onFailure( + new OpenSearchStatusException( + "cannot start a new test " + forecastID + " since current test hasn't finished.", + CONFLICT + ) + ); + } else { + nodeStateManager.getJob(forecastID, ActionListener.wrap(jobOptional -> { + if (jobOptional.isPresent() && jobOptional.get().isEnabled()) { + listener + .onFailure( + new OpenSearchStatusException("Cannot run once " + forecastID + " when real time job is running.", CONFLICT) + ); + return; + } + + triggerRunOnce(forecastID, request, listener); + }, e -> { + if (e instanceof IndexNotFoundException) { + triggerRunOnce(forecastID, request, listener); + } else { + LOG.error(e); + listener + .onFailure(new OpenSearchStatusException("Fail to verify if job " + forecastID + " starts or not.", CONFLICT)); + } + })); + } + }, e -> { + LOG.error(e); + listener.onFailure(new OpenSearchStatusException("Failed to run once forecaster " + forecastID, INTERNAL_SERVER_ERROR)); + })); + } + + private void checkIfRunOnceFinished(String forecastID, String taskId, AtomicInteger waitTimes) { + client.execute(ForecastRunOnceProfileAction.INSTANCE, new ForecastRunOnceProfileRequest(forecastID), ActionListener.wrap(r -> { + if (r.isAnswerTrue()) { + handleRunOnceNotFinished(forecastID, taskId, waitTimes); + } else { + handleRunOnceFinished(forecastID, taskId); + } + }, e -> { + LOG.error("Failed to profile run once of forecaster " + forecastID, e); + handleRunOnceNotFinished(forecastID, taskId, waitTimes); + })); + } + + private void handleRunOnceNotFinished(String forecastID, String taskId, AtomicInteger waitTimes) { + if (waitTimes.get() < 10) { + waitTimes.addAndGet(1); + threadPool + .schedule( + () -> checkIfRunOnceFinished(forecastID, taskId, waitTimes), + new TimeValue(10, TimeUnit.SECONDS), + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME + ); + } else { + LOG.warn("Timed out run once of forecaster {}", forecastID); + updateTaskState(forecastID, taskId, TaskState.INACTIVE); + } + } + + private void handleRunOnceFinished(String forecastID, String taskId) { + LOG.info("Run once of forecaster {} finished", forecastID); + nodeStateManager.getConfig(forecastID, AnalysisType.FORECAST, ActionListener.wrap(configOptional -> { + if (configOptional.isEmpty()) { + updateTaskState(forecastID, taskId, TaskState.INACTIVE); + return; + } + checkForecastResults(forecastID, taskId, configOptional.get()); + }, e -> { + LOG.error("Fail to get config", e); + updateTaskState(forecastID, taskId, TaskState.INACTIVE); + })); + } + + private void checkForecastResults(String forecastID, String taskId, Config config) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(ForecastCommonName.FORECASTER_ID_KEY, forecastID)); + ExistsQueryBuilder forecastsExistFilter = QueryBuilders.existsQuery(ForecastResult.VALUE_FIELD); + filterQuery.must(forecastsExistFilter); + // run-once analysis result also stored in result index, which has non-null task_id. + filterQuery.filter(QueryBuilders.termQuery(CommonName.TASK_ID_FIELD, taskId)); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); + + SearchRequest request = new SearchRequest(ForecastIndexManagement.ALL_FORECAST_RESULTS_INDEX_PATTERN); + request.source(source); + if (config.getCustomResultIndex() != null) { + request.indices(config.getCustomResultIndex()); + } + + client.search(request, ActionListener.wrap(searchResponse -> { + SearchHits hits = searchResponse.getHits(); + if (hits.getTotalHits().value > 0) { + // has at least one result + updateTaskState(forecastID, taskId, TaskState.TEST_COMPLETE); + } else { + updateTaskState(forecastID, taskId, TaskState.INIT_TEST_FAILED); + } + }, e -> { + LOG.error("Fail to search result", e); + updateTaskState(forecastID, taskId, TaskState.INACTIVE); + })); + } + + private void updateTaskState(String forecastID, String taskId, TaskState state) { + taskManager.updateTask(taskId, ImmutableMap.of(TimeSeriesTask.STATE_FIELD, state.name()), ActionListener.wrap(updateResponse -> { + LOG.info("Updated forecaster task: {} state as: {} for forecaster: {}", taskId, state.name(), forecastID); + }, e -> { LOG.error("Failed to update forecaster task: {} for forecaster: {}", taskId, forecastID, e); })); + } + + private void triggerRunOnce(String forecastID, ForecastResultRequest request, ActionListener listener) { + try { + resultProcessor = new ForecastResultProcessor( + ForecastSettings.FORECAST_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityForecastResultAction.NAME, + StatNames.FORECAST_HC_EXECUTE_REQUEST_COUNT, + settings, + clusterService, + threadPool, + hashRing, + nodeStateManager, + transportService, + forecastStats, + taskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + ForecastResultResponse.class, + featureManager, + AnalysisType.FORECAST, + true + ); + + ActionListener wrappedListener = ActionListener.wrap(r -> { + AtomicInteger waitTimes = new AtomicInteger(0); + + threadPool + .schedule( + () -> checkIfRunOnceFinished(forecastID, r.getTaskId(), waitTimes), + new TimeValue(10, TimeUnit.SECONDS), + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME + ); + listener.onResponse(r); + }, e -> { + LOG.error("Failed to finish run once of forecaster " + forecastID, e); + listener.onFailure(new OpenSearchStatusException("Failed to run once forecaster " + forecastID, INTERNAL_SERVER_ERROR)); + }); + + nodeStateManager + .getConfig( + forecastID, + AnalysisType.FORECAST, + resultProcessor.onGetConfig(wrappedListener, forecastID, request, Optional.empty()) + ); + + // check for status + } catch (Exception ex) { + ResultProcessor.handleExecuteException(ex, listener, forecastID); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultAction.java new file mode 100644 index 000000000..6b8687b82 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class ForecastSingleStreamResultAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "singlestream/result"; + public static final ForecastSingleStreamResultAction INSTANCE = new ForecastSingleStreamResultAction(); + + private ForecastSingleStreamResultAction() { + super(NAME, AcknowledgedResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java new file mode 100644 index 000000000..9da80bdf2 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.inject.Inject; +import org.opensearch.forecast.caching.ForecastCacheBuffer; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastCheckpointMaintainWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointReadWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.forecast.ratelimit.ForecastResultWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastSaveResultStrategy; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.transport.AbstractSingleStreamResultTransportAction; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastSingleStreamResultTransportAction extends + AbstractSingleStreamResultTransportAction { + + private static final Logger LOG = LogManager.getLogger(ForecastSingleStreamResultTransportAction.class); + + @Inject + public ForecastSingleStreamResultTransportAction( + TransportService transportService, + ActionFilters actionFilters, + CircuitBreakerService circuitBreakerService, + ForecastCacheProvider cache, + NodeStateManager stateManager, + ForecastCheckpointReadWorker checkpointReadQueue, + ForecastModelManager modelManager, + ForecastIndexManagement indexUtil, + ForecastResultWriteWorker resultWriteQueue, + ForecastStats stats, + ForecastColdStartWorker forecastColdStartQueue + ) { + super( + transportService, + actionFilters, + circuitBreakerService, + cache, + stateManager, + checkpointReadQueue, + modelManager, + indexUtil, + resultWriteQueue, + stats, + forecastColdStartQueue, + ForecastSingleStreamResultAction.NAME, + ForecastIndex.RESULT, + AnalysisType.FORECAST + ); + } + + @Override + public ForecastResultWriteRequest createResultWriteRequest(Config config, ForecastResult result) { + return new ForecastResultWriteRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + config.getId(), + RequestPriority.MEDIUM, + result, + config.getCustomResultIndex() + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesAction.java new file mode 100644 index 000000000..3d1bd793e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesAction.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.StatsNodesResponse; + +/** + * ADStatsNodesAction class + */ +public class ForecastStatsNodesAction extends ActionType { + + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "stats/nodes"; + public static final ForecastStatsNodesAction INSTANCE = new ForecastStatsNodesAction(); + + /** + * Constructor + */ + private ForecastStatsNodesAction() { + super(NAME, StatsNodesResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesTransportAction.java new file mode 100644 index 000000000..9d2ec58d7 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesTransportAction.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.transport.BaseStatsNodesTransportAction; +import org.opensearch.transport.TransportService; + +/** + * ForecastStatsNodesTransportAction contains the logic to extract the stats from the nodes + */ +public class ForecastStatsNodesTransportAction extends BaseStatsNodesTransportAction { + @Inject + public ForecastStatsNodesTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ForecastStats stats + ) { + super(threadPool, clusterService, transportService, actionFilters, stats, ForecastStatsNodesAction.NAME); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecasterJobAction.java b/src/main/java/org/opensearch/forecast/transport/ForecasterJobAction.java new file mode 100644 index 000000000..bfd915288 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecasterJobAction.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.JobResponse; + +public class ForecasterJobAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/jobmanagement"; + public static final ForecasterJobAction INSTANCE = new ForecasterJobAction(); + + private ForecasterJobAction() { + super(NAME, JobResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java new file mode 100644 index 000000000..b6f35c27f --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_START_FORECASTER; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_STOP_FORECASTER; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_REQUEST_TIMEOUT; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.ExecuteForecastResultResponseRecorder; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.rest.handler.ForecastIndexJobActionHandler; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.BaseJobTransportAction; +import org.opensearch.transport.TransportService; + +public class ForecasterJobTransportAction extends + BaseJobTransportAction { + + @Inject + public ForecasterJobTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + NamedXContentRegistry xContentRegistry, + ForecastIndexJobActionHandler forecastIndexJobActionHandler + ) { + super( + transportService, + actionFilters, + client, + clusterService, + settings, + xContentRegistry, + FORECAST_FILTER_BY_BACKEND_ROLES, + ForecasterJobAction.NAME, + FORECAST_REQUEST_TIMEOUT, + FAIL_TO_START_FORECASTER, + FAIL_TO_STOP_FORECASTER, + Forecaster.class, + forecastIndexJobActionHandler + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/GetForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/GetForecasterAction.java new file mode 100644 index 000000000..ef5d13540 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/GetForecasterAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class GetForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecasters/get"; + public static final GetForecasterAction INSTANCE = new GetForecasterAction(); + + private GetForecasterAction() { + super(NAME, GetForecasterResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/GetForecasterResponse.java b/src/main/java/org/opensearch/forecast/transport/GetForecasterResponse.java new file mode 100644 index 000000000..5ac88a509 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/GetForecasterResponse.java @@ -0,0 +1,220 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.model.ForecasterProfile; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class GetForecasterResponse extends ActionResponse implements ToXContentObject { + + public static final String FORECASTER_PROFILE = "forecasterProfile"; + public static final String ENTITY_PROFILE = "entityProfile"; + private String id; + private long version; + private long primaryTerm; + private long seqNo; + private Forecaster forecaster; + private Job forecastJob; + private ForecastTask realtimeTask; + private ForecastTask runOnceTask; + private RestStatus restStatus; + private ForecasterProfile forecasterProfile; + private EntityProfile entityProfile; + private boolean profileResponse; + private boolean returnJob; + private boolean returnTask; + + public GetForecasterResponse(StreamInput in) throws IOException { + super(in); + profileResponse = in.readBoolean(); + if (profileResponse) { + String profileType = in.readString(); + if (FORECASTER_PROFILE.equals(profileType)) { + forecasterProfile = new ForecasterProfile(in); + } else { + entityProfile = new EntityProfile(in); + } + } else { + id = in.readString(); + version = in.readLong(); + primaryTerm = in.readLong(); + seqNo = in.readLong(); + restStatus = in.readEnum(RestStatus.class); + forecaster = new Forecaster(in); + returnJob = in.readBoolean(); + if (returnJob) { + forecastJob = new Job(in); + } else { + forecastJob = null; + } + returnTask = in.readBoolean(); + if (in.readBoolean()) { + realtimeTask = new ForecastTask(in); + } else { + realtimeTask = null; + } + if (in.readBoolean()) { + runOnceTask = new ForecastTask(in); + } else { + runOnceTask = null; + } + } + + } + + public GetForecasterResponse( + String id, + long version, + long primaryTerm, + long seqNo, + Forecaster forecaster, + Job job, + boolean returnJob, + ForecastTask realtimeTask, + ForecastTask runOnceTask, + boolean returnTask, + RestStatus restStatus, + ForecasterProfile forecasterProfile, + EntityProfile entityProfile, + boolean profileResponse + ) { + this.id = id; + this.version = version; + this.primaryTerm = primaryTerm; + this.seqNo = seqNo; + this.forecaster = forecaster; + this.forecastJob = job; + this.returnJob = returnJob; + if (this.returnJob) { + this.forecastJob = job; + } else { + this.forecastJob = null; + } + this.returnTask = returnTask; + if (this.returnTask) { + this.realtimeTask = realtimeTask; + this.runOnceTask = runOnceTask; + } else { + this.realtimeTask = null; + this.runOnceTask = null; + } + this.restStatus = restStatus; + this.forecasterProfile = forecasterProfile; + this.entityProfile = entityProfile; + this.profileResponse = profileResponse; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (profileResponse) { + out.writeBoolean(true); // profileResponse is true + if (forecasterProfile != null) { + out.writeString(FORECASTER_PROFILE); + forecasterProfile.writeTo(out); + } else if (entityProfile != null) { + out.writeString(ENTITY_PROFILE); + entityProfile.writeTo(out); + } + } else { + out.writeBoolean(false); // profileResponse is false + out.writeString(id); + out.writeLong(version); + out.writeLong(primaryTerm); + out.writeLong(seqNo); + out.writeEnum(restStatus); + forecaster.writeTo(out); + if (returnJob) { + out.writeBoolean(true); // returnJob is true + forecastJob.writeTo(out); + } else { + out.writeBoolean(false); // returnJob is false + } + out.writeBoolean(returnTask); + if (realtimeTask != null) { + out.writeBoolean(true); + realtimeTask.writeTo(out); + } else { + out.writeBoolean(false); + } + if (runOnceTask != null) { + out.writeBoolean(true); + runOnceTask.writeTo(out); + } else { + out.writeBoolean(false); + } + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (profileResponse) { + if (forecasterProfile != null) { + forecasterProfile.toXContent(builder, params); + } else { + entityProfile.toXContent(builder, params); + } + } else { + builder.startObject(); + builder.field(RestHandlerUtils._ID, id); + builder.field(RestHandlerUtils._VERSION, version); + builder.field(RestHandlerUtils._PRIMARY_TERM, primaryTerm); + builder.field(RestHandlerUtils._SEQ_NO, seqNo); + builder.field(RestHandlerUtils.REST_STATUS, restStatus); + builder.field(RestHandlerUtils.FORECASTER, forecaster); + if (returnJob) { + builder.field(RestHandlerUtils.FORECASTER_JOB, forecastJob); + } + if (returnTask) { + builder.field(RestHandlerUtils.REALTIME_TASK, realtimeTask); + builder.field(RestHandlerUtils.RUN_ONCE_TASK, runOnceTask); + } + builder.endObject(); + } + return builder; + } + + public Job getForecastJob() { + return forecastJob; + } + + public ForecastTask getRealtimeTask() { + return realtimeTask; + } + + public ForecastTask getRunOnceTask() { + return runOnceTask; + } + + public Forecaster getForecaster() { + return forecaster; + } + + public ForecasterProfile getForecasterProfile() { + return forecasterProfile; + } + + public EntityProfile getEntityProfile() { + return entityProfile; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/GetForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/GetForecasterTransportAction.java new file mode 100644 index 000000000..08bfa79d0 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/GetForecasterTransportAction.java @@ -0,0 +1,150 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.util.Optional; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.ForecastEntityProfileRunner; +import org.opensearch.forecast.ForecastProfileRunner; +import org.opensearch.forecast.ForecastTaskProfileRunner; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskProfile; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.model.ForecasterProfile; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.BaseGetConfigTransportAction; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class GetForecasterTransportAction extends + BaseGetConfigTransportAction { + + @Inject + public GetForecasterTransportAction( + TransportService transportService, + DiscoveryNodeFilterer nodeFilter, + ActionFilters actionFilters, + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + Settings settings, + NamedXContentRegistry xContentRegistry, + ForecastTaskManager forecastTaskManager, + ForecastTaskProfileRunner taskProfileRunner + ) { + super( + transportService, + nodeFilter, + actionFilters, + clusterService, + client, + clientUtil, + settings, + xContentRegistry, + forecastTaskManager, + GetForecasterAction.NAME, + Forecaster.class, + Forecaster.FORECAST_PARSE_FIELD_NAME, + ForecastTaskType.ALL_FORECAST_TASK_TYPES, + ForecastTaskType.REALTIME_FORECAST_HC_FORECASTER.name(), + ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM.name(), + ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER.name(), + ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM.name(), + ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES, + taskProfileRunner + ); + } + + @Override + protected GetForecasterResponse createResponse( + long version, + String id, + long primaryTerm, + long seqNo, + Forecaster config, + Job job, + boolean returnJob, + Optional realtimeTask, + Optional historicalTask, + boolean returnTask, + RestStatus restStatus, + ForecasterProfile forecasterProfile, + EntityProfile entityProfile, + boolean profileResponse + ) { + return new GetForecasterResponse( + id, + version, + primaryTerm, + seqNo, + config, + job, + returnJob, + realtimeTask.orElse(null), + historicalTask.orElse(null), + returnTask, + restStatus, + forecasterProfile, + entityProfile, + profileResponse + ); + } + + @Override + protected ForecastEntityProfileRunner createEntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples + ) { + return new ForecastEntityProfileRunner(client, clientUtil, xContentRegistry, TimeSeriesSettings.NUM_MIN_SAMPLES); + } + + @Override + protected ForecastProfileRunner createProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeFilterer nodeFilter, + long requiredSamples, + TransportService transportService, + ForecastTaskManager taskManager, + ForecastTaskProfileRunner taskProfileRunner + ) { + return new ForecastProfileRunner( + client, + clientUtil, + xContentRegistry, + nodeFilter, + TimeSeriesSettings.NUM_MIN_SAMPLES, + transportService, + taskManager, + taskProfileRunner + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/IndexForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterAction.java new file mode 100644 index 000000000..23613a89f --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterAction.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class IndexForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/write"; + public static final IndexForecasterAction INSTANCE = new IndexForecasterAction(); + + private IndexForecasterAction() { + super(NAME, IndexForecasterResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/IndexForecasterRequest.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterRequest.java new file mode 100644 index 000000000..60a3a1964 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterRequest.java @@ -0,0 +1,144 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.rest.RestRequest; + +public class IndexForecasterRequest extends ActionRequest { + private String forecastID; + private long seqNo; + private long primaryTerm; + private WriteRequest.RefreshPolicy refreshPolicy; + private Forecaster forecaster; + private RestRequest.Method method; + private TimeValue requestTimeout; + private Integer maxSingleStreamForecasters; + private Integer maxHCForecasters; + private Integer maxForecastFeatures; + private Integer maxCategoricalFields; + + public IndexForecasterRequest(StreamInput in) throws IOException { + super(in); + forecastID = in.readString(); + seqNo = in.readLong(); + primaryTerm = in.readLong(); + refreshPolicy = in.readEnum(WriteRequest.RefreshPolicy.class); + forecaster = new Forecaster(in); + method = in.readEnum(RestRequest.Method.class); + requestTimeout = in.readTimeValue(); + maxSingleStreamForecasters = in.readInt(); + maxHCForecasters = in.readInt(); + maxForecastFeatures = in.readInt(); + maxCategoricalFields = in.readInt(); + } + + public IndexForecasterRequest( + String forecasterID, + long seqNo, + long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + Forecaster forecaster, + RestRequest.Method method, + TimeValue requestTimeout, + Integer maxSingleEntityAnomalyDetectors, + Integer maxMultiEntityAnomalyDetectors, + Integer maxAnomalyFeatures, + Integer maxCategoricalFields + ) { + super(); + this.forecastID = forecasterID; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.refreshPolicy = refreshPolicy; + this.forecaster = forecaster; + this.method = method; + this.requestTimeout = requestTimeout; + this.maxSingleStreamForecasters = maxSingleEntityAnomalyDetectors; + this.maxHCForecasters = maxMultiEntityAnomalyDetectors; + this.maxForecastFeatures = maxAnomalyFeatures; + this.maxCategoricalFields = maxCategoricalFields; + } + + public String getForecasterID() { + return forecastID; + } + + public long getSeqNo() { + return seqNo; + } + + public long getPrimaryTerm() { + return primaryTerm; + } + + public WriteRequest.RefreshPolicy getRefreshPolicy() { + return refreshPolicy; + } + + public Forecaster getForecaster() { + return forecaster; + } + + public RestRequest.Method getMethod() { + return method; + } + + public TimeValue getRequestTimeout() { + return requestTimeout; + } + + public Integer getMaxSingleStreamForecasters() { + return maxSingleStreamForecasters; + } + + public Integer getMaxHCForecasters() { + return maxHCForecasters; + } + + public Integer getMaxForecastFeatures() { + return maxForecastFeatures; + } + + public Integer getMaxCategoricalFields() { + return maxCategoricalFields; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(forecastID); + out.writeLong(seqNo); + out.writeLong(primaryTerm); + out.writeEnum(refreshPolicy); + forecaster.writeTo(out); + out.writeEnum(method); + out.writeTimeValue(requestTimeout); + out.writeInt(maxSingleStreamForecasters); + out.writeInt(maxHCForecasters); + out.writeInt(maxForecastFeatures); + out.writeInt(maxCategoricalFields); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/IndexForecasterResponse.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterResponse.java new file mode 100644 index 000000000..85362a07d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterResponse.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class IndexForecasterResponse extends ActionResponse implements ToXContentObject { + private final String id; + private final long version; + private final long seqNo; + private final long primaryTerm; + private final Forecaster forecaster; + private final RestStatus restStatus; + + public IndexForecasterResponse(StreamInput in) throws IOException { + super(in); + id = in.readString(); + version = in.readLong(); + seqNo = in.readLong(); + primaryTerm = in.readLong(); + forecaster = new Forecaster(in); + restStatus = in.readEnum(RestStatus.class); + } + + public IndexForecasterResponse(String id, long version, long seqNo, long primaryTerm, Forecaster forecaster, RestStatus restStatus) { + this.id = id; + this.version = version; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.forecaster = forecaster; + this.restStatus = restStatus; + } + + public String getId() { + return id; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + out.writeLong(version); + out.writeLong(seqNo); + out.writeLong(primaryTerm); + forecaster.writeTo(out); + out.writeEnum(restStatus); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(RestHandlerUtils._ID, id) + .field(RestHandlerUtils._VERSION, version) + .field(RestHandlerUtils._SEQ_NO, seqNo) + .field(RestHandlerUtils.FORECASTER, forecaster) + .field(RestHandlerUtils._PRIMARY_TERM, primaryTerm) + .endObject(); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/IndexForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterTransportAction.java new file mode 100644 index 000000000..c9bc28b72 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterTransportAction.java @@ -0,0 +1,223 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_CREATE_FORECASTER; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_UPDATE_FORECASTER; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; +import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; +import static org.opensearch.timeseries.util.ParseUtils.getConfig; +import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.util.List; +import java.util.function.Consumer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.rest.handler.IndexForecasterActionHandler; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class IndexForecasterTransportAction extends HandledTransportAction { + private static final Logger LOG = LogManager.getLogger(IndexForecasterTransportAction.class); + private final Client client; + private final SecurityClientUtil clientUtil; + private final TransportService transportService; + private final ForecastIndexManagement forecastIndices; + private final ClusterService clusterService; + private final NamedXContentRegistry xContentRegistry; + private volatile Boolean filterByEnabled; + private final SearchFeatureDao searchFeatureDao; + private final ForecastTaskManager taskManager; + private final Settings settings; + + @Inject + public IndexForecasterTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + Settings settings, + ForecastIndexManagement forecastIndices, + NamedXContentRegistry xContentRegistry, + SearchFeatureDao searchFeatureDao, + ForecastTaskManager taskManager + ) { + super(IndexForecasterAction.NAME, transportService, actionFilters, IndexForecasterRequest::new); + this.client = client; + this.clientUtil = clientUtil; + this.transportService = transportService; + this.clusterService = clusterService; + this.forecastIndices = forecastIndices; + this.xContentRegistry = xContentRegistry; + filterByEnabled = ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.searchFeatureDao = searchFeatureDao; + this.taskManager = taskManager; + this.settings = settings; + } + + @Override + protected void doExecute(Task task, IndexForecasterRequest request, ActionListener actionListener) { + User user = getUserContext(client); + String forecasterId = request.getForecasterID(); + RestRequest.Method method = request.getMethod(); + String errorMessage = method == RestRequest.Method.PUT ? FAIL_TO_UPDATE_FORECASTER : FAIL_TO_CREATE_FORECASTER; + ActionListener listener = wrapRestActionListener(actionListener, errorMessage); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + forecasterId, + method, + listener, + (forecaster) -> forecastExecute(request, user, forecaster, context, listener) + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + private void resolveUserAndExecute( + User requestedUser, + String forecasterId, + RestRequest.Method method, + ActionListener listener, + Consumer function + ) { + try { + // requestedUser == null means security is disabled or user is superadmin. In this case we don't need to + // check if request user have access to the forecaster or not. But we still need to get current forecaster for + // this case, so we can keep current forecaster's user data. + boolean filterByBackendRole = requestedUser == null ? false : filterByEnabled; + + // Check if user has backend roles + // When filter by is enabled, block users creating/updating detectors who do not have backend roles. + if (filterByEnabled) { + String error = checkFilterByBackendRoles(requestedUser); + if (error != null) { + listener.onFailure(new IllegalArgumentException(error)); + return; + } + } + if (method == RestRequest.Method.PUT) { + // Update forecaster request, check if user has permissions to update the forecaster + // Get forecaster and verify backend roles + getConfig( + requestedUser, + forecasterId, + listener, + function, + client, + clusterService, + xContentRegistry, + filterByBackendRole, + Forecaster.class + ); + } else { + // Create Detector. No need to get current detector. + function.accept(null); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void forecastExecute( + IndexForecasterRequest request, + User user, + Forecaster currentForecaster, + ThreadContext.StoredContext storedContext, + ActionListener listener + ) { + forecastIndices.update(); + String forecasterId = request.getForecasterID(); + long seqNo = request.getSeqNo(); + long primaryTerm = request.getPrimaryTerm(); + WriteRequest.RefreshPolicy refreshPolicy = request.getRefreshPolicy(); + Forecaster forecaster = request.getForecaster(); + RestRequest.Method method = request.getMethod(); + TimeValue requestTimeout = request.getRequestTimeout(); + Integer maxSingleStreamForecasters = request.getMaxSingleStreamForecasters(); + Integer maxHCForecasters = request.getMaxHCForecasters(); + Integer maxForecastFeatures = request.getMaxForecastFeatures(); + Integer maxCategoricalFields = request.getMaxCategoricalFields(); + + storedContext.restore(); + checkIndicesAndExecute(forecaster.getIndices(), () -> { + // Don't replace forecaster's user when update detector + // Github issue: https://github.com/opensearch-project/anomaly-detection/issues/124 + User forecastUser = currentForecaster == null ? user : currentForecaster.getUser(); + IndexForecasterActionHandler indexForecasterActionHandler = new IndexForecasterActionHandler( + clusterService, + client, + clientUtil, + transportService, + forecastIndices, + forecasterId, + seqNo, + primaryTerm, + refreshPolicy, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry, + forecastUser, + taskManager, + searchFeatureDao, + settings + ); + indexForecasterActionHandler.start(listener); + }, listener); + } + + private void checkIndicesAndExecute(List indices, ExecutorFunction function, ActionListener listener) { + SearchRequest searchRequest = new SearchRequest() + .indices(indices.toArray(new String[0])) + .source(new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery())); + client.search(searchRequest, ActionListener.wrap(r -> { function.execute(); }, e -> { + // Due to below issue with security plugin, we get security_exception when invalid index name is mentioned. + // https://github.com/opendistro-for-elasticsearch/security/issues/718 + LOG.error(e); + listener.onFailure(e); + })); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/RelationalOperation.java b/src/main/java/org/opensearch/forecast/transport/RelationalOperation.java new file mode 100644 index 000000000..853767ba0 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/RelationalOperation.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +public enum RelationalOperation { + GREATER_THAN(">"), + GREATER_THAN_OR_EQUAL_TO(">="), + LESS_THAN("<"), + LESS_THAN_OR_EQUAL_TO("<="); + + private final String symbol; + + RelationalOperation(String symbol) { + this.symbol = symbol; + } + + public String getSymbol() { + return this.symbol; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksAction.java new file mode 100644 index 000000000..1dda427a0 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class SearchForecastTasksAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "tasks/search"; + public static final SearchForecastTasksAction INSTANCE = new SearchForecastTasksAction(); + + private SearchForecastTasksAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksTransportAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksTransportAction.java new file mode 100644 index 000000000..5545e7668 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksTransportAction.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.transport.handler.ForecastSearchHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class SearchForecastTasksTransportAction extends HandledTransportAction { + private ForecastSearchHandler searchHandler; + + @Inject + public SearchForecastTasksTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ForecastSearchHandler searchHandler + ) { + super(SearchForecastTasksAction.NAME, transportService, actionFilters, SearchRequest::new); + this.searchHandler = searchHandler; + } + + @Override + protected void doExecute(Task task, SearchRequest request, ActionListener listener) { + searchHandler.search(request, listener); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecasterAction.java new file mode 100644 index 000000000..b4777a4b7 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecasterAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class SearchForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecasters/search"; + public static final SearchForecasterAction INSTANCE = new SearchForecasterAction(); + + private SearchForecasterAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoAction.java new file mode 100644 index 000000000..ba5aec5cc --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.SearchConfigInfoResponse; + +public class SearchForecasterInfoAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/info"; + public static final SearchForecasterInfoAction INSTANCE = new SearchForecasterInfoAction(); + + private SearchForecasterInfoAction() { + super(NAME, SearchConfigInfoResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoTransportAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoTransportAction.java new file mode 100644 index 000000000..7131c65a4 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoTransportAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.timeseries.transport.BaseSearchConfigInfoTransportAction; +import org.opensearch.transport.TransportService; + +public class SearchForecasterInfoTransportAction extends BaseSearchConfigInfoTransportAction { + + @Inject + public SearchForecasterInfoTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { + super(transportService, actionFilters, client, SearchForecasterInfoAction.NAME); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecasterTransportAction.java new file mode 100644 index 000000000..b53d09b76 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecasterTransportAction.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.transport.handler.ForecastSearchHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class SearchForecasterTransportAction extends HandledTransportAction { + private ForecastSearchHandler searchHandler; + + @Inject + public SearchForecasterTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ForecastSearchHandler searchHandler + ) { + super(SearchForecasterAction.NAME, transportService, actionFilters, SearchRequest::new); + this.searchHandler = searchHandler; + } + + @Override + protected void doExecute(Task task, SearchRequest request, ActionListener listener) { + searchHandler.search(request, listener); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultAction.java b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultAction.java new file mode 100644 index 000000000..831c42b69 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultAction.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class SearchTopForecastResultAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "result/topForecasts"; + public static final SearchTopForecastResultAction INSTANCE = new SearchTopForecastResultAction(); + + private SearchTopForecastResultAction() { + super(NAME, SearchTopForecastResultResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultRequest.java b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultRequest.java new file mode 100644 index 000000000..d605d1ff8 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultRequest.java @@ -0,0 +1,448 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParseException; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.FilterBy; +import org.opensearch.forecast.model.Subaggregation; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.util.ParseUtils; + +/** + * Request for getting the top forecast results for HC forecasters. + *

+ * forecasterId, filterBy, and forecastFrom are required. + * One or two of buildInQuery, entity, threshold, filterQuery, subaggregations will be set to + * appropriate value depending on filterBy. + * Other parameters will be set to default values if left blank. + */ +public class SearchTopForecastResultRequest extends ActionRequest implements ToXContentObject { + + private static final String TASK_ID_FIELD = "task_id"; + private static final String SIZE_FIELD = "size"; + private static final String SPLIT_BY_FIELD = "split_by"; + private static final String FILTER_BY_FIELD = "filter_by"; + private static final String BUILD_IN_QUERY_FIELD = "build_in_query"; + private static final String THRESHOLD_FIELD = "threshold"; + private static final String RELATION_TO_THRESHOLD_FIELD = "relation_to_threshold"; + private static final String FILTER_QUERY_FIELD = "filter_query"; + public static final String SUBAGGREGATIONS_FIELD = "subaggregations"; + // forecast from looks for data end time + private static final String FORECAST_FROM_FIELD = "forecast_from"; + private static final String RUN_ONCE_FIELD = "run_once"; + + private String forecasterId; + private String taskId; + private boolean runOnce; + private Integer size; + private List splitBy; + private FilterBy filterBy; + private BuildInQuery buildInQuery; + private Float threshold; + private RelationalOperation relationToThreshold; + private QueryBuilder filterQuery; + private List subaggregations; + private Instant forecastFrom; + + public SearchTopForecastResultRequest(StreamInput in) throws IOException { + super(in); + forecasterId = in.readOptionalString(); + taskId = in.readOptionalString(); + runOnce = in.readBoolean(); + size = in.readOptionalInt(); + splitBy = in.readOptionalStringList(); + if (in.readBoolean()) { + filterBy = in.readEnum(FilterBy.class); + } else { + filterBy = null; + } + if (in.readBoolean()) { + buildInQuery = in.readEnum(BuildInQuery.class); + } else { + buildInQuery = null; + } + threshold = in.readOptionalFloat(); + if (in.readBoolean()) { + relationToThreshold = in.readEnum(RelationalOperation.class); + } else { + relationToThreshold = null; + } + if (in.readBoolean()) { + filterQuery = in.readNamedWriteable(QueryBuilder.class); + } else { + filterQuery = null; + } + if (in.readBoolean()) { + subaggregations = in.readList(Subaggregation::new); + } else { + subaggregations = null; + } + forecastFrom = in.readOptionalInstant(); + } + + public SearchTopForecastResultRequest( + String forecasterId, + String taskId, + boolean runOnce, + Integer size, + List splitBy, + FilterBy filterBy, + BuildInQuery buildInQuery, + Float threshold, + RelationalOperation relationToThreshold, + QueryBuilder filterQuery, + List subaggregations, + Instant forecastFrom + ) { + super(); + this.forecasterId = forecasterId; + this.taskId = taskId; + this.runOnce = runOnce; + this.size = size; + this.splitBy = splitBy; + this.filterBy = filterBy; + this.buildInQuery = buildInQuery; + this.threshold = threshold; + this.relationToThreshold = relationToThreshold; + this.filterQuery = filterQuery; + this.subaggregations = subaggregations; + this.forecastFrom = forecastFrom; + } + + public String getTaskId() { + return taskId; + } + + public boolean isRunOnce() { + return runOnce; + } + + public Integer getSize() { + return size; + } + + public String getForecasterId() { + return forecasterId; + } + + public List getSplitBy() { + return splitBy; + } + + public FilterBy getFilterBy() { + return filterBy; + } + + public BuildInQuery getBuildInQuery() { + return buildInQuery; + } + + public Float getThreshold() { + return threshold; + } + + public QueryBuilder getFilterQuery() { + return filterQuery; + } + + public List getSubaggregations() { + return subaggregations; + } + + public Instant getForecastFrom() { + return forecastFrom; + } + + public RelationalOperation getRelationToThreshold() { + return relationToThreshold; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public void setSize(Integer size) { + this.size = size; + } + + public void setForecasterId(String forecasterId) { + this.forecasterId = forecasterId; + } + + public void setRunOnce(boolean runOnce) { + this.runOnce = runOnce; + } + + public void setSplitBy(List splitBy) { + this.splitBy = splitBy; + } + + public void setFilterBy(FilterBy filterBy) { + this.filterBy = filterBy; + } + + public void setBuildInQuery(BuildInQuery buildInQuery) { + this.buildInQuery = buildInQuery; + } + + public void setThreshold(Float threshold) { + this.threshold = threshold; + } + + public void setFilterQuery(QueryBuilder filterQuery) { + this.filterQuery = filterQuery; + } + + public void setSubaggregations(List subaggregations) { + this.subaggregations = subaggregations; + } + + public void setForecastFrom(Instant forecastFrom) { + this.forecastFrom = forecastFrom; + } + + public void setRelationToThreshold(RelationalOperation relationToThreshold) { + this.relationToThreshold = relationToThreshold; + } + + public static SearchTopForecastResultRequest parse(XContentParser parser, String forecasterId) throws IOException { + String taskId = null; + Integer size = null; + List splitBy = null; + FilterBy filterBy = null; + BuildInQuery buildInQuery = null; + Float threshold = null; + RelationalOperation relationToThreshold = null; + QueryBuilder filterQuery = null; + List subaggregations = new ArrayList<>(); + Instant forecastFrom = null; + boolean runOnce = false; + + // "forecasterId" and "historical" params come from the original API path, not in the request body + // and therefore don't need to be parsed + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TASK_ID_FIELD: + taskId = parser.text(); + break; + case SIZE_FIELD: + size = parser.intValue(); + break; + case SPLIT_BY_FIELD: + splitBy = Arrays.asList(parser.text().split(",")); + break; + case FILTER_BY_FIELD: + filterBy = FilterBy.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case BUILD_IN_QUERY_FIELD: + buildInQuery = BuildInQuery.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case THRESHOLD_FIELD: + threshold = parser.floatValue(); + break; + case RELATION_TO_THRESHOLD_FIELD: + relationToThreshold = RelationalOperation.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case FILTER_QUERY_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + try { + filterQuery = parseInnerQueryBuilder(parser); + } catch (ParsingException | XContentParseException e) { + throw new ValidationException( + "Custom query error in data filter: " + e.getMessage(), + ValidationIssueType.FILTER_QUERY, + ValidationAspect.FORECASTER + ); + } catch (IllegalArgumentException e) { + if (!e.getMessage().contains("empty clause")) { + throw e; + } + } + break; + case SUBAGGREGATIONS_FIELD: + try { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + subaggregations.add(Subaggregation.parse(parser)); + } + } catch (Exception e) { + if (e instanceof ParsingException || e instanceof XContentParseException) { + throw new ValidationException( + "Custom query error: " + e.getMessage(), + ValidationIssueType.SUBAGGREGATION, + ValidationAspect.FORECASTER + ); + } + throw e; + } + break; + case FORECAST_FROM_FIELD: + forecastFrom = ParseUtils.toInstant(parser); + break; + case RUN_ONCE_FIELD: + runOnce = parser.booleanValue(); + break; + default: + parser.skipChildren(); + break; + } + } + + return new SearchTopForecastResultRequest( + forecasterId, + taskId, + runOnce, + size, + splitBy, + filterBy, + buildInQuery, + threshold, + relationToThreshold, + filterQuery, + subaggregations, + forecastFrom + ); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // "forecasterId" and "historical" params come from the original API path, not in the request body + // and therefore don't need to be in the generated json + builder + .field(TASK_ID_FIELD, taskId) + .field(SPLIT_BY_FIELD, String.join(",", splitBy)) + .field(FILTER_BY_FIELD, filterBy.name()) + .field(RUN_ONCE_FIELD, runOnce); + + if (size != null) { + builder.field(SIZE_FIELD, size); + } + if (buildInQuery != null) { + builder.field(BUILD_IN_QUERY_FIELD, buildInQuery); + } + if (threshold != null) { + builder.field(THRESHOLD_FIELD, threshold); + } + if (relationToThreshold != null) { + builder.field(RELATION_TO_THRESHOLD_FIELD, relationToThreshold); + } + if (filterQuery != null) { + builder.field(FILTER_QUERY_FIELD, filterQuery); + } + if (subaggregations != null) { + builder.field(SUBAGGREGATIONS_FIELD, subaggregations.toArray()); + } + if (forecastFrom != null) { + builder.field(FORECAST_FROM_FIELD, forecastFrom.toString()); + } + + return builder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(forecasterId); + out.writeOptionalString(taskId); + out.writeBoolean(runOnce); + out.writeOptionalInt(size); + out.writeOptionalStringCollection(splitBy); + if (filterBy == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeEnum(filterBy); + } + if (buildInQuery == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeEnum(buildInQuery); + } + out.writeOptionalFloat(threshold); + if (relationToThreshold == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeEnum(relationToThreshold); + } + if (filterQuery == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeNamedWriteable(filterQuery); + } + if (subaggregations == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeList(subaggregations); + } + out.writeOptionalInstant(forecastFrom); + } + + @Override + public ActionRequestValidationException validate() { + if (forecasterId == null) { + return addValidationError("Cannot find forecasterId", null); + } + if (filterBy == null) { + return addValidationError("Must set filter_by", null); + } + if (forecastFrom == null) { + return addValidationError("Must set forecast_from with epoch of milliseconds", null); + } + if (!((filterBy == FilterBy.BUILD_IN_QUERY) == (buildInQuery != null))) { + throw new IllegalArgumentException( + "If 'filter_by' is set to BUILD_IN_QUERY, a 'build_in_query' type must be provided. Otherwise, 'build_in_query' should not be given." + ); + } + + if (filterBy == FilterBy.BUILD_IN_QUERY + && buildInQuery == BuildInQuery.DISTANCE_TO_THRESHOLD_VALUE + && (threshold == null || relationToThreshold == null)) { + return addValidationError( + String + .format(Locale.ROOT, "Must set threshold and relation_to_threshold, but get %s and %s", threshold, relationToThreshold), + null + ); + } + if (filterBy == FilterBy.CUSTOM_QUERY && (subaggregations == null || subaggregations.isEmpty())) { + return addValidationError("Must set subaggregations", null); + } + if (!runOnce && !Strings.isNullOrEmpty(taskId)) { + return addValidationError("task id must not be set when run_once is false", null); + } + return null; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultResponse.java b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultResponse.java new file mode 100644 index 000000000..ac4e5ed56 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultResponse.java @@ -0,0 +1,55 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.ForecastResultBucket; + +/** + * Response for getting the top anomaly results for HC detectors + */ +public class SearchTopForecastResultResponse extends ActionResponse implements ToXContentObject { + public static final String BUCKETS_FIELD = "buckets"; + + private List forecastResultBuckets; + + public SearchTopForecastResultResponse(StreamInput in) throws IOException { + super(in); + forecastResultBuckets = in.readList(ForecastResultBucket::new); + } + + public SearchTopForecastResultResponse(List forecastResultBuckets) { + this.forecastResultBuckets = forecastResultBuckets; + } + + public List getForecastResultBuckets() { + return forecastResultBuckets; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(forecastResultBuckets); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // no need to show bucket with empty value + return builder.startObject().field(BUCKETS_FIELD, forecastResultBuckets).endObject(); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultTransportAction.java new file mode 100644 index 000000000..069914f86 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultTransportAction.java @@ -0,0 +1,605 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.forecast.indices.ForecastIndexManagement.ALL_FORECAST_RESULTS_INDEX_PATTERN; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.routing.Preference; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.model.FilterBy; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastResultBucket; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.model.Order; +import org.opensearch.forecast.model.Subaggregation; +import org.opensearch.forecast.transport.handler.ForecastSearchHandler; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; +import org.opensearch.search.aggregations.bucket.terms.Terms; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.QueryUtil; +import org.opensearch.transport.TransportService; + +/** + * Transport action to fetch top forecast results for HC forecaster. + */ +public class SearchTopForecastResultTransportAction extends + HandledTransportAction { + private static final Logger logger = LogManager.getLogger(SearchTopForecastResultTransportAction.class); + private ForecastSearchHandler searchHandler; + // Number of buckets to return per page + private static final String defaultIndex = ALL_FORECAST_RESULTS_INDEX_PATTERN; + + private static final int DEFAULT_SIZE = 5; + private static final int MAX_SIZE = 50; + + protected static final String AGG_NAME_TERM = "term_agg"; + + private final Client client; + private NamedXContentRegistry xContent; + + @Inject + public SearchTopForecastResultTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ForecastSearchHandler searchHandler, + Client client, + NamedXContentRegistry xContent + ) { + super(SearchTopForecastResultAction.NAME, transportService, actionFilters, SearchTopForecastResultRequest::new); + this.searchHandler = searchHandler; + this.client = client; + this.xContent = xContent; + } + + @Override + protected void doExecute(Task task, SearchTopForecastResultRequest request, ActionListener listener) { + GetConfigRequest getForecasterRequest = new GetConfigRequest( + request.getForecasterId(), + // The default version value used in + // org.opensearch.rest.action.RestActions.parseVersion() + -3L, + false, + true, + "", + "", + false, + null + ); + + client.execute(GetForecasterAction.INSTANCE, getForecasterRequest, ActionListener.wrap(getForecasterResponse -> { + // Make sure forecaster exists + if (getForecasterResponse.getForecaster() == null) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "No forecaster found with ID %s", request.getForecasterId())); + } + + Forecaster forecaster = getForecasterResponse.getForecaster(); + // Make sure forecaster is HC + List categoryFields = forecaster.getCategoryFields(); + if (categoryFields == null || categoryFields.isEmpty()) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "No category fields found for forecaster ID %s", request.getForecasterId()) + ); + } + + // Validating the category fields. Setting the list to be all category fields, + // unless otherwise specified + if (request.getSplitBy() == null || request.getSplitBy().isEmpty()) { + request.setSplitBy(categoryFields); + } else { + for (String categoryField : request.getSplitBy()) { + if (!categoryFields.contains(categoryField)) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Category field %s doesn't exist for forecaster ID %s", + categoryField, + request.getForecasterId() + ) + ); + } + } + } + + // Validating run once tasks if runOnce is true. Setting the task id to the + // latest run once task's ID, unless otherwise specified + if (request.isRunOnce() == true && Strings.isNullOrEmpty(request.getTaskId())) { + ForecastTask runOnceTask = getForecasterResponse.getRunOnceTask(); + if (runOnceTask == null) { + throw new ResourceNotFoundException( + String.format(Locale.ROOT, "No latest run once tasks found for forecaster ID %s", request.getForecasterId()) + ); + } + request.setTaskId(runOnceTask.getTaskId()); + } + + // Validating the size. If nothing passed use default + if (request.getSize() == null) { + request.setSize(DEFAULT_SIZE); + } else if (request.getSize() > MAX_SIZE) { + throw new IllegalArgumentException("Size cannot exceed " + MAX_SIZE); + } else if (request.getSize() <= 0) { + throw new IllegalArgumentException("Size must be a positive integer"); + } + + // Generating the search request which will contain the generated query + SearchRequest searchRequest = generateQuery(request, forecaster); + + // Adding search over any custom result indices + if (!Strings.isNullOrEmpty(forecaster.getCustomResultIndex())) { + searchRequest.indices(forecaster.getCustomResultIndex()); + } + // Utilizing the existing search() from SearchHandler to handle security + // permissions. Both user role + // and backend role filtering is handled in there, and any error will be + // propagated up and + // returned as a failure in this Listener. + // This same method is used for security handling for the search results action. + // Since this action + // is doing fundamentally the same thing, we can reuse the security logic here. + searchHandler.search(searchRequest, onSearchResponse(request, categoryFields, forecaster, listener)); + }, exception -> { + logger.error("Failed to get top forecast results", exception); + listener.onFailure(exception); + })); + + } + + private ActionListener onSearchResponse( + SearchTopForecastResultRequest request, + List categoryFields, + Forecaster forecaster, + ActionListener listener + ) { + return ActionListener.wrap(response -> { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + // empty result (e.g., cannot find forecasts within [forecast from, forecast from + horizon * interval] range). + listener.onResponse(new SearchTopForecastResultResponse(new ArrayList<>())); + return; + } + + Aggregation aggResults = aggs.get(AGG_NAME_TERM); + if (aggResults == null) { + // empty result + listener.onResponse(new SearchTopForecastResultResponse(new ArrayList<>())); + return; + } + + List buckets = ((MultiBucketsAggregation) aggResults).getBuckets(); + if (buckets == null || buckets.size() == 0) { + // empty result + listener + .onFailure( + new ResourceNotFoundException( + "No forecast value found. forecast_from timestamp or other parameters might be incorrect." + ) + ); + return; + } + + final GroupedActionListener groupListeneer = new GroupedActionListener<>(ActionListener.wrap(r -> { + // Keep original bucket order + // Sort the collection based on getBucketIndex() in ascending order + // and convert it to a List + List sortedList = r + .stream() + .sorted((a, b) -> Integer.compare(a.getBucketIndex(), b.getBucketIndex())) + .collect(Collectors.toList()); + listener.onResponse(new SearchTopForecastResultResponse(new ArrayList<>(sortedList))); + }, exception -> { + logger.warn("Failed to find valid aggregation result", exception); + listener + .onFailure(new OpenSearchStatusException("Failed to find valid aggregation result", RestStatus.INTERNAL_SERVER_ERROR)); + }), buckets.size()); + + for (int i = 0; i < buckets.size(); i++) { + MultiBucketsAggregation.Bucket bucket = buckets.get(i); + createForecastResultBucket(bucket, i, request, categoryFields, forecaster, groupListeneer); + } + }, e -> listener.onFailure(e)); + } + + public void createForecastResultBucket( + MultiBucketsAggregation.Bucket bucket, + int bucketIndex, + SearchTopForecastResultRequest request, + List categoryFields, + Forecaster forecaster, + ActionListener listener + ) { + Map aggregationsMap = new HashMap<>(); + for (Aggregation aggregation : bucket.getAggregations()) { + if (!(aggregation instanceof NumericMetricsAggregation.SingleValue)) { + listener + .onFailure( + new IllegalArgumentException( + String.format(Locale.ROOT, "A single value aggregation is required; received [{}]", aggregation) + ) + ); + } + NumericMetricsAggregation.SingleValue singleValueAggregation = (NumericMetricsAggregation.SingleValue) aggregation; + aggregationsMap.put(aggregation.getName(), singleValueAggregation.value()); + } + if (bucket instanceof Terms.Bucket) { + // our terms key is string + convertToCategoricalFieldValuePair( + (String) bucket.getKey(), + bucketIndex, + (int) bucket.getDocCount(), + aggregationsMap, + request, + categoryFields, + forecaster, + listener + ); + } else { + listener + .onFailure( + new IllegalArgumentException(String.format(Locale.ROOT, "We only use terms aggregation in top, but got %s", bucket)) + ); + } + } + + private void convertToCategoricalFieldValuePair( + String keyInSearchResponse, + int bucketIndex, + int docCount, + Map aggregations, + SearchTopForecastResultRequest request, + List categoryFields, + Forecaster forecaster, + ActionListener listener + ) { + List splitBy = request.getSplitBy(); + Map keys = new HashMap<>(); + // TODO: we only support two categorical fields. Expand to support more categorical fields + if (splitBy == null || splitBy.size() == categoryFields.size()) { + // use all categorical fields in splitBy. Convert entity id to concrete attributes. + findMatchingCategoricalFieldValuePair(keyInSearchResponse, docCount, aggregations, bucketIndex, forecaster, listener); + } else { + keys.put(splitBy.get(0), keyInSearchResponse); + listener.onResponse(new ForecastResultBucket(keys, docCount, aggregations, bucketIndex)); + } + } + + private void findMatchingCategoricalFieldValuePair( + String entityId, + int docCount, + Map aggregations, + int bucketIndex, + Forecaster forecaster, + ActionListener listener + ) { + TermQueryBuilder entityIdFilter = QueryBuilders.termQuery(CommonName.ENTITY_ID_FIELD, entityId); + + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(entityIdFilter); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(internalFilterQuery).size(1); + + String resultIndex = Strings.isNullOrEmpty(forecaster.getCustomResultIndex()) ? defaultIndex : forecaster.getCustomResultIndex(); + SearchRequest searchRequest = new SearchRequest() + .indices(resultIndex) + .source(searchSourceBuilder) + .preference(Preference.LOCAL.toString()); + + String failure = String.format(Locale.ROOT, "Cannot find a result matching entity id %s", entityId); + + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { + try { + SearchHit[] hits = searchResponse.getHits().getHits(); + if (hits.length == 0) { + listener.onFailure(new IllegalArgumentException(failure)); + return; + } + SearchHit searchHit = hits[0]; + try (XContentParser parser = createXContentParserFromRegistry(xContent, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Optional entity = ForecastResult.parse(parser).getEntity(); + if (entity.isEmpty()) { + listener.onFailure(new IllegalArgumentException(failure)); + return; + } + + listener + .onResponse( + new ForecastResultBucket(convertMap(entity.get().getAttributes()), docCount, aggregations, bucketIndex) + ); + } catch (Exception e) { + listener.onFailure(new IllegalArgumentException(failure, e)); + } + } catch (Exception e) { + listener.onFailure(new IllegalArgumentException(failure, e)); + } + }, e -> listener.onFailure(new IllegalArgumentException(failure, e))); + + searchHandler.search(searchRequest, searchResponseListener); + } + + private Map convertMap(Map stringMap) { + // Create a new Map and copy the entries + Map objectMap = new HashMap<>(); + for (Map.Entry entry : stringMap.entrySet()) { + objectMap.put(entry.getKey(), entry.getValue()); + } + return objectMap; + } + + /** + * Generates the entire search request to pass to the search handler + * + * @param request the request containing the all of the user-specified + * parameters needed to generate the request + * @param forecaster Forecaster config + * @return the SearchRequest to pass to the SearchHandler + */ + private SearchRequest generateQuery(SearchTopForecastResultRequest request, Forecaster forecaster) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + QueryBuilder rangeQuery = generateDateFilter(request, forecaster); + boolQueryBuilder = boolQueryBuilder.filter(rangeQuery); + + // we only look for documents containing forecasts + boolQueryBuilder.filter(new ExistsQueryBuilder(ForecastResult.VALUE_FIELD)); + + FilterBy filterBy = request.getFilterBy(); + switch (filterBy) { + case CUSTOM_QUERY: + if (request.getFilterQuery() != null) { + boolQueryBuilder = boolQueryBuilder.filter(request.getFilterQuery()); + } + break; + case BUILD_IN_QUERY: + QueryBuilder buildInSubFilter = generateBuildInSubFilter(request, forecaster); + if (buildInSubFilter != null) { + boolQueryBuilder = boolQueryBuilder.filter(buildInSubFilter); + } + break; + default: + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unexpected filter by %s", request.getFilterBy())); + } + + boolQueryBuilder = generateTaskIdFilter(request, boolQueryBuilder); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQueryBuilder).trackTotalHits(false).size(0); + + AggregationBuilder termsAgg = generateTermsAggregation(request, forecaster); + if (termsAgg != null) { + searchSourceBuilder = searchSourceBuilder.aggregation(termsAgg); + } + return new SearchRequest().indices(defaultIndex).source(searchSourceBuilder); + } + + private QueryBuilder generateBuildInSubFilter(SearchTopForecastResultRequest request, Forecaster forecaster) { + BuildInQuery buildInQuery = request.getBuildInQuery(); + switch (buildInQuery) { + case MIN_CONFIDENCE_INTERVAL_WIDTH: + case MAX_CONFIDENCE_INTERVAL_WIDTH: + // Include only documents where horizon_index is configured horizon (indicating the "latest" forecast). + return QueryBuilders.termQuery(ForecastResult.HORIZON_INDEX_FIELD, forecaster.getHorizon()); + case DISTANCE_TO_THRESHOLD_VALUE: + RangeQueryBuilder res = QueryBuilders.rangeQuery(ForecastResult.VALUE_FIELD); + Float threshold = request.getThreshold(); + switch (request.getRelationToThreshold()) { + case GREATER_THAN: + res = res.gt(threshold); + break; + case GREATER_THAN_OR_EQUAL_TO: + res = res.gte(threshold); + break; + case LESS_THAN: + res = res.lt(threshold); + break; + case LESS_THAN_OR_EQUAL_TO: + res = res.lte(threshold); + break; + } + return res; + default: + // no need to generate filter in cases like MIN_VALUE_WITHIN_THE_HORIZON + return null; + } + } + + /** + * Adding the date filter (needed regardless of filter by type) + * @param request top forecaster request + * @return filter for date + */ + private RangeQueryBuilder generateDateFilter(SearchTopForecastResultRequest request, Forecaster forecaster) { + // forecast from is data end time for forecast + // return QueryBuilders.termQuery(CommonName.DATA_END_TIME_FIELD, request.getForecastFrom().toEpochMilli()); + long startInclusive = request.getForecastFrom().toEpochMilli(); + long endExclusive = startInclusive + forecaster.getIntervalInMilliseconds(); + return QueryBuilders.rangeQuery(CommonName.DATA_END_TIME_FIELD).gte(startInclusive).lt(endExclusive); + } + + /** + * Generates the query with appropriate filters on the results indices. If + * fetching real-time results: must_not filter on task_id (because real-time + * results don't have a 'task_id' field associated with them in the document). + * If fetching historical results: term filter on the task_id. + * + * @param request the request containing the necessary fields to generate the query + * @param query Bool query to generate + * @return input bool query with added id related filter + */ + private BoolQueryBuilder generateTaskIdFilter(SearchTopForecastResultRequest request, BoolQueryBuilder query) { + if (!Strings.isNullOrEmpty(request.getTaskId())) { + query.filter(QueryBuilders.termQuery(CommonName.TASK_ID_FIELD, request.getTaskId())); + } else { + TermQueryBuilder forecasterIdFilter = QueryBuilders.termQuery(ForecastCommonName.FORECASTER_ID_KEY, request.getForecasterId()); + ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); + query.filter(forecasterIdFilter).mustNot(taskIdExistsFilter); + } + return query; + } + + /** + * Generates aggregation. Creating a list of sources based on the + * set of category fields, and sorting on the returned result buckets + * + * @param request the request containing the necessary fields to generate the + * aggregation + * @return the generated aggregation as an AggregationBuilder + */ + private TermsAggregationBuilder generateTermsAggregation(SearchTopForecastResultRequest request, Forecaster forecaster) { + // TODO: use multi_terms or composite when multiple categorical fields are required. + // Right now, since we only support two categorical fields, we either use terms + // aggregation for one categorical field or terms aggregation on entity_id for + // all categorical fields. + TermsAggregationBuilder termsAgg = AggregationBuilders.terms(AGG_NAME_TERM).size(request.getSize()); + + if (request.getSplitBy().size() == forecaster.getCategoryFields().size()) { + termsAgg = termsAgg.field(CommonName.ENTITY_ID_FIELD); + } else if (request.getSplitBy().size() == 1) { + termsAgg = termsAgg.script(QueryUtil.getScriptForCategoryField(request.getSplitBy().get(0))); + } + + List orders = new ArrayList<>(); + + FilterBy filterBy = request.getFilterBy(); + switch (filterBy) { + case BUILD_IN_QUERY: + Pair aggregationOrderPair = generateBuildInSubAggregation(request); + termsAgg.subAggregation(aggregationOrderPair.getLeft()); + orders.add(aggregationOrderPair.getRight()); + break; + case CUSTOM_QUERY: + // if customers defined customized aggregation + for (Subaggregation subaggregation : request.getSubaggregations()) { + AggregatorFactories.Builder internalAgg; + try { + internalAgg = ParseUtils.parseAggregators(subaggregation.getAggregation().toString(), xContent, null); + AggregationBuilder aggregation = internalAgg.getAggregatorFactories().iterator().next(); + termsAgg.subAggregation(aggregation); + orders.add(BucketOrder.aggregation(aggregation.getName(), subaggregation.getOrder() == Order.ASC ? true : false)); + } catch (IOException e) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Unexpected IOException when parsing %s", subaggregation), + e + ); + } + } + break; + default: + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unexpected filter by %s", filterBy)); + } + + if (orders.isEmpty()) { + throw new IllegalArgumentException("Cannot have empty order list"); + } + + termsAgg.order(orders); + + return termsAgg; + } + + private Pair generateBuildInSubAggregation(SearchTopForecastResultRequest request) { + String aggregationName = null; + AggregationBuilder aggregation = null; + BucketOrder order = null; + BuildInQuery buildInQuery = request.getBuildInQuery(); + switch (buildInQuery) { + case MIN_CONFIDENCE_INTERVAL_WIDTH: + aggregationName = BuildInQuery.MIN_CONFIDENCE_INTERVAL_WIDTH.name(); + aggregation = AggregationBuilders.min(aggregationName).field(ForecastResult.INTERVAL_WIDTH_FIELD); + order = BucketOrder.aggregation(aggregationName, true); + return Pair.of(aggregation, order); + case MAX_CONFIDENCE_INTERVAL_WIDTH: + aggregationName = BuildInQuery.MAX_CONFIDENCE_INTERVAL_WIDTH.name(); + aggregation = AggregationBuilders.max(aggregationName).field(ForecastResult.INTERVAL_WIDTH_FIELD); + order = BucketOrder.aggregation(aggregationName, false); + return Pair.of(aggregation, order); + case MIN_VALUE_WITHIN_THE_HORIZON: + aggregationName = BuildInQuery.MIN_VALUE_WITHIN_THE_HORIZON.name(); + aggregation = AggregationBuilders.min(aggregationName).field(ForecastResult.VALUE_FIELD); + order = BucketOrder.aggregation(aggregationName, true); + return Pair.of(aggregation, order); + case MAX_VALUE_WITHIN_THE_HORIZON: + aggregationName = BuildInQuery.MAX_VALUE_WITHIN_THE_HORIZON.name(); + aggregation = AggregationBuilders.max(aggregationName).field(ForecastResult.VALUE_FIELD); + order = BucketOrder.aggregation(aggregationName, false); + return Pair.of(aggregation, order); + case DISTANCE_TO_THRESHOLD_VALUE: + RelationalOperation relationToThreshold = request.getRelationToThreshold(); + switch (relationToThreshold) { + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL_TO: + aggregationName = BuildInQuery.DISTANCE_TO_THRESHOLD_VALUE.name(); + aggregation = AggregationBuilders.max(aggregationName).field(ForecastResult.VALUE_FIELD); + order = BucketOrder.aggregation(aggregationName, false); + return Pair.of(aggregation, order); + case LESS_THAN: + case LESS_THAN_OR_EQUAL_TO: + aggregationName = BuildInQuery.DISTANCE_TO_THRESHOLD_VALUE.name(); + aggregation = AggregationBuilders.min(aggregationName).field(ForecastResult.VALUE_FIELD); + order = BucketOrder.aggregation(aggregationName, true); + return Pair.of(aggregation, order); + default: + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Unexpected relation to threshold %s", relationToThreshold) + ); + } + default: + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unexpected build in query type %s", buildInQuery)); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/StatsForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/StatsForecasterAction.java new file mode 100644 index 000000000..951850a67 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/StatsForecasterAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.StatsTimeSeriesResponse; + +public class StatsForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/stats"; + public static final StatsForecasterAction INSTANCE = new StatsForecasterAction(); + + private StatsForecasterAction() { + super(NAME, StatsTimeSeriesResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/StatsForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/StatsForecasterTransportAction.java new file mode 100644 index 000000000..d2c6d3619 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/StatsForecasterTransportAction.java @@ -0,0 +1,129 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.bucket.SingleBucketAggregation; +import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.BaseStatsTransportAction; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.transport.StatsResponse; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.transport.TransportService; + +public class StatsForecasterTransportAction extends BaseStatsTransportAction { + public final Logger logger = LogManager.getLogger(StatsForecasterTransportAction.class); + private final String WITH_CATEGORY_FIELD = "with_category_field"; + private final String WITHOUT_CATEGORY_FIELD = "without_category_field"; + + @Inject + public StatsForecasterTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ForecastStats stats, + ClusterService clusterService + + ) { + super(transportService, actionFilters, client, stats, clusterService, StatsForecasterAction.NAME); + } + + /** + * Make async request to get the number of detectors in AnomalyDetector.ANOMALY_DETECTORS_INDEX if necessary + * and, onResponse, gather the cluster statistics + * + * @param client Client + * @param listener MultiResponsesDelegateActionListener to be used once both requests complete + * @param statsRequest Request containing stats to be retrieved + */ + @Override + public void getClusterStats(Client client, MultiResponsesDelegateActionListener listener, StatsRequest statsRequest) { + StatsResponse adStatsResponse = new StatsResponse(); + if ((statsRequest.getStatsToBeRetrieved().contains(StatNames.FORECASTER_COUNT.getName()) + || statsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_STREAM_FORECASTER_COUNT.getName()) + || statsRequest.getStatsToBeRetrieved().contains(StatNames.HC_FORECASTER_COUNT.getName())) + && clusterService.state().getRoutingTable().hasIndex(CommonName.CONFIG_INDEX)) { + + // Create the query + ExistsQueryBuilder existsQuery = QueryBuilders.existsQuery(Config.CATEGORY_FIELD); + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery().mustNot(existsQuery); + + FilterAggregationBuilder withFieldAgg = AggregationBuilders.filter(WITH_CATEGORY_FIELD, existsQuery); + FilterAggregationBuilder withoutFieldAgg = AggregationBuilders.filter(WITHOUT_CATEGORY_FIELD, boolQuery); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(QueryBuilders.matchAllQuery()); + searchSourceBuilder.size(0); + searchSourceBuilder.aggregation(withFieldAgg); + searchSourceBuilder.aggregation(withoutFieldAgg); + + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX); + searchRequest.source(searchSourceBuilder); + + // Execute the query + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + // Parse the response + SingleBucketAggregation withField = (SingleBucketAggregation) searchResponse.getAggregations().get(WITH_CATEGORY_FIELD); + SingleBucketAggregation withoutField = (SingleBucketAggregation) searchResponse + .getAggregations() + .get(WITHOUT_CATEGORY_FIELD); + if (statsRequest.getStatsToBeRetrieved().contains(StatNames.FORECASTER_COUNT.getName())) { + stats.getStat(StatNames.FORECASTER_COUNT.getName()).setValue(withField.getDocCount() + withoutField.getDocCount()); + } + if (statsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_STREAM_FORECASTER_COUNT.getName())) { + stats.getStat(StatNames.SINGLE_STREAM_FORECASTER_COUNT.getName()).setValue(withoutField.getDocCount()); + } + if (statsRequest.getStatsToBeRetrieved().contains(StatNames.HC_FORECASTER_COUNT.getName())) { + stats.getStat(StatNames.HC_FORECASTER_COUNT.getName()).setValue(withField.getDocCount()); + } + adStatsResponse.setClusterStats(getClusterStatsMap(statsRequest)); + listener.onResponse(adStatsResponse); + }, e -> listener.onFailure(e))); + } else { + adStatsResponse.setClusterStats(getClusterStatsMap(statsRequest)); + listener.onResponse(adStatsResponse); + } + } + + /** + * Make async request to get the forecasting statistics from each node and, onResponse, set the + * StatsNodesResponse field of StatsResponse + * + * @param client Client + * @param listener MultiResponsesDelegateActionListener to be used once both requests complete + * @param statsRequest Request containing stats to be retrieved + */ + @Override + public void getNodeStats(Client client, MultiResponsesDelegateActionListener listener, StatsRequest statsRequest) { + client.execute(ForecastStatsNodesAction.INSTANCE, statsRequest, ActionListener.wrap(adStatsResponse -> { + StatsResponse restStatsResponse = new StatsResponse(); + restStatsResponse.setStatsNodesResponse(adStatsResponse); + listener.onResponse(restStatsResponse); + }, listener::onFailure)); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/StopForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/StopForecasterAction.java new file mode 100644 index 000000000..9b38db2eb --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/StopForecasterAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.StopConfigResponse; + +public class StopForecasterAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "forecaster/stop"; + public static final StopForecasterAction INSTANCE = new StopForecasterAction(); + + private StopForecasterAction() { + super(NAME, StopConfigResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/StopForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/StopForecasterTransportAction.java new file mode 100644 index 000000000..a3a35e3f2 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/StopForecasterTransportAction.java @@ -0,0 +1,85 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_STOP_FORECASTER; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.transport.TransportService; + +public class StopForecasterTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(StopForecasterTransportAction.class); + + private final Client client; + private final DiscoveryNodeFilterer nodeFilter; + + @Inject + public StopForecasterTransportAction( + TransportService transportService, + DiscoveryNodeFilterer nodeFilter, + ActionFilters actionFilters, + Client client + ) { + super(StopForecasterAction.NAME, transportService, actionFilters, StopConfigRequest::new); + this.client = client; + this.nodeFilter = nodeFilter; + } + + @Override + protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { + StopConfigRequest request = StopConfigRequest.fromActionRequest(actionRequest); + String configId = request.getConfigID(); + try { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + DeleteModelRequest modelDeleteRequest = new DeleteModelRequest(configId, dataNodes); + client.execute(DeleteForecastModelAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> { + if (response.hasFailures()) { + LOG.warn("Cannot delete all models of forecaster {}", configId); + for (FailedNodeException failedNodeException : response.failures()) { + LOG.warn("Deleting models of node has exception", failedNodeException); + } + // if customers are using an updated detector and we haven't deleted old + // checkpoints, customer would have trouble + listener.onResponse(new StopConfigResponse(false)); + } else { + LOG.info("models of forecaster {} get deleted", configId); + listener.onResponse(new StopConfigResponse(true)); + } + }, exception -> { + LOG.error(new ParameterizedMessage("Deletion of forecaster [{}] has exception.", configId), exception); + listener.onResponse(new StopConfigResponse(false)); + })); + } catch (Exception e) { + LOG.error(FAIL_TO_STOP_FORECASTER + " " + configId, e); + Throwable cause = ExceptionsHelper.unwrapCause(e); + listener.onFailure(new InternalFailure(configId, FAIL_TO_STOP_FORECASTER, cause)); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamAction.java b/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamAction.java new file mode 100644 index 000000000..bbcee1b72 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.SuggestConfigParamResponse; + +public class SuggestForecasterParamAction extends ActionType { + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/suggest"; + public static final SuggestForecasterParamAction INSTANCE = new SuggestForecasterParamAction(); + + private SuggestForecasterParamAction() { + super(NAME, SuggestConfigParamResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamTransportAction.java b/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamTransportAction.java new file mode 100644 index 000000000..b632ef866 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamTransportAction.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; + +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.ValidationException; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.transport.BaseSuggestConfigParamTransportAction; +import org.opensearch.timeseries.transport.SuggestConfigParamRequest; +import org.opensearch.timeseries.transport.SuggestConfigParamResponse; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class SuggestForecasterParamTransportAction extends BaseSuggestConfigParamTransportAction { + public static final Logger logger = LogManager.getLogger(SuggestForecasterParamTransportAction.class); + + @Inject + public SuggestForecasterParamTransportAction( + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + Settings settings, + ForecastIndexManagement anomalyDetectionIndices, + ActionFilters actionFilters, + TransportService transportService, + SearchFeatureDao searchFeatureDao + ) { + super( + SuggestForecasterParamAction.NAME, + client, + clientUtil, + clusterService, + settings, + actionFilters, + transportService, + FORECAST_FILTER_BY_BACKEND_ROLES, + AnalysisType.FORECAST, + searchFeatureDao + ); + } + + @Override + public void suggestExecute( + SuggestConfigParamRequest request, + User user, + ThreadContext.StoredContext storedContext, + ActionListener listener + ) { + storedContext.restore(); + // if type param isn't blank and isn't a part of possible validation types throws exception + Set params = getParametersToSuggest(request.getParam()); + if (params.isEmpty()) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError(CommonMessages.NOT_EXISTENT_SUGGEST_TYPE); + listener.onFailure(validationException); + return; + } + + Config config = request.getConfig(); + MultiResponsesDelegateActionListener delegateListener = + new MultiResponsesDelegateActionListener( + listener, + params.size(), + CommonMessages.FAIL_SUGGEST_ERR_MSG + config.getId(), + false + ); + + if (params.contains(SuggestName.INTERVAL)) { + suggestInterval(request.getConfig(), user, request.getRequestTimeout(), delegateListener); + } + + if (params.contains(SuggestName.HISTORY)) { + delegateListener.onResponse(new SuggestConfigParamResponse.Builder().history(config.suggestHistory()).build()); + } + + if (params.contains(SuggestName.HORIZON)) { + Forecaster forecaster = (Forecaster) config; + delegateListener.onResponse(new SuggestConfigParamResponse.Builder().horizon(forecaster.suggestHorizon()).build()); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SuggestName.java b/src/main/java/org/opensearch/forecast/transport/SuggestName.java new file mode 100644 index 000000000..15f3239c3 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SuggestName.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import java.util.Collection; +import java.util.Set; + +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.Name; +import org.opensearch.timeseries.constant.CommonMessages; + +public enum SuggestName implements Name { + INTERVAL(Forecaster.FORECAST_INTERVAL_FIELD), + HORIZON(Forecaster.HORIZON_FIELD), + HISTORY(Forecaster.HISTORY_INTERVAL_FIELD); + + private String name; + + SuggestName(String name) { + this.name = name; + } + + /** + * Get suggest name + * + * @return name + */ + @Override + public String getName() { + return name; + } + + public static SuggestName getName(String name) { + switch (name) { + case Forecaster.FORECAST_INTERVAL_FIELD: + return INTERVAL; + case Forecaster.HORIZON_FIELD: + return HORIZON; + case Forecaster.HISTORY_INTERVAL_FIELD: + return HISTORY; + default: + throw new IllegalArgumentException(CommonMessages.NOT_EXISTENT_SUGGEST_TYPE); + } + } + + public static Set getNames(Collection names) { + return Name.getNameFromCollection(names, SuggestName::getName); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ValidateForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/ValidateForecasterAction.java new file mode 100644 index 000000000..26cf17666 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ValidateForecasterAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.ValidateConfigResponse; + +public class ValidateForecasterAction extends ActionType { + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/validate"; + public static final ValidateForecasterAction INSTANCE = new ValidateForecasterAction(); + + private ValidateForecasterAction() { + super(NAME, ValidateConfigResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ValidateForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ValidateForecasterTransportAction.java new file mode 100644 index 000000000..01b38ffa6 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ValidateForecasterTransportAction.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.rest.handler.ValidateForecasterActionHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.rest.handler.Processor; +import org.opensearch.timeseries.transport.BaseValidateConfigTransportAction; +import org.opensearch.timeseries.transport.ValidateConfigRequest; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ValidateForecasterTransportAction extends BaseValidateConfigTransportAction { + public static final Logger logger = LogManager.getLogger(ValidateForecasterTransportAction.class); + + @Inject + public ValidateForecasterTransportAction( + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + Settings settings, + ForecastIndexManagement anomalyDetectionIndices, + ActionFilters actionFilters, + TransportService transportService, + SearchFeatureDao searchFeatureDao + ) { + super( + ValidateForecasterAction.NAME, + client, + clientUtil, + clusterService, + xContentRegistry, + settings, + anomalyDetectionIndices, + actionFilters, + transportService, + searchFeatureDao, + FORECAST_FILTER_BY_BACKEND_ROLES + ); + } + + @Override + protected Processor createProcessor(Config forecaster, ValidateConfigRequest request, User user) { + return new ValidateForecasterActionHandler( + clusterService, + client, + clientUtil, + indexManagement, + forecaster, + request.getRequestTimeout(), + request.getMaxSingleEntityAnomalyDetectors(), + request.getMaxMultiEntityAnomalyDetectors(), + request.getMaxAnomalyFeatures(), + request.getMaxCategoricalFields(), + RestRequest.Method.POST, + xContentRegistry, + user, + searchFeatureDao, + request.getValidationType(), + clock, + settings + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java b/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java new file mode 100644 index 000000000..1f94257e3 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport.handler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.transport.ForecastResultBulkAction; +import org.opensearch.forecast.transport.ForecastResultBulkRequest; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; + +public class ForecastIndexMemoryPressureAwareResultHandler extends + IndexMemoryPressureAwareResultHandler { + private static final Logger LOG = LogManager.getLogger(ForecastIndexMemoryPressureAwareResultHandler.class); + + @Inject + public ForecastIndexMemoryPressureAwareResultHandler(Client client, ForecastIndexManagement anomalyDetectionIndices) { + super(client, anomalyDetectionIndices); + } + + @Override + public void bulk(ForecastResultBulkRequest currentBulkRequest, ActionListener listener) { + if (currentBulkRequest.numberOfActions() <= 0) { + listener.onFailure(new TimeSeriesException("no result to save")); + return; + } + client.execute(ForecastResultBulkAction.INSTANCE, currentBulkRequest, ActionListener.wrap(response -> { + LOG.debug(CommonMessages.SUCCESS_SAVING_RESULT_MSG); + listener.onResponse(response); + }, exception -> { + LOG.error("Error in bulking results", exception); + listener.onFailure(exception); + })); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/handler/ForecastSearchHandler.java b/src/main/java/org/opensearch/forecast/transport/handler/ForecastSearchHandler.java new file mode 100644 index 000000000..61979f534 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/handler/ForecastSearchHandler.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport.handler; + +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.timeseries.transport.handler.SearchHandler; + +/** + * Handle general search request, check user role and return search response. + */ +public class ForecastSearchHandler extends SearchHandler { + + public ForecastSearchHandler(Settings settings, ClusterService clusterService, Client client) { + super(settings, clusterService, client, ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES); + } +} diff --git a/src/main/java/org/opensearch/ad/AbstractProfileRunner.java b/src/main/java/org/opensearch/timeseries/AbstractProfileRunner.java similarity index 92% rename from src/main/java/org/opensearch/ad/AbstractProfileRunner.java rename to src/main/java/org/opensearch/timeseries/AbstractProfileRunner.java index e402a4da1..79345db34 100644 --- a/src/main/java/org/opensearch/ad/AbstractProfileRunner.java +++ b/src/main/java/org/opensearch/timeseries/AbstractProfileRunner.java @@ -9,11 +9,11 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import java.util.Locale; -import org.opensearch.ad.model.InitProgressProfile; +import org.opensearch.timeseries.model.InitProgressProfile; public abstract class AbstractProfileRunner { protected long requiredSamples; diff --git a/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java b/src/main/java/org/opensearch/timeseries/AbstractSearchAction.java similarity index 73% rename from src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java rename to src/main/java/org/opensearch/timeseries/AbstractSearchAction.java index 1d0611cf7..43681e78f 100644 --- a/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java +++ b/src/main/java/org/opensearch/timeseries/AbstractSearchAction.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.rest; +package org.opensearch.timeseries; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.timeseries.util.RestHandlerUtils.getSourceContext; @@ -17,6 +17,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Supplier; import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; @@ -24,8 +25,6 @@ import org.opensearch.action.ActionType; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.client.node.NodeClient; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; @@ -36,6 +35,7 @@ import org.opensearch.rest.RestResponse; import org.opensearch.rest.action.RestResponseListener; import org.opensearch.search.builder.SearchSourceBuilder; +import org.owasp.encoder.Encode; /** * Abstract class to handle search request. @@ -47,6 +47,8 @@ public abstract class AbstractSearchAction extends B protected final List urlPaths; protected final List> deprecatedPaths; protected final ActionType actionType; + protected final Supplier adEnabledSupplier; + protected final String disabledMsg; private final Logger logger = LogManager.getLogger(AbstractSearchAction.class); @@ -55,29 +57,38 @@ public AbstractSearchAction( List> deprecatedPaths, String index, Class clazz, - ActionType actionType + ActionType actionType, + Supplier adEnabledSupplier, + String disabledMsg ) { this.index = index; this.clazz = clazz; this.urlPaths = urlPaths; this.deprecatedPaths = deprecatedPaths; this.actionType = actionType; + this.adEnabledSupplier = adEnabledSupplier; + this.disabledMsg = disabledMsg; } @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - if (!ADEnabledSetting.isADEnabled()) { - throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + if (!adEnabledSupplier.get()) { + throw new IllegalStateException(disabledMsg); } - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); - // order of response will be re-arranged everytime we use `_source`, we sometimes do this - // even if user doesn't give this field as we exclude ui_metadata if request isn't from OSD - // ref-link: https://github.com/elastic/elasticsearch/issues/17639 - searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); - searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); - SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(this.index); - return channel -> client.execute(actionType, searchRequest, search(channel)); + try { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); + // order of response will be re-arranged everytime we use `_source`, we sometimes do this + // even if user doesn't give this field as we exclude ui_metadata if request isn't from OSD + // ref-link: https://github.com/elastic/elasticsearch/issues/17639 + searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); + searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(this.index); + return channel -> client.execute(actionType, searchRequest, search(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } protected void onFailure(RestChannel channel, Exception e) { diff --git a/src/main/java/org/opensearch/ad/DetectorModelSize.java b/src/main/java/org/opensearch/timeseries/AnalysisModelSize.java similarity index 74% rename from src/main/java/org/opensearch/ad/DetectorModelSize.java rename to src/main/java/org/opensearch/timeseries/AnalysisModelSize.java index 52e4660e6..5e70c456c 100644 --- a/src/main/java/org/opensearch/ad/DetectorModelSize.java +++ b/src/main/java/org/opensearch/timeseries/AnalysisModelSize.java @@ -9,16 +9,16 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import java.util.Map; -public interface DetectorModelSize { +public interface AnalysisModelSize { /** * Gets all of a detector's model sizes hosted on a node * - * @param detectorId Detector Id + * @param id Analysis Id * @return a map of model id to its memory size */ - Map getModelSize(String detectorId); + Map getModelSize(String id); } diff --git a/src/main/java/org/opensearch/timeseries/AnalysisType.java b/src/main/java/org/opensearch/timeseries/AnalysisType.java index 7d7cc805e..f0f4e2025 100644 --- a/src/main/java/org/opensearch/timeseries/AnalysisType.java +++ b/src/main/java/org/opensearch/timeseries/AnalysisType.java @@ -7,5 +7,13 @@ public enum AnalysisType { AD, - FORECAST + FORECAST; + + public boolean isForecast() { + return this == FORECAST; + } + + public boolean isAD() { + return this == AD; + } } diff --git a/src/main/java/org/opensearch/ad/EntityProfileRunner.java b/src/main/java/org/opensearch/timeseries/EntityProfileRunner.java similarity index 78% rename from src/main/java/org/opensearch/ad/EntityProfileRunner.java rename to src/main/java/org/opensearch/timeseries/EntityProfileRunner.java index ce14dbef2..43dbe3cbc 100644 --- a/src/main/java/org/opensearch/ad/EntityProfileRunner.java +++ b/src/main/java/org/opensearch/timeseries/EntityProfileRunner.java @@ -9,10 +9,11 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Optional; @@ -20,26 +21,17 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionType; import org.opensearch.action.get.GetRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.model.EntityProfile; -import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.EntityState; -import org.opensearch.ad.model.InitProgressProfile; -import org.opensearch.ad.settings.ADNumericSetting; -import org.opensearch.ad.transport.EntityProfileAction; -import org.opensearch.ad.transport.EntityProfileRequest; -import org.opensearch.ad.transport.EntityProfileResponse; import org.opensearch.client.Client; import org.opensearch.cluster.routing.Preference; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; @@ -49,43 +41,73 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.BiCheckedFunction; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.EntityState; +import org.opensearch.timeseries.model.InitProgressProfile; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.EntityProfileRequest; +import org.opensearch.timeseries.transport.EntityProfileResponse; import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.timeseries.util.SecurityClientUtil; -public class EntityProfileRunner extends AbstractProfileRunner { +public class EntityProfileRunner> extends AbstractProfileRunner { private final Logger logger = LogManager.getLogger(EntityProfileRunner.class); - static final String NOT_HC_DETECTOR_ERR_MSG = "This is not a high cardinality detector"; + public static final String NOT_HC_DETECTOR_ERR_MSG = "This is not a high cardinality detector"; static final String EMPTY_ENTITY_ATTRIBUTES = "Empty entity attributes"; static final String NO_ENTITY = "Cannot find entity"; private Client client; private SecurityClientUtil clientUtil; private NamedXContentRegistry xContentRegistry; - - public EntityProfileRunner(Client client, SecurityClientUtil clientUtil, NamedXContentRegistry xContentRegistry, long requiredSamples) { + private BiCheckedFunction configParser; + private int maxCategoryFields; + private AnalysisType analysisType; + private EntityProfileActionType entityProfileAction; + private String resultIndexAlias; + private String configIdField; + + public EntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples, + BiCheckedFunction configParser, + int maxCategoryFields, + AnalysisType analysisType, + EntityProfileActionType entityProfileAction, + String resultIndexAlias, + String configIdField + ) { super(requiredSamples); this.client = client; this.clientUtil = clientUtil; this.xContentRegistry = xContentRegistry; + this.configParser = configParser; + this.maxCategoryFields = maxCategoryFields; + this.analysisType = analysisType; + this.entityProfileAction = entityProfileAction; + this.resultIndexAlias = resultIndexAlias; + this.configIdField = configIdField; } /** * Get profile info of specific entity. * - * @param detectorId detector identifier + * @param configId config identifier * @param entityValue entity value * @param profilesToCollect profiles to collect * @param listener action listener to handle exception and process entity profile response */ public void profile( - String detectorId, + String configId, Entity entityValue, Set profilesToCollect, ActionListener listener @@ -94,7 +116,7 @@ public void profile( listener.onFailure(new IllegalArgumentException(CommonMessages.EMPTY_PROFILES_COLLECT)); return; } - GetRequest getDetectorRequest = new GetRequest(CommonName.CONFIG_INDEX, detectorId); + GetRequest getDetectorRequest = new GetRequest(CommonName.CONFIG_INDEX, configId); client.get(getDetectorRequest, ActionListener.wrap(getResponse -> { if (getResponse != null && getResponse.isExists()) { @@ -104,21 +126,20 @@ public void profile( .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetector detector = AnomalyDetector.parse(parser, detectorId); - List categoryFields = detector.getCategoryFields(); - int maxCategoryFields = ADNumericSetting.maxCategoricalFields(); + Config config = configParser.apply(parser, configId); + List categoryFields = config.getCategoryFields(); if (categoryFields == null || categoryFields.size() == 0) { listener.onFailure(new IllegalArgumentException(NOT_HC_DETECTOR_ERR_MSG)); } else if (categoryFields.size() > maxCategoryFields) { listener.onFailure(new IllegalArgumentException(CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields))); } else { - validateEntity(entityValue, categoryFields, detectorId, profilesToCollect, detector, listener); + validateEntity(entityValue, categoryFields, configId, profilesToCollect, config, listener); } } catch (Exception t) { listener.onFailure(t); } } else { - listener.onFailure(new IllegalArgumentException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + detectorId)); + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); } }, listener::onFailure)); } @@ -143,14 +164,18 @@ private void validateEntity( List categoryFields, String detectorId, Set profilesToCollect, - AnomalyDetector detector, + Config config, ActionListener listener ) { Map attributes = entity.getAttributes(); - if (attributes == null || attributes.size() != categoryFields.size()) { + if (attributes == null) { listener.onFailure(new IllegalArgumentException(EMPTY_ENTITY_ATTRIBUTES)); return; } + if (attributes.size() != categoryFields.size()) { + listener.onFailure(new IllegalArgumentException(NO_ENTITY)); + return; + } for (String field : categoryFields) { if (false == attributes.containsKey(field)) { listener.onFailure(new IllegalArgumentException("Cannot find " + field)); @@ -158,15 +183,15 @@ private void validateEntity( } } - BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(detector.getFilterQuery()); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(config.getFilterQuery()); - for (TermQueryBuilder term : entity.getTermQueryBuilders()) { + for (TermQueryBuilder term : entity.getTermQueryForCustomerIndex()) { internalFilterQuery.filter(term); } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(internalFilterQuery).size(1); - SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder) + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0]), searchSourceBuilder) .preference(Preference.LOCAL.toString()); final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { try { @@ -174,7 +199,7 @@ private void validateEntity( listener.onFailure(new IllegalArgumentException(NO_ENTITY)); return; } - prepareEntityProfile(listener, detectorId, entity, profilesToCollect, detector, categoryFields.get(0)); + prepareEntityProfile(listener, detectorId, entity, profilesToCollect, config, categoryFields.get(0)); } catch (Exception e) { listener.onFailure(new IllegalArgumentException(NO_ENTITY)); return; @@ -186,9 +211,9 @@ private void validateEntity( .asyncRequestWithInjectedSecurity( searchRequest, client::search, - detector.getId(), + config.getId(), client, - AnalysisType.AD, + analysisType, searchResponseListener ); @@ -199,16 +224,16 @@ private void prepareEntityProfile( String detectorId, Entity entityValue, Set profilesToCollect, - AnomalyDetector detector, + Config config, String categoryField ) { EntityProfileRequest request = new EntityProfileRequest(detectorId, entityValue, profilesToCollect); client .execute( - EntityProfileAction.INSTANCE, + entityProfileAction, request, - ActionListener.wrap(r -> getJob(detectorId, entityValue, profilesToCollect, detector, r, listener), listener::onFailure) + ActionListener.wrap(r -> getJob(detectorId, entityValue, profilesToCollect, config, r, listener), listener::onFailure) ); } @@ -216,7 +241,7 @@ private void getJob( String detectorId, Entity entityValue, Set profilesToCollect, - AnomalyDetector detector, + Config config, EntityProfileResponse entityProfileResponse, ActionListener listener ) { @@ -266,7 +291,7 @@ private void getJob( detectorId, entityValue, profilesToCollect, - detector, + config, job, delegateListener ); @@ -278,7 +303,7 @@ private void getJob( detectorId, enabledTimeMs, entityValue, - detector.getCustomResultIndex() + config.getCustomResultIndex() ); EntityProfile.Builder builder = new EntityProfile.Builder(); @@ -331,7 +356,7 @@ private void profileStateRelated( String detectorId, Entity entityValue, Set profilesToCollect, - AnomalyDetector detector, + Config config, Job job, MultiResponsesDelegateActionListener delegateListener ) { @@ -342,7 +367,7 @@ private void profileStateRelated( } else if (totalUpdates >= requiredSamples) { sendRunningState(profilesToCollect, entityValue, delegateListener); } else { - sendInitState(profilesToCollect, entityValue, detector, totalUpdates, delegateListener); + sendInitState(profilesToCollect, entityValue, config, totalUpdates, delegateListener); } } @@ -389,7 +414,7 @@ private void sendRunningState( private void sendInitState( Set profilesToCollect, Entity entityValue, - AnomalyDetector detector, + Config config, long updates, MultiResponsesDelegateActionListener delegateListener ) { @@ -398,64 +423,21 @@ private void sendInitState( builder.state(EntityState.INIT); } if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS)) { - long intervalMins = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMinutes(); + long intervalMins = ((IntervalTimeConfiguration) config.getInterval()).toDuration().toMinutes(); InitProgressProfile initProgress = computeInitProgressProfile(updates, intervalMins); builder.initProgress(initProgress); } delegateListener.onResponse(builder.build()); } - private SearchRequest createLastSampleTimeRequest(String detectorId, long enabledTime, Entity entity, String resultIndex) { + private SearchRequest createLastSampleTimeRequest(String configId, long enabledTime, Entity entity, String resultIndex) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - String path = "entity"; - String entityName = path + ".name"; - String entityValue = path + ".value"; - - for (Map.Entry attribute : entity.getAttributes().entrySet()) { - /* - * each attribute pair corresponds to a nested query like - "nested": { - "query": { - "bool": { - "filter": [ - { - "term": { - "entity.name": { - "value": "turkey4", - "boost": 1 - } - } - }, - { - "term": { - "entity.value": { - "value": "Turkey", - "boost": 1 - } - } - } - ] - } - }, - "path": "entity", - "ignore_unmapped": false, - "score_mode": "none", - "boost": 1 - } - },*/ - BoolQueryBuilder nestedBoolQueryBuilder = new BoolQueryBuilder(); - - TermQueryBuilder entityNameFilterQuery = QueryBuilders.termQuery(entityName, attribute.getKey()); - nestedBoolQueryBuilder.filter(entityNameFilterQuery); - TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValue, attribute.getValue()); - nestedBoolQueryBuilder.filter(entityValueFilterQuery); - - NestedQueryBuilder nestedNameQueryBuilder = new NestedQueryBuilder(path, nestedBoolQueryBuilder, ScoreMode.None); + for (NestedQueryBuilder nestedNameQueryBuilder : entity.getTermQueryForResultIndex()) { boolQueryBuilder.filter(nestedNameQueryBuilder); } - boolQueryBuilder.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); + boolQueryBuilder.filter(QueryBuilders.termQuery(configIdField, configId)); boolQueryBuilder.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); @@ -465,7 +447,7 @@ private SearchRequest createLastSampleTimeRequest(String detectorId, long enable .trackTotalHits(false) .size(0); - SearchRequest request = new SearchRequest(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + SearchRequest request = new SearchRequest(resultIndexAlias); request.source(source); if (resultIndex != null) { request.indices(resultIndex); diff --git a/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java b/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java new file mode 100644 index 000000000..3b19684a7 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java @@ -0,0 +1,366 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import java.time.Instant; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionType; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.transport.RCFPollingAction; +import org.opensearch.ad.transport.RCFPollingRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.SearchHits; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.ExceptionUtil; + +public abstract class ExecuteResultResponseRecorder & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager, IndexableResultType extends IndexableResult, ProfileActionType extends ActionType> { + + private static final Logger log = LogManager.getLogger(ExecuteResultResponseRecorder.class); + + protected IndexManagementType indexManagement; + private ResultBulkIndexingHandler resultHandler; + protected TaskManagerType taskManager; + private DiscoveryNodeFilterer nodeFilter; + private ThreadPool threadPool; + private String threadPoolName; + private Client client; + private NodeStateManager nodeStateManager; + private TaskCacheManager taskCacheManager; + private int rcfMinSamples; + protected IndexType resultIndex; + private AnalysisType analysisType; + private ProfileActionType profileAction; + + public ExecuteResultResponseRecorder( + IndexManagementType indexManagement, + ResultBulkIndexingHandler resultHandler, + TaskManagerType taskManager, + DiscoveryNodeFilterer nodeFilter, + ThreadPool threadPool, + String threadPoolName, + Client client, + NodeStateManager nodeStateManager, + TaskCacheManager taskCacheManager, + int rcfMinSamples, + IndexType resultIndex, + AnalysisType analysisType, + ProfileActionType profileAction + ) { + this.indexManagement = indexManagement; + this.resultHandler = resultHandler; + this.taskManager = taskManager; + this.nodeFilter = nodeFilter; + this.threadPool = threadPool; + this.threadPoolName = threadPoolName; + this.client = client; + this.nodeStateManager = nodeStateManager; + this.taskCacheManager = taskCacheManager; + this.rcfMinSamples = rcfMinSamples; + this.resultIndex = resultIndex; + this.analysisType = analysisType; + this.profileAction = profileAction; + } + + public void indexResult( + Instant detectionStartTime, + Instant executionStartTime, + ResultResponse response, + Config config + ) { + String configId = config.getId(); + try { + + if (!response.shouldSave()) { + updateRealtimeTask(response, configId); + return; + } + IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) config.getWindowDelay(); + Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + User user = config.getUser(); + + if (response.getError() != null) { + log.info("Result action run successfully for {} with error {}", configId, response.getError()); + } + + List analysisResults = response + .toIndexableResults( + configId, + dataStartTime, + dataEndTime, + executionStartTime, + Instant.now(), + indexManagement.getSchemaVersion(resultIndex), + user, + response.getError() + ); + + String resultIndex = config.getCustomResultIndex(); + resultHandler + .bulk( + resultIndex, + analysisResults, + configId, + ActionListener + .wrap( + r -> {}, + exception -> log.error(String.format(Locale.ROOT, "Fail to bulk for %s", configId), exception) + ) + ); + updateRealtimeTask(response, configId); + } catch (EndRunException e) { + throw e; + } catch (Exception e) { + log.error("Failed to index result for " + configId, e); + } + } + + /** + * + * If result action is handled asynchronously, response won't contain the result. + * This function wait some time before fetching update. + * One side-effect is if the config is already deleted the latest task will get deleted too. + * This delayed update can cause ResourceNotFoundException. + * + * @param response response returned from executing AnomalyResultAction + * @param configId config Id + */ + protected void delayedUpdate(ResultResponse response, String configId) { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + Set profiles = new HashSet<>(); + profiles.add(ProfileName.INIT_PROGRESS); + ProfileRequest profileRequest = new ProfileRequest(configId, profiles, true, dataNodes); + Runnable profileHCInitProgress = () -> { + client.execute(profileAction, profileRequest, ActionListener.wrap(r -> { + log.debug("Update latest realtime task for config {}, total updates: {}", configId, r.getTotalUpdates()); + updateLatestRealtimeTask(configId, null, r.getTotalUpdates(), response.getConfigIntervalInMinutes(), response.getError()); + }, e -> { log.error("Failed to update latest realtime task for " + configId, e); })); + }; + if (!taskManager.isHCRealtimeTaskStartInitializing(configId)) { + // real time init progress is 0 may mean this is a newly started detector + // Delay real time cache update by one minute. If we are in init status, the delay may give the model training time to + // finish. We can change the detector running immediately instead of waiting for the next interval. + threadPool.schedule(profileHCInitProgress, new TimeValue(60, TimeUnit.SECONDS), threadPoolName); + } else { + profileHCInitProgress.run(); + } + } + + protected void updateLatestRealtimeTask( + String configId, + String taskState, + Long rcfTotalUpdates, + Long configIntervalInMinutes, + String error + ) { + // Don't need info as this will be printed repeatedly in each interval + ActionListener listener = ActionListener.wrap(r -> { + if (r != null) { + log.debug("Updated latest realtime task successfully for config {}, taskState: {}", configId, taskState); + } + }, e -> { + if ((e instanceof ResourceNotFoundException) && e.getMessage().contains(CommonMessages.CAN_NOT_FIND_LATEST_TASK)) { + // Clear realtime task cache, will recreate task in next run, check ADResultProcessor. + log.error("Can't find latest realtime task of config " + configId); + taskManager.removeRealtimeTaskCache(configId); + } else { + log.error("Failed to update latest realtime task for config " + configId, e); + } + }); + + // rcfTotalUpdates is null when we save exception messages + if (!taskCacheManager.hasQueriedResultIndex(configId) && rcfTotalUpdates != null && rcfTotalUpdates < rcfMinSamples) { + // confirm the total updates number since it is possible that we have already had results after job enabling time + // If yes, total updates should be at least rcfMinSamples so that the init progress reaches 100%. + confirmTotalRCFUpdatesFound( + configId, + taskState, + rcfTotalUpdates, + configIntervalInMinutes, + error, + ActionListener + .wrap( + r -> taskManager + .updateLatestRealtimeTaskOnCoordinatingNode(configId, taskState, r, configIntervalInMinutes, error, listener), + e -> { + log.error("Fail to confirm rcf update", e); + taskManager + .updateLatestRealtimeTaskOnCoordinatingNode( + configId, + taskState, + rcfTotalUpdates, + configIntervalInMinutes, + error, + listener + ); + } + ) + ); + } else { + taskManager + .updateLatestRealtimeTaskOnCoordinatingNode(configId, taskState, rcfTotalUpdates, configIntervalInMinutes, error, listener); + } + } + + /** + * The function is not only indexing the result with the exception, but also updating the task state after + * 60s if the exception is related to cold start (index not found exceptions) for a single stream detector. + * + * @param executeStartTime execution start time + * @param executeEndTime execution end time + * @param errorMessage Error message to record + * @param taskState task state (e.g., stopped) + * @param config config accessor + */ + public void indexResultException( + Instant executeStartTime, + Instant executeEndTime, + String errorMessage, + String taskState, + Config config + ) { + String configId = config.getId(); + try { + IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) config.getWindowDelay(); + Instant dataStartTime = executeStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + Instant dataEndTime = executeEndTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + User user = config.getUser(); + + IndexableResultType resultToSave = createErrorResult(configId, dataStartTime, dataEndTime, executeEndTime, errorMessage, user); + String resultIndex = config.getCustomResultIndex(); + if (resultIndex != null && !indexManagement.doesIndexExist(resultIndex)) { + // Set result index as null, will write exception to default result index. + resultHandler.index(resultToSave, configId, null); + } else { + resultHandler.index(resultToSave, configId, resultIndex); + } + + if (errorMessage.contains(ADCommonMessages.NO_MODEL_ERR_MSG) && !config.isHighCardinality()) { + // single stream detector raises ResourceNotFoundException containing ADCommonMessages.NO_CHECKPOINT_ERR_MSG + // when there is no checkpoint. + // Delay real time cache update by one minute so we will have trained models by then and update the state + // document accordingly. + threadPool.schedule(() -> { + RCFPollingRequest request = new RCFPollingRequest(configId); + client.execute(RCFPollingAction.INSTANCE, request, ActionListener.wrap(rcfPollResponse -> { + long totalUpdates = rcfPollResponse.getTotalUpdates(); + // if there are updates, don't record failures + updateLatestRealtimeTask( + configId, + taskState, + totalUpdates, + config.getIntervalInMinutes(), + totalUpdates > 0 ? "" : errorMessage + ); + }, e -> { + log.error("Fail to execute RCFRollingAction", e); + updateLatestRealtimeTask(configId, taskState, null, null, errorMessage); + })); + }, new TimeValue(60, TimeUnit.SECONDS), threadPoolName); + } else { + updateLatestRealtimeTask(configId, taskState, null, null, errorMessage); + } + + } catch (Exception e) { + log.error("Failed to index anomaly result for " + configId, e); + } + } + + private void confirmTotalRCFUpdatesFound( + String configId, + String taskState, + Long rcfTotalUpdates, + Long configIntervalInMinutes, + String error, + ActionListener listener + ) { + nodeStateManager.getConfig(configId, analysisType, ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new TimeSeriesException(configId, "fail to get config")); + return; + } + nodeStateManager.getJob(configId, ActionListener.wrap(jobOptional -> { + if (!jobOptional.isPresent()) { + listener.onFailure(new TimeSeriesException(configId, "fail to get job")); + return; + } + + ProfileUtil + .confirmRealtimeInitStatus( + configOptional.get(), + jobOptional.get().getEnabledTime().toEpochMilli(), + client, + analysisType, + ActionListener.wrap(searchResponse -> { + ActionListener.completeWith(listener, () -> { + SearchHits hits = searchResponse.getHits(); + Long correctedTotalUpdates = rcfTotalUpdates; + if (hits.getTotalHits().value > 0L) { + // correct the number if we have already had results after job enabling time + // so that the detector won't stay initialized + correctedTotalUpdates = Long.valueOf(rcfMinSamples); + } + taskCacheManager.markResultIndexQueried(configId); + return correctedTotalUpdates; + }); + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + // anomaly result index is not created yet + taskCacheManager.markResultIndexQueried(configId); + listener.onResponse(0L); + } else { + listener.onFailure(exception); + } + }) + ); + }, e -> listener.onFailure(new TimeSeriesException(configId, "fail to get job")))); + }, e -> listener.onFailure(new TimeSeriesException(configId, "fail to get config")))); + } + + protected abstract IndexableResultType createErrorResult( + String configId, + Instant dataStartTime, + Instant dataEndTime, + Instant executeEndTime, + String errorMessage, + User user + ); + + // protected abstract void updateRealtimeTask(ResultResponseType response, String configId); + protected abstract void updateRealtimeTask(ResultResponse response, String configId); +} diff --git a/src/main/java/org/opensearch/timeseries/JobProcessor.java b/src/main/java/org/opensearch/timeseries/JobProcessor.java new file mode 100644 index 000000000..8ce5e861b --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/JobProcessor.java @@ -0,0 +1,583 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries; + +import java.time.Instant; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionType; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.InjectSecurity; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.timeseries.transport.ResultRequest; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.util.SecurityUtil; + +import com.google.common.base.Throwables; + +/** + * JobScheduler will call job runner to get time series analysis result periodically + */ +public abstract class JobProcessor & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager, IndexableResultType extends IndexableResult, ProfileActionType extends ActionType, ExecuteResultResponseRecorderType extends ExecuteResultResponseRecorder, IndexJobActionHandlerType extends IndexJobActionHandler> { + + private static final Logger log = LogManager.getLogger(JobProcessor.class); + + private Settings settings; + private int maxRetryForEndRunException; + private Client client; + private ThreadPool threadPool; + private ConcurrentHashMap endRunExceptionCount; + protected IndexManagementType indexManagement; + private TaskManagerType taskManager; + private NodeStateManager nodeStateManager; + private ExecuteResultResponseRecorderType recorder; + private AnalysisType analysisType; + private String threadPoolName; + private ActionType> resultAction; + private IndexJobActionHandlerType indexJobActionHandler; + + protected JobProcessor( + AnalysisType analysisType, + String threadPoolName, + ActionType> resultAction + ) { + // Singleton class, use getJobRunnerInstance method instead of constructor + this.endRunExceptionCount = new ConcurrentHashMap<>(); + this.analysisType = analysisType; + this.threadPoolName = threadPoolName; + this.resultAction = resultAction; + } + + public void setClient(Client client) { + this.client = client; + } + + public void setThreadPool(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + protected void registerSettings(Settings settings, Setting maxRetryForEndRunExceptionSetting) { + this.settings = settings; + this.maxRetryForEndRunException = maxRetryForEndRunExceptionSetting.get(settings); + } + + public void setTaskManager(TaskManagerType adTaskManager) { + this.taskManager = adTaskManager; + } + + public void setIndexManagement(IndexManagementType anomalyDetectionIndices) { + this.indexManagement = anomalyDetectionIndices; + } + + public void setNodeStateManager(NodeStateManager nodeStateManager) { + this.nodeStateManager = nodeStateManager; + } + + public void setExecuteResultResponseRecorder(ExecuteResultResponseRecorderType recorder) { + this.recorder = recorder; + } + + public void setIndexJobActionHandler(IndexJobActionHandlerType indexJobActionHandler) { + this.indexJobActionHandler = indexJobActionHandler; + } + + public void process(Job jobParameter, JobExecutionContext context) { + String configId = jobParameter.getName(); + + log.info("Start to run {} job {}", analysisType, configId); + + taskManager.refreshRealtimeJobRunTime(configId); + + Instant executionEndTime = Instant.now(); + IntervalSchedule schedule = (IntervalSchedule) jobParameter.getSchedule(); + Instant executionStartTime = executionEndTime.minus(schedule.getInterval(), schedule.getUnit()); + + final LockService lockService = context.getLockService(); + + Runnable runnable = () -> { + try { + nodeStateManager.getConfig(configId, analysisType, ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + log.error(new ParameterizedMessage("fail to get config [{}]", configId)); + return; + } + Config config = configOptional.get(); + + if (jobParameter.getLockDurationSeconds() != null) { + lockService + .acquireLock( + jobParameter, + context, + ActionListener + .wrap( + lock -> runJob( + jobParameter, + lockService, + lock, + executionStartTime, + executionEndTime, + recorder, + config + ), + exception -> { + indexResultException( + jobParameter, + lockService, + null, + executionStartTime, + executionEndTime, + exception, + false, + recorder, + config + ); + throw new IllegalStateException("Failed to acquire lock for job: " + configId); + } + ) + ); + } else { + log.warn("Can't get lock for job: " + configId); + } + + }, e -> log.error(new ParameterizedMessage("fail to get config [{}]", configId), e))); + } catch (Exception e) { + // os log won't show anything if there is an exception happens (maybe due to running on a ExecutorService) + // we at least log the error. + log.error("Can't start job: " + configId, e); + throw e; + } + }; + + ExecutorService executor = threadPool.executor(threadPoolName); + executor.submit(runnable); + } + + /** + * Get analysis result, index result or handle exception if failed. + * + * @param jobParameter scheduled job parameter + * @param lockService lock service + * @param lock lock to run job + * @param executionStartTime analysis start time + * @param executionEndTime analysis end time + * @param recorder utility to record job execution result + * @param detector associated detector accessor + */ + public void runJob( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + String configId = jobParameter.getName(); + if (lock == null) { + indexResultException( + jobParameter, + lockService, + lock, + executionStartTime, + executionEndTime, + "Can't run job due to null lock", + false, + recorder, + detector + ); + return; + } + indexManagement.update(); + + User userInfo = SecurityUtil.getUserFromJob(jobParameter, settings); + + String user = userInfo.getName(); + List roles = userInfo.getRoles(); + + validateResultIndexAndRunJob( + jobParameter, + lockService, + lock, + executionStartTime, + executionEndTime, + configId, + user, + roles, + recorder, + detector + ); + } + + protected abstract void validateResultIndexAndRunJob( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + String configId, + String user, + List roles, + ExecuteResultResponseRecorderType recorder2, + Config detector + ); + + protected void runJob( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + String configId, + String user, + List roles, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + // using one thread in the write threadpool + try (InjectSecurity injectSecurity = new InjectSecurity(configId, settings, client.threadPool().getThreadContext())) { + // Injecting user role to verify if the user has permissions for our API. + injectSecurity.inject(user, roles); + + ResultRequest request = createResultRequest(configId, executionStartTime.toEpochMilli(), executionEndTime.toEpochMilli()); + client.execute(resultAction, request, ActionListener.wrap(response -> { + indexResult(jobParameter, lockService, lock, executionStartTime, executionEndTime, response, recorder, detector); + }, + exception -> { + handleException(jobParameter, lockService, lock, executionStartTime, executionEndTime, exception, recorder, detector); + } + )); + } catch (Exception e) { + indexResultException(jobParameter, lockService, lock, executionStartTime, executionEndTime, e, true, recorder, detector); + log.error("Failed to execute AD job " + configId, e); + } + } + + /** + * Handle exception from anomaly result action. + * + * 1. If exception is {@link EndRunException} + * a). if isEndNow == true, stop job and store exception in result + * b). if isEndNow == false, record count of {@link EndRunException} for this + * analysis. If count of {@link EndRunException} exceeds upper limit, will + * stop job and store exception in result; otherwise, just + * store exception in result, not stop job for the config. + * + * 2. If exception is not {@link EndRunException}, decrease count of + * {@link EndRunException} for the config and index exception in + * result. If exception is {@link InternalFailure}, will not log exception + * stack trace as already logged in {@link JobProcessor}. + * + * TODO: Handle finer granularity exception such as some exception may be + * transient and retry in current job may succeed. Currently, we don't + * know which exception is transient and retryable in + * {@link JobProcessor}. So we don't add backoff retry + * now to avoid bring extra load to cluster, expecially the code start + * process is relatively heavy by sending out 24 queries, initializing + * models, and saving checkpoints. + * Sometimes missing anomaly and notification is not acceptable. For example, + * current detection interval is 1hour, and there should be anomaly in + * current interval, some transient exception may fail current AD job, + * so no anomaly found and user never know it. Then we start next AD job, + * maybe there is no anomaly in next 1hour, user will never know something + * wrong happened. In one word, this is some tradeoff between protecting + * our performance, user experience and what we can do currently. + * + * @param jobParameter scheduled job parameter + * @param lockService lock service + * @param lock lock to run job + * @param detectionStartTime detection start time + * @param executionStartTime detection end time + * @param exception exception + * @param recorder utility to record job execution result + * @param config associated config accessor + */ + public void handleException( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + Exception exception, + ExecuteResultResponseRecorderType recorder, + Config config + ) { + String configId = jobParameter.getName(); + if (exception instanceof EndRunException) { + log.error("EndRunException happened when executing result action for " + configId, exception); + + if (((EndRunException) exception).isEndNow()) { + // Stop AD job if EndRunException shows we should end job now. + log.info("JobRunner will stop job due to EndRunException for {}", configId); + stopJobForEndRunException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + (EndRunException) exception, + recorder, + config + ); + } else { + endRunExceptionCount.compute(configId, (k, v) -> { + if (v == null) { + return 1; + } else { + return v + 1; + } + }); + log.info("EndRunException happened for {}", configId); + // if AD job failed consecutively due to EndRunException and failed times exceeds upper limit, will stop AD job + if (endRunExceptionCount.get(configId) > maxRetryForEndRunException) { + log + .info( + "JobRunner will stop job due to EndRunException retry exceeds upper limit {} for {}", + maxRetryForEndRunException, + configId + ); + stopJobForEndRunException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + (EndRunException) exception, + recorder, + config + ); + return; + } + indexResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + exception.getMessage(), + true, + recorder, + config + ); + } + } else { + endRunExceptionCount.remove(configId); + if (exception instanceof InternalFailure) { + log.error("InternalFailure happened when executing result action for " + configId, exception); + } else { + log.error("Failed to execute result action for " + configId, exception); + } + indexResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + exception, + true, + recorder, + config + ); + } + } + + private void stopJobForEndRunException( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + EndRunException exception, + ExecuteResultResponseRecorderType recorder, + Config config + ) { + String configId = jobParameter.getName(); + endRunExceptionCount.remove(configId); + String errorPrefix = exception.isEndNow() + ? "Stopped analysis: " + : "Stopped analysis as job failed consecutively for more than " + this.maxRetryForEndRunException + " times: "; + String error = errorPrefix + exception.getMessage(); + + ExecutorFunction runAfer = () -> indexResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + error, + true, + TaskState.STOPPED.name(), + recorder, + config + ); + + ActionListener stopListener = ActionListener.wrap(jobResponse -> { + log.info("Job was disabled by JobRunner for " + configId); + runAfer.execute(); + }, exp -> { + log.error("JobRunner failed to update job as disabled for " + configId, exp); + runAfer.execute(); + }); + + // transport service is null as we cannot access transport service outside of transport action + // to reset real time job we don't need transport service and we have guarded against the null + // reference in task manager + indexJobActionHandler.stopJob(configId, null, stopListener); + } + + private void indexResult( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + ResultResponse response, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + String detectorId = jobParameter.getName(); + endRunExceptionCount.remove(detectorId); + try { + recorder.indexResult(executionStartTime, executionEndTime, response, detector); + } catch (EndRunException e) { + handleException(jobParameter, lockService, lock, executionStartTime, executionEndTime, e, recorder, detector); + } catch (Exception e) { + log.error("Failed to index anomaly result for " + detectorId, e); + } finally { + releaseLock(jobParameter, lockService, lock); + } + + } + + private void indexResultException( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + Exception exception, + boolean releaseLock, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + try { + String errorMessage = exception instanceof TimeSeriesException + ? exception.getMessage() + : Throwables.getStackTraceAsString(exception); + indexResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + errorMessage, + releaseLock, + recorder, + detector + ); + } catch (Exception e) { + log.error("Failed to index result for " + jobParameter.getName(), e); + } + } + + private void indexResultException( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + String errorMessage, + boolean releaseLock, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + indexResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + errorMessage, + releaseLock, + null, + recorder, + detector + ); + } + + private void indexResultException( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + String errorMessage, + boolean releaseLock, + String taskState, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + try { + recorder.indexResultException(detectionStartTime, executionStartTime, errorMessage, taskState, detector); + } finally { + if (releaseLock) { + releaseLock(jobParameter, lockService, lock); + } + } + } + + private void releaseLock(Job jobParameter, LockService lockService, LockModel lock) { + lockService + .release( + lock, + ActionListener + .wrap(released -> { log.info("Released lock for {} job {}", analysisType, jobParameter.getName()); }, exception -> { + log + .error( + new ParameterizedMessage("Failed to release lock for [{}] job [{}]", analysisType, jobParameter.getName()), + exception + ); + }) + ); + } + + protected abstract ResultRequest createResultRequest(String configID, long start, long end); +} diff --git a/src/main/java/org/opensearch/timeseries/JobRunner.java b/src/main/java/org/opensearch/timeseries/JobRunner.java new file mode 100644 index 000000000..68a50ee4f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/JobRunner.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import org.opensearch.ad.ADJobProcessor; +import org.opensearch.forecast.ForecastJobProcessor; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.timeseries.model.Job; + +public class JobRunner implements ScheduledJobRunner { + private static JobRunner INSTANCE; + + public static JobRunner getJobRunnerInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (JobRunner.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new JobRunner(); + return INSTANCE; + } + } + + @Override + public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionContext context) { + if (!(scheduledJobParameter instanceof Job)) { + throw new IllegalArgumentException( + "Job parameter is not instance of Job, type: " + scheduledJobParameter.getClass().getCanonicalName() + ); + } + Job jobParameter = (Job) scheduledJobParameter; + switch (jobParameter.getAnalysisType()) { + case AD: + ADJobProcessor.getInstance().process(jobParameter, context); + break; + case FORECAST: + ForecastJobProcessor.getInstance().process(jobParameter, context); + break; + default: + throw new IllegalArgumentException("Analysis type is not supported, type: : " + jobParameter.getAnalysisType()); + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/MaintenanceState.java b/src/main/java/org/opensearch/timeseries/MaintenanceState.java index 07bbb9546..7fc55dc6d 100644 --- a/src/main/java/org/opensearch/timeseries/MaintenanceState.java +++ b/src/main/java/org/opensearch/timeseries/MaintenanceState.java @@ -22,11 +22,11 @@ public interface MaintenanceState { default void maintenance(Map stateToClean, Duration stateTtl) { stateToClean.entrySet().stream().forEach(entry -> { - K detectorId = entry.getKey(); + K configId = entry.getKey(); V state = entry.getValue(); if (state.expired(stateTtl)) { - stateToClean.remove(detectorId); + stateToClean.remove(configId); } }); diff --git a/src/main/java/org/opensearch/timeseries/MemoryTracker.java b/src/main/java/org/opensearch/timeseries/MemoryTracker.java index 1599960b3..66850fe3b 100644 --- a/src/main/java/org/opensearch/timeseries/MemoryTracker.java +++ b/src/main/java/org/opensearch/timeseries/MemoryTracker.java @@ -290,7 +290,7 @@ public synchronized boolean syncMemoryState(Origin origin, long totalBytes, long .format( Locale.ROOT, "Memory states do not match. Recorded: total bytes %d, reserved bytes %d." - + "Actual: total bytes %d, reserved bytes: %d", + + " Actual: total bytes %d, reserved bytes: %d", recordedTotalBytes, recordedReservedBytes, totalBytes, diff --git a/src/main/java/org/opensearch/timeseries/Name.java b/src/main/java/org/opensearch/timeseries/Name.java index d53a2a33a..8a8f5940b 100644 --- a/src/main/java/org/opensearch/timeseries/Name.java +++ b/src/main/java/org/opensearch/timeseries/Name.java @@ -13,8 +13,10 @@ import java.util.Collection; import java.util.HashSet; +import java.util.List; import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; /** * A super type for enum types returning names @@ -30,4 +32,8 @@ static Set getNameFromCollection(Collection names, F } return res; } + + static Set getListStrs(List profileList) { + return profileList.stream().map(profile -> profile.getName()).collect(Collectors.toSet()); + } } diff --git a/src/main/java/org/opensearch/timeseries/NodeStateManager.java b/src/main/java/org/opensearch/timeseries/NodeStateManager.java index 799a1b6ca..37f3336f6 100644 --- a/src/main/java/org/opensearch/timeseries/NodeStateManager.java +++ b/src/main/java/org/opensearch/timeseries/NodeStateManager.java @@ -78,6 +78,8 @@ public class NodeStateManager implements MaintenanceState, CleanState, Exception * @param clock A UTC clock * @param stateTtl Max time to keep state in memory * @param clusterService Cluster service accessor + * @param maxRetryForUnresponsiveNodeSetting max retry number for unresponsive node + * @param backoffMinutesSetting back off minutes setting */ public NodeStateManager( Client client, @@ -206,9 +208,9 @@ public void getConfig( ) { XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); Config config = null; - if (analysisType == AnalysisType.AD) { + if (analysisType.isAD()) { config = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); - } else if (analysisType == AnalysisType.FORECAST) { + } else if (analysisType.isForecast()) { config = Forecaster.parse(parser, response.getId(), response.getVersion()); } else { throw new UnsupportedOperationException("This method is not supported"); @@ -232,7 +234,7 @@ public void getConfig(String configID, AnalysisType context, ActionListener configParser = context == AnalysisType.AD + BiCheckedFunction configParser = context.isAD() ? AnomalyDetector::parse : Forecaster::parse; clientUtil.asyncRequest(request, client::get, onGetConfigResponse(configID, configParser, listener)); diff --git a/src/main/java/org/opensearch/timeseries/ProfileRunner.java b/src/main/java/org/opensearch/timeseries/ProfileRunner.java new file mode 100644 index 000000000..6e486c17b --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ProfileRunner.java @@ -0,0 +1,567 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.NOT_FOUND; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionType; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; +import org.opensearch.search.aggregations.metrics.InternalCardinality; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.BiCheckedFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigProfile; +import org.opensearch.timeseries.model.ConfigState; +import org.opensearch.timeseries.model.InitProgressProfile; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public abstract class ProfileRunner & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskProfileType extends TaskProfile, TaskManagerType extends TaskManager, ConfigProfileType extends ConfigProfile, ProfileActionType extends ActionType, TaskProfileRunnerType extends TaskProfileRunner> + extends AbstractProfileRunner { + private final Logger logger = LogManager.getLogger(ProfileRunner.class); + protected Client client; + protected SecurityClientUtil clientUtil; + protected NamedXContentRegistry xContentRegistry; + protected DiscoveryNodeFilterer nodeFilter; + protected final TransportService transportService; + protected final TaskManagerType taskManager; + protected final int maxTotalEntitiesToTrack; + protected final AnalysisType analysisType; + protected final List realTimeTaskTypes; + protected final List batchConfigTaskTypes; + protected int maxCategoricalFields; + protected ProfileName taskProfile; + protected TaskProfileRunnerType taskProfileRunner; + protected ProfileActionType profileAction; + protected BiCheckedFunction configParser; + + public ProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeFilterer nodeFilter, + long requiredSamples, + TransportService transportService, + TaskManagerType taskManager, + AnalysisType analysisType, + List realTimeTaskTypes, + List batchConfigTaskTypes, + int maxCategoricalFields, + ProfileName taskProfile, + ProfileActionType profileAction, + BiCheckedFunction configParser, + TaskProfileRunnerType taskProfileRunner + ) { + super(requiredSamples); + this.client = client; + this.clientUtil = clientUtil; + this.xContentRegistry = xContentRegistry; + this.nodeFilter = nodeFilter; + if (requiredSamples <= 0) { + throw new IllegalArgumentException("required samples should be a positive number, but was " + requiredSamples); + } + this.transportService = transportService; + this.taskManager = taskManager; + this.maxTotalEntitiesToTrack = TimeSeriesSettings.MAX_TOTAL_ENTITIES_TO_TRACK; + this.analysisType = analysisType; + this.realTimeTaskTypes = realTimeTaskTypes; + this.batchConfigTaskTypes = batchConfigTaskTypes; + this.maxCategoricalFields = maxCategoricalFields; + this.taskProfile = taskProfile; + this.profileAction = profileAction; + this.configParser = configParser; + this.taskProfileRunner = taskProfileRunner; + } + + public void profile(String configId, ActionListener listener, Set profilesToCollect) { + if (profilesToCollect.isEmpty()) { + listener.onFailure(new IllegalArgumentException(CommonMessages.EMPTY_PROFILES_COLLECT)); + return; + } + calculateTotalResponsesToWait(configId, profilesToCollect, listener); + } + + private void calculateTotalResponsesToWait( + String configId, + Set profilesToCollect, + ActionListener listener + ) { + GetRequest getConfigRequest = new GetRequest(CommonName.CONFIG_INDEX, configId); + client.get(getConfigRequest, ActionListener.wrap(getConfigResponse -> { + if (getConfigResponse != null && getConfigResponse.isExists()) { + try ( + XContentParser xContentParser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getConfigResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser); + Config config = configParser.apply(xContentParser, configId); + prepareProfile(config, listener, profilesToCollect); + } catch (Exception e) { + logger.error(CommonMessages.FAIL_TO_PARSE_CONFIG_MSG + configId, e); + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_PARSE_CONFIG_MSG + configId, BAD_REQUEST)); + } + } else { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, NOT_FOUND)); + } + }, exception -> { + logger.error(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, exception); + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, NOT_FOUND)); + })); + } + + protected void prepareProfile(Config config, ActionListener listener, Set profilesToCollect) { + String configId = config.getId(); + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX, configId); + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (getResponse != null && getResponse.isExists()) { + try ( + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job job = Job.parse(parser); + long enabledTimeMs = job.getEnabledTime().toEpochMilli(); + + int totalResponsesToWait = 0; + if (profilesToCollect.contains(ProfileName.ERROR)) { + totalResponsesToWait++; + } + + // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide + // when to consolidate results and return to users + if (profilesToCollect.contains(ProfileName.TOTAL_ENTITIES)) { + totalResponsesToWait++; + } + if (profilesToCollect.contains(ProfileName.COORDINATING_NODE) + || profilesToCollect.contains(ProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(ProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(ProfileName.MODELS) + || profilesToCollect.contains(ProfileName.ACTIVE_ENTITIES) + || profilesToCollect.contains(ProfileName.INIT_PROGRESS) + || profilesToCollect.contains(ProfileName.STATE)) { + totalResponsesToWait++; + } + if (profilesToCollect.contains(taskProfile)) { + totalResponsesToWait++; + } + + MultiResponsesDelegateActionListener delegateListener = + new MultiResponsesDelegateActionListener( + listener, + totalResponsesToWait, + CommonMessages.FAIL_FETCH_ERR_MSG + configId, + false + ); + if (profilesToCollect.contains(ProfileName.ERROR)) { + taskManager.getAndExecuteOnLatestConfigLevelTask(configId, realTimeTaskTypes, task -> { + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + if (task.isPresent()) { + long lastUpdateTimeMs = task.get().getLastUpdateTime().toEpochMilli(); + + // if state index hasn't been updated, we should not use the error field + // For example, before a detector is enabled, if the error message contains + // the phrase "stopped due to blah", we should not show this when the detector + // is enabled. + if (lastUpdateTimeMs > enabledTimeMs && task.get().getError() != null) { + profileBuilder.error(task.get().getError()); + } + delegateListener.onResponse(profileBuilder.build()); + } else { + // detector state for this detector does not exist + delegateListener.onResponse(profileBuilder.build()); + } + }, transportService, false, delegateListener); + } + + // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide + // when to consolidate results and return to users + if (profilesToCollect.contains(ProfileName.TOTAL_ENTITIES)) { + profileEntityStats(delegateListener, config); + } + if (profilesToCollect.contains(ProfileName.COORDINATING_NODE) + || profilesToCollect.contains(ProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(ProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(ProfileName.MODELS) + || profilesToCollect.contains(ProfileName.ACTIVE_ENTITIES) + || profilesToCollect.contains(ProfileName.INIT_PROGRESS) + || profilesToCollect.contains(ProfileName.STATE)) { + profileModels(config, profilesToCollect, job, true, delegateListener); + } + if (profilesToCollect.contains(taskProfile)) { + getLatestHistoricalTaskProfile(configId, transportService, null, delegateListener); + } + + } catch (Exception e) { + logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG, e); + listener.onFailure(e); + } + } else { + onGetDetectorForPrepare(configId, listener, profilesToCollect); + } + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + logger.info(exception.getMessage()); + onGetDetectorForPrepare(configId, listener, profilesToCollect); + } else { + logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG + configId); + listener.onFailure(exception); + } + })); + } + + private void profileEntityStats(MultiResponsesDelegateActionListener listener, Config config) { + List categoryField = config.getCategoryFields(); + if (!config.isHighCardinality() || categoryField.size() > maxCategoricalFields) { + listener.onResponse(createProfileBuilder().build()); + } else { + if (categoryField.size() == 1) { + // Run a cardinality aggregation to count the cardinality of single category fields + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + CardinalityAggregationBuilder aggBuilder = new CardinalityAggregationBuilder(CommonName.TOTAL_ENTITIES); + aggBuilder.field(categoryField.get(0)); + searchSourceBuilder.aggregation(aggBuilder); + + SearchRequest request = new SearchRequest(config.getIndices().toArray(new String[0]), searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { + Map aggMap = searchResponse.getAggregations().asMap(); + InternalCardinality totalEntities = (InternalCardinality) aggMap.get(CommonName.TOTAL_ENTITIES); + long value = totalEntities.getValue(); + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + ConfigProfileType profile = profileBuilder.totalEntities(value).build(); + listener.onResponse(profile); + }, searchException -> { + logger.warn(CommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + config.getId()); + listener.onFailure(searchException); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + config.getId(), + client, + analysisType, + searchResponseListener + ); + } else { + // Run a composite query and count the number of buckets to decide cardinality of multiple category fields + AggregationBuilder bucketAggs = AggregationBuilders + .composite( + CommonName.TOTAL_ENTITIES, + config.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) + ) + .size(maxTotalEntitiesToTrack); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(bucketAggs).trackTotalHits(false).size(0); + SearchRequest searchRequest = new SearchRequest() + .indices(config.getIndices().toArray(new String[0])) + .source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + Aggregations aggs = searchResponse.getAggregations(); + if (aggs == null) { + // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date + // with + // the large amounts of changes there). For example, they may change to if there are results return it; otherwise + // return + // null instead of an empty Aggregations as they currently do. + logger.warn("Unexpected null aggregation."); + listener.onResponse(profileBuilder.totalEntities(0L).build()); + return; + } + + Aggregation aggrResult = aggs.get(CommonName.TOTAL_ENTITIES); + if (aggrResult == null) { + listener.onFailure(new IllegalArgumentException("Fail to find valid aggregation result")); + return; + } + + CompositeAggregation compositeAgg = (CompositeAggregation) aggrResult; + ConfigProfileType profile = profileBuilder.totalEntities(Long.valueOf(compositeAgg.getBuckets().size())).build(); + listener.onResponse(profile); + }, searchException -> { + logger.warn(CommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + config.getId()); + listener.onFailure(searchException); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + config.getId(), + client, + analysisType, + searchResponseListener + ); + } + + } + } + + protected void onGetDetectorForPrepare(String configId, ActionListener listener, Set profiles) { + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + if (profiles.contains(ProfileName.STATE)) { + profileBuilder.state(ConfigState.DISABLED); + } + if (profiles.contains(taskProfile)) { + getLatestHistoricalTaskProfile(configId, transportService, profileBuilder.build(), listener); + } else { + listener.onResponse(profileBuilder.build()); + } + } + + /** + * Profile models related + * + * @param config Config accessor + * @param profiles profiles to collect + * @param job Job accessor + * @param modelInPriorityCache Whether the models are stored in priority cache. AD single stream models are stored in ModelManager. + * Other models are stored in priority cache. + * @param listener returns collected profiles + */ + protected void profileModels( + Config config, + Set profiles, + Job job, + boolean modelInPriorityCache, + MultiResponsesDelegateActionListener listener + ) { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + ProfileRequest profileRequest = new ProfileRequest(config.getId(), profiles, modelInPriorityCache, dataNodes); + client.execute(profileAction, profileRequest, onModelResponse(config, profiles, job, modelInPriorityCache, listener));// get init + // progress + } + + private ActionListener onModelResponse( + Config config, + Set profilesToCollect, + Job job, + boolean modelInPriorityCache, + MultiResponsesDelegateActionListener listener + ) { + boolean isMultientityDetector = config.isHighCardinality(); + return ActionListener.wrap(profileResponse -> { + ConfigProfileType.Builder profile = createProfileBuilder(); + if (profilesToCollect.contains(ProfileName.COORDINATING_NODE)) { + profile.coordinatingNode(profileResponse.getCoordinatingNode()); + } + if (profilesToCollect.contains(ProfileName.SHINGLE_SIZE)) { + profile.shingleSize(profileResponse.getShingleSize()); + } + if (profilesToCollect.contains(ProfileName.TOTAL_SIZE_IN_BYTES)) { + profile.totalSizeInBytes(profileResponse.getTotalSizeInBytes()); + } + if (profilesToCollect.contains(ProfileName.MODELS)) { + profile.modelProfile(profileResponse.getModelProfile()); + profile.modelCount(profileResponse.getModelCount()); + } + if (isMultientityDetector && profilesToCollect.contains(ProfileName.ACTIVE_ENTITIES)) { + profile.activeEntities(profileResponse.getActiveEntities()); + } + + // only need to do it for models in priority cache. AD single stream analysis has a + // different workflow to determine state and init progress + if (modelInPriorityCache + && (profilesToCollect.contains(ProfileName.INIT_PROGRESS) || profilesToCollect.contains(ProfileName.STATE))) { + profileStateRelated(job, profilesToCollect, profileResponse, profile, config, listener); + } else { + listener.onResponse(profile.build()); + } + }, listener::onFailure); + } + + private void profileStateRelated( + Job job, + Set profilesToCollect, + ProfileResponse profileResponse, + ConfigProfileType.Builder profileBuilder, + Config config, + MultiResponsesDelegateActionListener listener + ) { + if (job.isEnabled()) { + if (profileResponse.getTotalUpdates() < requiredSamples) { + // need to double check for an HC analysis + // since what ProfileResponse returns is the highest priority entity currently in memory, but + // another entity might have already been initialized and sit somewhere else (in memory or on disk). + long enabledTime = job.getEnabledTime().toEpochMilli(); + long totalUpdates = profileResponse.getTotalUpdates(); + ProfileUtil + .confirmRealtimeInitStatus( + config, + enabledTime, + client, + analysisType, + onInittedEver(enabledTime, profileBuilder, profilesToCollect, config, totalUpdates, listener) + ); + } else { + createRunningStateAndInitProgress(profilesToCollect, profileBuilder); + listener.onResponse(profileBuilder.build()); + } + } else { + if (profilesToCollect.contains(ProfileName.STATE)) { + profileBuilder.state(ConfigState.DISABLED); + } + listener.onResponse(profileBuilder.build()); + } + } + + private ActionListener onInittedEver( + long lastUpdateTimeMs, + ConfigProfileType.Builder profileBuilder, + Set profilesToCollect, + Config config, + long totalUpdates, + MultiResponsesDelegateActionListener listener + ) { + return ActionListener.wrap(searchResponse -> { + SearchHits hits = searchResponse.getHits(); + if (hits.getTotalHits().value == 0L) { + processInitResponse(config, profilesToCollect, totalUpdates, false, profileBuilder, listener); + } else { + createRunningStateAndInitProgress(profilesToCollect, profileBuilder); + listener.onResponse(profileBuilder.build()); + } + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + // anomaly result index is not created yet + processInitResponse(config, profilesToCollect, totalUpdates, false, profileBuilder, listener); + } else { + logger + .error( + "Fail to find any anomaly result with anomaly score larger than 0 after AD job enabled time for detector {}", + config.getId() + ); + listener.onFailure(exception); + } + }); + } + + protected void createRunningStateAndInitProgress( + Set profilesToCollect, + ConfigProfileType.Builder builder + ) { + if (profilesToCollect.contains(ProfileName.STATE)) { + builder.state(ConfigState.RUNNING).build(); + } + + if (profilesToCollect.contains(ProfileName.INIT_PROGRESS)) { + InitProgressProfile initProgress = new InitProgressProfile("100%", 0, 0); + builder.initProgress(initProgress); + } + } + + protected void processInitResponse( + Config config, + Set profilesToCollect, + long totalUpdates, + boolean hideMinutesLeft, + ConfigProfileType.Builder builder, + MultiResponsesDelegateActionListener listener + ) { + if (profilesToCollect.contains(ProfileName.STATE)) { + builder.state(ConfigState.INIT); + } + + if (profilesToCollect.contains(ProfileName.INIT_PROGRESS)) { + if (hideMinutesLeft) { + InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, 0); + builder.initProgress(initProgress); + } else { + long intervalMins = ((IntervalTimeConfiguration) config.getInterval()).toDuration().toMinutes(); + InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, intervalMins); + builder.initProgress(initProgress); + } + } + + listener.onResponse(builder.build()); + } + + /** + * Get latest historical config task profile. + * Will not reset task state in this method. + * + * @param configId config id + * @param transportService transport service + * @param profile config profile + * @param listener action listener + */ + public void getLatestHistoricalTaskProfile( + String configId, + TransportService transportService, + ConfigProfileType profile, + ActionListener listener + ) { + taskManager.getAndExecuteOnLatestConfigTask(configId, null, null, batchConfigTaskTypes, task -> { + if (task.isPresent()) { + taskProfileRunner.getTaskProfile(task.get(), ActionListener.wrap(taskProfile -> { + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + profileBuilder.taskProfile(taskProfile); + ConfigProfileType configProfile = profileBuilder.build(); + configProfile.merge(profile); + listener.onResponse(configProfile); + }, e -> { + logger.error("Failed to get task profile for task " + task.get().getTaskId(), e); + listener.onFailure(e); + })); + } else { + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + listener.onResponse(profileBuilder.build()); + } + }, transportService, false, listener); + } + + protected abstract ConfigProfileType.Builder createProfileBuilder(); + +} diff --git a/src/main/java/org/opensearch/timeseries/ProfileTask.java b/src/main/java/org/opensearch/timeseries/ProfileTask.java new file mode 100644 index 000000000..9b68a2db6 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ProfileTask.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.model.TimeSeriesTask; + +/** + * Break the cross dependency between TaskManager and ProfileRunner. Instead of + * depending on each other, they depend on the interface. + * + */ +public interface ProfileTask> { + void getTaskProfile(TaskClass configLevelTask, ActionListener listener); +} diff --git a/src/main/java/org/opensearch/timeseries/ProfileUtil.java b/src/main/java/org/opensearch/timeseries/ProfileUtil.java new file mode 100644 index 000000000..b6de04ba7 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ProfileUtil.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Config; + +public class ProfileUtil { + /** + * Create search request to check if we have at least 1 anomaly score larger than 0 after AD job enabled time. + * Note this function is only meant to check for status of real time analysis. + * + * @param detectorId detector id + * @param enabledTime the time when AD job is enabled in milliseconds + * @return the search request + */ + private static SearchRequest createADRealtimeInittedEverRequest(String detectorId, long enabledTime, String resultIndex) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); + filterQuery.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); + filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.ANOMALY_SCORE_FIELD).gt(0)); + // Historical analysis result also stored in result index, which has non-null task_id. + // For realtime detection result, we should filter task_id == null + ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); + filterQuery.mustNot(taskIdExistsFilter); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); + + SearchRequest request = new SearchRequest(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + request.source(source); + if (resultIndex != null) { + request.indices(resultIndex); + } + return request; + } + + /** + * Create search request to check if we have at least 1 forecast after AD job enabled time. + * Note this function is only meant to check for status of real time analysis. + * + * @param forecasterId forecaster id + * @param enabledTime the time when forecast job is enabled in milliseconds + * @return the search request + */ + private static SearchRequest createForecastRealtimeInittedEverRequest(String forecasterId, long enabledTime, String resultIndex) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(ForecastCommonName.FORECASTER_ID_KEY, forecasterId)); + filterQuery.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); + ExistsQueryBuilder forecastsExistFilter = QueryBuilders.existsQuery(ForecastResult.VALUE_FIELD); + filterQuery.must(forecastsExistFilter); + // Historical/run-once analysis result also stored in result index, which has non-null task_id. + // For realtime detection result, we should filter task_id == null + ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); + filterQuery.mustNot(taskIdExistsFilter); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); + + SearchRequest request = new SearchRequest(ForecastIndex.RESULT.getIndexName()); + request.source(source); + if (resultIndex != null) { + request.indices(resultIndex); + } + return request; + } + + public static void confirmRealtimeInitStatus( + Config config, + long enabledTime, + Client client, + AnalysisType analysisType, + ActionListener listener + ) { + SearchRequest searchLatestResult = null; + if (analysisType.isAD()) { + searchLatestResult = createADRealtimeInittedEverRequest(config.getId(), enabledTime, config.getCustomResultIndex()); + } else if (analysisType.isForecast()) { + searchLatestResult = createForecastRealtimeInittedEverRequest(config.getId(), enabledTime, config.getCustomResultIndex()); + } else { + throw new IllegalArgumentException("Analysis type is not supported, type: : " + analysisType); + } + + client.search(searchLatestResult, listener); + } +} diff --git a/src/main/java/org/opensearch/timeseries/TaskProfile.java b/src/main/java/org/opensearch/timeseries/TaskProfile.java new file mode 100644 index 000000000..4abf41897 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/TaskProfile.java @@ -0,0 +1,173 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import java.io.IOException; +import java.util.Objects; + +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.model.TimeSeriesTask; + +public abstract class TaskProfile implements ToXContentObject, Writeable { + + public static final String SHINGLE_SIZE_FIELD = "shingle_size"; + public static final String RCF_TOTAL_UPDATES_FIELD = "rcf_total_updates"; + public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes"; + public static final String NODE_ID_FIELD = "node_id"; + public static final String TASK_ID_FIELD = "task_id"; + public static final String TASK_TYPE_FIELD = "task_type"; + public static final String ENTITY_TASK_PROFILE_FIELD = "entity_task_profiles"; + + protected TaskType task; + protected Integer shingleSize; + protected Long rcfTotalUpdates; + protected Long modelSizeInBytes; + protected String nodeId; + protected String taskId; + protected String taskType; + + public TaskProfile() { + + } + + public TaskProfile(TaskType task) { + this.task = task; + } + + public TaskProfile(String taskId, int shingleSize, long rcfTotalUpdates, long modelSizeInBytes, String nodeId) { + this.taskId = taskId; + this.shingleSize = shingleSize; + this.rcfTotalUpdates = rcfTotalUpdates; + this.modelSizeInBytes = modelSizeInBytes; + this.nodeId = nodeId; + } + + public TaskProfile( + TaskType adTask, + Integer shingleSize, + Long rcfTotalUpdates, + Long modelSizeInBytes, + String nodeId, + String taskId, + String adTaskType + ) { + this.task = adTask; + this.shingleSize = shingleSize; + this.rcfTotalUpdates = rcfTotalUpdates; + this.modelSizeInBytes = modelSizeInBytes; + this.nodeId = nodeId; + this.taskId = taskId; + this.taskType = adTaskType; + } + + public TaskType getTask() { + return task; + } + + public void setTask(TaskType adTask) { + this.task = adTask; + } + + public Integer getShingleSize() { + return shingleSize; + } + + public void setShingleSize(Integer shingleSize) { + this.shingleSize = shingleSize; + } + + public Long getRcfTotalUpdates() { + return rcfTotalUpdates; + } + + public void setRcfTotalUpdates(Long rcfTotalUpdates) { + this.rcfTotalUpdates = rcfTotalUpdates; + } + + public Long getModelSizeInBytes() { + return modelSizeInBytes; + } + + public void setModelSizeInBytes(Long modelSizeInBytes) { + this.modelSizeInBytes = modelSizeInBytes; + } + + public String getNodeId() { + return nodeId; + } + + public void setNodeId(String nodeId) { + this.nodeId = nodeId; + } + + public String getTaskId() { + return taskId; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public String getTaskType() { + return taskType; + } + + public void setTaskType(String taskType) { + this.taskType = taskType; + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + TaskProfile that = (TaskProfile) o; + return Objects.equals(task, that.task) + && Objects.equals(shingleSize, that.shingleSize) + && Objects.equals(rcfTotalUpdates, that.rcfTotalUpdates) + && Objects.equals(modelSizeInBytes, that.modelSizeInBytes) + && Objects.equals(nodeId, that.nodeId) + && Objects.equals(taskId, that.taskId) + && Objects.equals(taskType, that.taskType); + } + + @Generated + @Override + public int hashCode() { + return Objects.hash(task, shingleSize, rcfTotalUpdates, modelSizeInBytes, nodeId, taskId, taskType); + } + + protected void toXContent(XContentBuilder xContentBuilder) throws IOException { + if (task != null) { + xContentBuilder.field(getTaskFieldName(), task); + } + if (shingleSize != null) { + xContentBuilder.field(SHINGLE_SIZE_FIELD, shingleSize); + } + if (rcfTotalUpdates != null) { + xContentBuilder.field(RCF_TOTAL_UPDATES_FIELD, rcfTotalUpdates); + } + if (modelSizeInBytes != null) { + xContentBuilder.field(MODEL_SIZE_IN_BYTES, modelSizeInBytes); + } + if (nodeId != null) { + xContentBuilder.field(NODE_ID_FIELD, nodeId); + } + if (taskId != null) { + xContentBuilder.field(TASK_ID_FIELD, taskId); + } + if (taskType != null) { + xContentBuilder.field(TASK_TYPE_FIELD, taskType); + } + } + + protected abstract String getTaskFieldName(); +} diff --git a/src/main/java/org/opensearch/timeseries/TaskProfileRunner.java b/src/main/java/org/opensearch/timeseries/TaskProfileRunner.java new file mode 100644 index 000000000..6f29fd244 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/TaskProfileRunner.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.model.TimeSeriesTask; + +/** + * break the cross dependency between ProfileRunner and TaskManager. Instead, both of them depend on TaskProfileRunner. + */ +public interface TaskProfileRunner> { + void getTaskProfile(TaskClass configLevelTask, ActionListener listener); +} diff --git a/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java index 7dadac650..6ac451d48 100644 --- a/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java @@ -12,6 +12,8 @@ package org.opensearch.timeseries; import static java.util.Collections.unmodifiableList; +import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; import java.security.AccessController; import java.security.PrivilegedAction; @@ -33,33 +35,29 @@ import org.apache.logging.log4j.Logger; import org.opensearch.SpecialPermission; import org.opensearch.action.ActionRequest; -import org.opensearch.ad.AnomalyDetectorJobRunner; +import org.opensearch.ad.ADJobProcessor; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.AnomalyDetectorRunner; import org.opensearch.ad.ExecuteADResultResponseRecorder; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.caching.PriorityCache; -import org.opensearch.ad.cluster.ADClusterEventListener; -import org.opensearch.ad.cluster.ADDataMigrator; -import org.opensearch.ad.cluster.ClusterManagerEventListener; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.HybridThresholdingModel; -import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.DetectorInternalState; -import org.opensearch.ad.ratelimit.CheckPointMaintainRequestAdapter; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; +import org.opensearch.ad.ratelimit.ADSaveResultStrategy; import org.opensearch.ad.rest.RestAnomalyDetectorJobAction; import org.opensearch.ad.rest.RestDeleteAnomalyDetectorAction; import org.opensearch.ad.rest.RestDeleteAnomalyResultsAction; @@ -74,17 +72,14 @@ import org.opensearch.ad.rest.RestSearchTopAnomalyResultAction; import org.opensearch.ad.rest.RestStatsAnomalyDetectorAction; import org.opensearch.ad.rest.RestValidateAnomalyDetectorAction; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.ADNumericSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.settings.LegacyOpenDistroAnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeCountSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; +import org.opensearch.ad.stats.suppliers.ADModelsOnNodeCountSupplier; +import org.opensearch.ad.stats.suppliers.ADModelsOnNodeSupplier; import org.opensearch.ad.task.ADBatchTaskRunner; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; @@ -94,6 +89,10 @@ import org.opensearch.ad.transport.ADBatchTaskRemoteExecutionTransportAction; import org.opensearch.ad.transport.ADCancelTaskAction; import org.opensearch.ad.transport.ADCancelTaskTransportAction; +import org.opensearch.ad.transport.ADEntityProfileAction; +import org.opensearch.ad.transport.ADEntityProfileTransportAction; +import org.opensearch.ad.transport.ADProfileAction; +import org.opensearch.ad.transport.ADProfileTransportAction; import org.opensearch.ad.transport.ADResultBulkAction; import org.opensearch.ad.transport.ADResultBulkTransportAction; import org.opensearch.ad.transport.ADStatsNodesAction; @@ -105,17 +104,14 @@ import org.opensearch.ad.transport.AnomalyResultAction; import org.opensearch.ad.transport.AnomalyResultTransportAction; import org.opensearch.ad.transport.CronAction; -import org.opensearch.ad.transport.CronTransportAction; +import org.opensearch.ad.transport.DeleteADModelAction; +import org.opensearch.ad.transport.DeleteADModelTransportAction; import org.opensearch.ad.transport.DeleteAnomalyDetectorAction; import org.opensearch.ad.transport.DeleteAnomalyDetectorTransportAction; import org.opensearch.ad.transport.DeleteAnomalyResultsAction; import org.opensearch.ad.transport.DeleteAnomalyResultsTransportAction; -import org.opensearch.ad.transport.DeleteModelAction; -import org.opensearch.ad.transport.DeleteModelTransportAction; -import org.opensearch.ad.transport.EntityProfileAction; -import org.opensearch.ad.transport.EntityProfileTransportAction; -import org.opensearch.ad.transport.EntityResultAction; -import org.opensearch.ad.transport.EntityResultTransportAction; +import org.opensearch.ad.transport.EntityADResultAction; +import org.opensearch.ad.transport.EntityADResultTransportAction; import org.opensearch.ad.transport.ForwardADTaskAction; import org.opensearch.ad.transport.ForwardADTaskTransportAction; import org.opensearch.ad.transport.GetAnomalyDetectorAction; @@ -124,8 +120,6 @@ import org.opensearch.ad.transport.IndexAnomalyDetectorTransportAction; import org.opensearch.ad.transport.PreviewAnomalyDetectorAction; import org.opensearch.ad.transport.PreviewAnomalyDetectorTransportAction; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileTransportAction; import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingTransportAction; import org.opensearch.ad.transport.RCFResultAction; @@ -148,11 +142,8 @@ import org.opensearch.ad.transport.ThresholdResultTransportAction; import org.opensearch.ad.transport.ValidateAnomalyDetectorAction; import org.opensearch.ad.transport.ValidateAnomalyDetectorTransportAction; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; import org.opensearch.ad.transport.handler.ADSearchHandler; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; -import org.opensearch.ad.transport.handler.AnomalyResultBulkIndexHandler; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNodes; @@ -171,8 +162,91 @@ import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.forecast.ExecuteForecastResultResponseRecorder; +import org.opensearch.forecast.ForecastJobProcessor; +import org.opensearch.forecast.ForecastTaskProfileRunner; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.model.ForecastResult; import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.ratelimit.ForecastCheckpointMaintainWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointReadWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastColdEntityWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.forecast.ratelimit.ForecastResultWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastSaveResultStrategy; +import org.opensearch.forecast.rest.RestDeleteForecasterAction; +import org.opensearch.forecast.rest.RestForecasterJobAction; +import org.opensearch.forecast.rest.RestForecasterSuggestAction; +import org.opensearch.forecast.rest.RestGetForecasterAction; +import org.opensearch.forecast.rest.RestIndexForecasterAction; +import org.opensearch.forecast.rest.RestRunOnceForecasterAction; +import org.opensearch.forecast.rest.RestSearchForecastTasksAction; +import org.opensearch.forecast.rest.RestSearchForecasterAction; +import org.opensearch.forecast.rest.RestSearchForecasterInfoAction; +import org.opensearch.forecast.rest.RestSearchTopForecastResultAction; +import org.opensearch.forecast.rest.RestStatsForecasterAction; +import org.opensearch.forecast.rest.RestValidateForecasterAction; +import org.opensearch.forecast.rest.handler.ForecastIndexJobActionHandler; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.settings.ForecastNumericSetting; import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.stats.ForecastModelsOnNodeSupplier; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.forecast.stats.suppliers.ForecastModelsOnNodeCountSupplier; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.DeleteForecastModelAction; +import org.opensearch.forecast.transport.DeleteForecastModelTransportAction; +import org.opensearch.forecast.transport.DeleteForecasterAction; +import org.opensearch.forecast.transport.DeleteForecasterTransportAction; +import org.opensearch.forecast.transport.EntityForecastResultAction; +import org.opensearch.forecast.transport.EntityForecastResultTransportAction; +import org.opensearch.forecast.transport.ForecastEntityProfileAction; +import org.opensearch.forecast.transport.ForecastEntityProfileTransportAction; +import org.opensearch.forecast.transport.ForecastProfileAction; +import org.opensearch.forecast.transport.ForecastProfileTransportAction; +import org.opensearch.forecast.transport.ForecastResultAction; +import org.opensearch.forecast.transport.ForecastResultBulkAction; +import org.opensearch.forecast.transport.ForecastResultBulkTransportAction; +import org.opensearch.forecast.transport.ForecastResultTransportAction; +import org.opensearch.forecast.transport.ForecastRunOnceAction; +import org.opensearch.forecast.transport.ForecastRunOnceProfileAction; +import org.opensearch.forecast.transport.ForecastRunOnceProfileTransportAction; +import org.opensearch.forecast.transport.ForecastRunOnceTransportAction; +import org.opensearch.forecast.transport.ForecastSingleStreamResultAction; +import org.opensearch.forecast.transport.ForecastSingleStreamResultTransportAction; +import org.opensearch.forecast.transport.ForecastStatsNodesAction; +import org.opensearch.forecast.transport.ForecastStatsNodesTransportAction; +import org.opensearch.forecast.transport.ForecasterJobAction; +import org.opensearch.forecast.transport.ForecasterJobTransportAction; +import org.opensearch.forecast.transport.GetForecasterAction; +import org.opensearch.forecast.transport.GetForecasterTransportAction; +import org.opensearch.forecast.transport.IndexForecasterAction; +import org.opensearch.forecast.transport.IndexForecasterTransportAction; +import org.opensearch.forecast.transport.SearchForecastTasksAction; +import org.opensearch.forecast.transport.SearchForecastTasksTransportAction; +import org.opensearch.forecast.transport.SearchForecasterAction; +import org.opensearch.forecast.transport.SearchForecasterInfoAction; +import org.opensearch.forecast.transport.SearchForecasterInfoTransportAction; +import org.opensearch.forecast.transport.SearchForecasterTransportAction; +import org.opensearch.forecast.transport.SearchTopForecastResultAction; +import org.opensearch.forecast.transport.SearchTopForecastResultTransportAction; +import org.opensearch.forecast.transport.StatsForecasterAction; +import org.opensearch.forecast.transport.StatsForecasterTransportAction; +import org.opensearch.forecast.transport.StopForecasterAction; +import org.opensearch.forecast.transport.StopForecasterTransportAction; +import org.opensearch.forecast.transport.SuggestForecasterParamAction; +import org.opensearch.forecast.transport.SuggestForecasterParamTransportAction; +import org.opensearch.forecast.transport.ValidateForecasterAction; +import org.opensearch.forecast.transport.ValidateForecasterTransportAction; +import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; +import org.opensearch.forecast.transport.handler.ForecastSearchHandler; import org.opensearch.jobscheduler.spi.JobSchedulerExtension; import org.opensearch.jobscheduler.spi.ScheduledJobParser; import org.opensearch.jobscheduler.spi.ScheduledJobRunner; @@ -189,19 +263,38 @@ import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.ADDataMigrator; +import org.opensearch.timeseries.cluster.ClusterEventListener; +import org.opensearch.timeseries.cluster.ClusterManagerEventListener; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.function.ThrowingSupplierWrapper; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.ratelimit.CheckPointMaintainRequestAdapter; +import org.opensearch.timeseries.settings.TimeSeriesEnabledSetting; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.CronTransportAction; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.IndexUtils; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.watcher.ResourceWatcherService; +import com.amazon.randomcutforest.parkservices.RCFCaster; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.amazon.randomcutforest.parkservices.state.RCFCasterMapper; +import com.amazon.randomcutforest.parkservices.state.RCFCasterState; import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter; @@ -216,7 +309,7 @@ import io.protostuff.runtime.RuntimeSchema; /** - * Entry point of AD plugin. + * Entry point of time series analytics plugin. */ public class TimeSeriesAnalyticsPlugin extends Plugin implements ActionPlugin, ScriptPlugin, JobSchedulerExtension { @@ -236,27 +329,32 @@ public class TimeSeriesAnalyticsPlugin extends Plugin implements ActionPlugin, S public static final String FORECAST_FORECASTERS_URI = FORECAST_BASE_URI + "/forecasters"; public static final String FORECAST_THREAD_POOL_PREFIX = "opensearch.forecast."; public static final String FORECAST_THREAD_POOL_NAME = "forecast-threadpool"; - public static final String FORECAST_BATCH_TASK_THREAD_POOL_NAME = "forecast-batch-task-threadpool"; public static final String TIME_SERIES_JOB_TYPE = "opensearch_time_series_analytics"; private static Gson gson; private ADIndexManagement anomalyDetectionIndices; + private ForecastIndexManagement forecastIndices; private AnomalyDetectorRunner anomalyDetectorRunner; private Client client; private ClusterService clusterService; private ThreadPool threadPool; private ADStats adStats; + private ForecastStats forecastStats; private ClientUtil clientUtil; private SecurityClientUtil securityClientUtil; private DiscoveryNodeFilterer nodeFilter; private IndexUtils indexUtils; private ADTaskManager adTaskManager; + private ForecastTaskManager forecastTaskManager; private ADBatchTaskRunner adBatchTaskRunner; // package private for testing GenericObjectPool serializeRCFBufferPool; private NodeStateManager stateManager; private ExecuteADResultResponseRecorder adResultResponseRecorder; + private ExecuteForecastResultResponseRecorder forecastResultResponseRecorder; + private ADIndexJobActionHandler adIndexJobActionHandler; + private ForecastIndexJobActionHandler forecastIndexJobActionHandler; static { SpecialPermission.check(); @@ -277,14 +375,16 @@ public List getRestHandlers( IndexNameExpressionResolver indexNameExpressionResolver, Supplier nodesInCluster ) { - AnomalyDetectorJobRunner jobRunner = AnomalyDetectorJobRunner.getJobRunnerInstance(); - jobRunner.setClient(client); - jobRunner.setThreadPool(threadPool); - jobRunner.setSettings(settings); - jobRunner.setAnomalyDetectionIndices(anomalyDetectionIndices); - jobRunner.setAdTaskManager(adTaskManager); - jobRunner.setNodeStateManager(stateManager); - jobRunner.setExecuteADResultResponseRecorder(adResultResponseRecorder); + // AD + ADJobProcessor adJobRunner = ADJobProcessor.getInstance(); + adJobRunner.setClient(client); + adJobRunner.setThreadPool(threadPool); + adJobRunner.registerSettings(settings); + adJobRunner.setIndexManagement(anomalyDetectionIndices); + adJobRunner.setTaskManager(adTaskManager); + adJobRunner.setNodeStateManager(stateManager); + adJobRunner.setExecuteResultResponseRecorder(adResultResponseRecorder); + adJobRunner.setIndexJobActionHandler(adIndexJobActionHandler); RestGetAnomalyDetectorAction restGetAnomalyDetectorAction = new RestGetAnomalyDetectorAction(); RestIndexAnomalyDetectorAction restIndexAnomalyDetectorAction = new RestIndexAnomalyDetectorAction(settings, clusterService); @@ -301,8 +401,33 @@ public List getRestHandlers( RestSearchTopAnomalyResultAction searchTopAnomalyResultAction = new RestSearchTopAnomalyResultAction(); RestValidateAnomalyDetectorAction validateAnomalyDetectorAction = new RestValidateAnomalyDetectorAction(settings, clusterService); + // Forecast + RestIndexForecasterAction restIndexForecasterAction = new RestIndexForecasterAction(settings, clusterService); + RestForecasterJobAction restForecasterJobAction = new RestForecasterJobAction(); + RestGetForecasterAction restGetForecasterAction = new RestGetForecasterAction(); + RestDeleteForecasterAction deleteForecasterAction = new RestDeleteForecasterAction(); + RestSearchForecasterAction searchForecasterAction = new RestSearchForecasterAction(); + RestSearchForecasterInfoAction searchForecasterInfoAction = new RestSearchForecasterInfoAction(); + RestSearchTopForecastResultAction searchTopForecastResultAction = new RestSearchTopForecastResultAction(); + RestSearchForecastTasksAction searchForecastTasksAction = new RestSearchForecastTasksAction(); + RestStatsForecasterAction statsForecasterAction = new RestStatsForecasterAction(forecastStats, this.nodeFilter); + RestRunOnceForecasterAction runOnceForecasterAction = new RestRunOnceForecasterAction(); + RestValidateForecasterAction validateForecasterAction = new RestValidateForecasterAction(settings, clusterService); + RestForecasterSuggestAction suggestForecasterParamAction = new RestForecasterSuggestAction(settings, clusterService); + + ForecastJobProcessor forecastJobRunner = ForecastJobProcessor.getInstance(); + forecastJobRunner.setClient(client); + forecastJobRunner.setThreadPool(threadPool); + forecastJobRunner.registerSettings(settings); + forecastJobRunner.setIndexManagement(forecastIndices); + forecastJobRunner.setTaskManager(forecastTaskManager); + forecastJobRunner.setNodeStateManager(stateManager); + forecastJobRunner.setExecuteResultResponseRecorder(forecastResultResponseRecorder); + forecastJobRunner.setIndexJobActionHandler(forecastIndexJobActionHandler); + return ImmutableList .of( + // AD restGetAnomalyDetectorAction, restIndexAnomalyDetectorAction, searchAnomalyDetectorAction, @@ -316,7 +441,20 @@ public List getRestHandlers( previewAnomalyDetectorAction, deleteAnomalyResultsAction, searchTopAnomalyResultAction, - validateAnomalyDetectorAction + validateAnomalyDetectorAction, + // Forecast + restIndexForecasterAction, + restForecasterJobAction, + restGetForecasterAction, + deleteForecasterAction, + searchForecasterAction, + searchForecasterInfoAction, + searchTopForecastResultAction, + searchForecastTasksAction, + statsForecasterAction, + runOnceForecasterAction, + validateForecasterAction, + suggestForecasterParamAction ); } @@ -339,30 +477,51 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - ADEnabledSetting.getInstance().init(clusterService); - ADNumericSetting.getInstance().init(clusterService); + // ===================== + // Common components + // ===================== this.client = client; this.threadPool = threadPool; Settings settings = environment.settings(); this.clientUtil = new ClientUtil(client); - this.indexUtils = new IndexUtils(client, clientUtil, clusterService, indexNameExpressionResolver); + this.indexUtils = new IndexUtils(clusterService, indexNameExpressionResolver); this.nodeFilter = new DiscoveryNodeFilterer(clusterService); - // convert from checked IOException to unchecked RuntimeException - this.anomalyDetectionIndices = ThrowingSupplierWrapper - .throwingSupplierWrapper( - () -> new ADIndexManagement( - client, - clusterService, - threadPool, - settings, - nodeFilter, - TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES - ) - ) - .get(); this.clusterService = clusterService; - Imputer imputer = new LinearUniformImputer(true); + + JvmService jvmService = new JvmService(environment.settings()); + RandomCutForestMapper rcfMapper = new RandomCutForestMapper(); + rcfMapper.setSaveExecutorContextEnabled(true); + rcfMapper.setSaveTreeStateEnabled(true); + rcfMapper.setPartialTreeStateEnabled(true); + V1JsonToV3StateConverter converter = new V1JsonToV3StateConverter(); + + CircuitBreakerService circuitBreakerService = new CircuitBreakerService(jvmService).init(); + + long heapSizeBytes = JvmInfo.jvmInfo().getMem().getHeapMax().getBytes(); + + serializeRCFBufferPool = AccessController.doPrivileged(new PrivilegedAction>() { + @Override + public GenericObjectPool run() { + return new GenericObjectPool<>(new BasePooledObjectFactory() { + @Override + public LinkedBuffer create() throws Exception { + return LinkedBuffer.allocate(TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES); + } + + @Override + public PooledObject wrap(LinkedBuffer obj) { + return new DefaultPooledObject<>(obj); + } + }); + } + }); + serializeRCFBufferPool.setMaxTotal(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMaxIdle(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMinIdle(0); + serializeRCFBufferPool.setBlockWhenExhausted(false); + serializeRCFBufferPool.setTimeBetweenEvictionRuns(TimeSeriesSettings.HOURLY_MAINTENANCE); + stateManager = new NodeStateManager( client, xContentRegistry, @@ -375,6 +534,7 @@ public Collection createComponents( TimeSeriesSettings.BACKOFF_MINUTES ); securityClientUtil = new SecurityClientUtil(stateManager, settings); + SearchFeatureDao searchFeatureDao = new SearchFeatureDao( client, xContentRegistry, @@ -385,27 +545,12 @@ public Collection createComponents( TimeSeriesSettings.NUM_SAMPLES_PER_TREE ); - JvmService jvmService = new JvmService(environment.settings()); - RandomCutForestMapper mapper = new RandomCutForestMapper(); - mapper.setSaveExecutorContextEnabled(true); - mapper.setSaveTreeStateEnabled(true); - mapper.setPartialTreeStateEnabled(true); - V1JsonToV3StateConverter converter = new V1JsonToV3StateConverter(); - - double modelMaxSizePercent = AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE.get(settings); - - CircuitBreakerService adCircuitBreakerService = new CircuitBreakerService(jvmService).init(); - - MemoryTracker memoryTracker = new MemoryTracker(jvmService, modelMaxSizePercent, clusterService, adCircuitBreakerService); - FeatureManager featureManager = new FeatureManager( searchFeatureDao, imputer, getClock(), - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + TimeSeriesSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + TimeSeriesSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, @@ -415,36 +560,36 @@ public Collection createComponents( AD_THREAD_POOL_NAME ); - long heapSizeBytes = JvmInfo.jvmInfo().getMem().getHeapMax().getBytes(); + Random random = new Random(42); - serializeRCFBufferPool = AccessController.doPrivileged(new PrivilegedAction>() { - @Override - public GenericObjectPool run() { - return new GenericObjectPool<>(new BasePooledObjectFactory() { - @Override - public LinkedBuffer create() throws Exception { - return LinkedBuffer.allocate(TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES); - } + // ===================== + // AD components + // ===================== + ADEnabledSetting.getInstance().init(clusterService); + ADNumericSetting.getInstance().init(clusterService); + // convert from checked IOException to unchecked RuntimeException + this.anomalyDetectionIndices = ThrowingSupplierWrapper + .throwingSupplierWrapper( + () -> new ADIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ) + ) + .get(); - @Override - public PooledObject wrap(LinkedBuffer obj) { - return new DefaultPooledObject<>(obj); - } - }); - } - }); - serializeRCFBufferPool.setMaxTotal(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); - serializeRCFBufferPool.setMaxIdle(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); - serializeRCFBufferPool.setMinIdle(0); - serializeRCFBufferPool.setBlockWhenExhausted(false); - serializeRCFBufferPool.setTimeBetweenEvictionRuns(TimeSeriesSettings.HOURLY_MAINTENANCE); + double adModelMaxSizePercent = AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE.get(settings); + + MemoryTracker adMemoryTracker = new MemoryTracker(jvmService, adModelMaxSizePercent, clusterService, circuitBreakerService); - CheckpointDao checkpoint = new CheckpointDao( + ADCheckpointDao adCheckpoint = new ADCheckpointDao( client, clientUtil, - ADCommonName.CHECKPOINT_INDEX_NAME, gson, - mapper, + rcfMapper, converter, new ThresholdedRandomCutForestMapper(), AccessController @@ -457,30 +602,30 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.MAX_CHECKPOINT_BYTES, serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - 1 - TimeSeriesSettings.THRESHOLD_MIN_PVALUE + 1 - TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + getClock() ); - Random random = new Random(42); + ADCacheProvider adCacheProvider = new ADCacheProvider(); - CacheProvider cacheProvider = new CacheProvider(); - - CheckPointMaintainRequestAdapter adapter = new CheckPointMaintainRequestAdapter( - cacheProvider, - checkpoint, - ADCommonName.CHECKPOINT_INDEX_NAME, - AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, - getClock(), - clusterService, - settings - ); + CheckPointMaintainRequestAdapter adAdapter = + new CheckPointMaintainRequestAdapter<>( + adCheckpoint, + ADCommonName.CHECKPOINT_INDEX_NAME, + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + getClock(), + clusterService, + settings, + adCacheProvider + ); - CheckpointWriteWorker checkpointWriteQueue = new CheckpointWriteWorker( + ADCheckpointWriteWorker adCheckpointWriteQueue = new ADCheckpointWriteWorker( heapSizeBytes, TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, @@ -489,20 +634,20 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, TimeSeriesSettings.QUEUE_MAINTENANCE, - checkpoint, + adCheckpoint, ADCommonName.CHECKPOINT_INDEX_NAME, TimeSeriesSettings.HOURLY_MAINTENANCE, stateManager, TimeSeriesSettings.HOURLY_MAINTENANCE ); - CheckpointMaintainWorker checkpointMaintainQueue = new CheckpointMaintainWorker( + ADCheckpointMaintainWorker adCheckpointMaintainQueue = new ADCheckpointMaintainWorker( heapSizeBytes, TimeSeriesSettings.CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, @@ -510,108 +655,81 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointWriteQueue, + adCheckpointWriteQueue, TimeSeriesSettings.HOURLY_MAINTENANCE, stateManager, - adapter + adAdapter::convert ); - EntityCache cache = new PriorityCache( - checkpoint, + ADPriorityCache adPriorityCache = new ADPriorityCache( + adCheckpoint, AnomalyDetectorSettings.AD_DEDICATED_CACHE_SIZE.get(settings), AnomalyDetectorSettings.AD_CHECKPOINT_TTL, AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, - memoryTracker, + adMemoryTracker, TimeSeriesSettings.NUM_TREES, getClock(), clusterService, TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, - checkpointWriteQueue, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointMaintainQueue, settings, - AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + adCheckpointWriteQueue, + adCheckpointMaintainQueue ); - cacheProvider.set(cache); + // cache provider allows us to break circular dependency among PriorityCache, CacheBuffer, + // CheckPointMaintainRequestAdapter, and CheckpointMaintainWorker + adCacheProvider.set(adPriorityCache); - EntityColdStarter entityColdStarter = new EntityColdStarter( + ADColdStart adEntityColdStarter = new ADColdStart( getClock(), threadPool, stateManager, TimeSeriesSettings.NUM_SAMPLES_PER_TREE, TimeSeriesSettings.NUM_TREES, - TimeSeriesSettings.TIME_DECAY, TimeSeriesSettings.NUM_MIN_SAMPLES, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - imputer, searchFeatureDao, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, - settings, - TimeSeriesSettings.HOURLY_MAINTENANCE, - checkpointWriteQueue, - TimeSeriesSettings.MAX_COLD_START_ROUNDS - ); - - EntityColdStartWorker coldstartQueue = new EntityColdStartWorker( - heapSizeBytes, - AnomalyDetectorSettings.ENTITY_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, - clusterService, - random, - adCircuitBreakerService, - threadPool, - settings, - TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, - getClock(), - TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, - TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, - TimeSeriesSettings.QUEUE_MAINTENANCE, - entityColdStarter, TimeSeriesSettings.HOURLY_MAINTENANCE, - stateManager, - cacheProvider + adCheckpointWriteQueue, + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + (int) (AD_COOLDOWN_MINUTES.get(settings).getMinutes()) ); - ModelManager modelManager = new ModelManager( - checkpoint, + ADModelManager adModelManager = new ADModelManager( + adCheckpoint, getClock(), TimeSeriesSettings.NUM_TREES, TimeSeriesSettings.NUM_SAMPLES_PER_TREE, - TimeSeriesSettings.TIME_DECAY, TimeSeriesSettings.NUM_MIN_SAMPLES, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, AnomalyDetectorSettings.MIN_PREVIEW_SIZE, TimeSeriesSettings.HOURLY_MAINTENANCE, AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, - entityColdStarter, + adEntityColdStarter, featureManager, - memoryTracker, + adMemoryTracker, settings, clusterService ); - MultiEntityResultHandler multiEntityResultHandler = new MultiEntityResultHandler( + ADIndexMemoryPressureAwareResultHandler adIndexMemoryPressureAwareResultHandler = new ADIndexMemoryPressureAwareResultHandler( client, - settings, - threadPool, - anomalyDetectionIndices, - this.clientUtil, - this.indexUtils, - clusterService + anomalyDetectionIndices ); - ResultWriteWorker resultWriteQueue = new ResultWriteWorker( + ADResultWriteWorker adResultWriteQueue = new ADResultWriteWorker( heapSizeBytes, TimeSeriesSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, @@ -620,62 +738,94 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, TimeSeriesSettings.QUEUE_MAINTENANCE, - multiEntityResultHandler, + adIndexMemoryPressureAwareResultHandler, xContentRegistry, stateManager, TimeSeriesSettings.HOURLY_MAINTENANCE ); - Map> stats = ImmutableMap - .>builder() - .put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + ADSaveResultStrategy adSaveResultStrategy = new ADSaveResultStrategy( + anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), + adResultWriteQueue + ); + + ADColdStartWorker adColdstartQueue = new ADColdStartWorker( + heapSizeBytes, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + adEntityColdStarter, + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager, + adPriorityCache, + adModelManager, + adSaveResultStrategy + ); + + Map> adStatsMap = ImmutableMap + .>builder() + // ad stats + .put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) .put( - StatNames.MODEL_INFORMATION.getName(), - new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService)) + StatNames.ANOMALY_RESULTS_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) ) .put( - StatNames.ANOMALY_DETECTORS_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.CONFIG_INDEX)) + StatNames.AD_MODELS_CHECKPOINT_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.CHECKPOINT_INDEX_NAME)) ) .put( - StatNames.ANOMALY_RESULTS_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) + StatNames.ANOMALY_DETECTION_STATE_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.DETECTION_STATE_INDEX)) ) + .put(StatNames.DETECTOR_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.HC_DETECTOR_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_CANCELED_BATCH_TASK_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) .put( - StatNames.MODELS_CHECKPOINT_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.CHECKPOINT_INDEX_NAME)) + StatNames.MODEL_INFORMATION.getName(), + new TimeSeriesStat<>(false, new ADModelsOnNodeSupplier(adModelManager, adCacheProvider, settings, clusterService)) ) .put( - StatNames.ANOMALY_DETECTION_JOB_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.JOB_INDEX)) + StatNames.CONFIG_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.CONFIG_INDEX)) ) .put( - StatNames.ANOMALY_DETECTION_STATE_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.DETECTION_STATE_INDEX)) + StatNames.JOB_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.JOB_INDEX)) + ) + .put( + StatNames.MODEL_COUNT.getName(), + new TimeSeriesStat<>(false, new ADModelsOnNodeCountSupplier(adModelManager, adCacheProvider)) ) - .put(StatNames.DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) - .put(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) - .put(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) - .put(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_CANCELED_BATCH_TASK_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.MODEL_COUNT.getName(), new ADStat<>(false, new ModelsOnNodeCountSupplier(modelManager, cacheProvider))) - .put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) .build(); - adStats = new ADStats(stats); + adStats = new ADStats(adStatsMap); - CheckpointReadWorker checkpointReadQueue = new CheckpointReadWorker( + ADCheckpointReadWorker adCheckpointReadQueue = new ADCheckpointReadWorker( heapSizeBytes, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, @@ -684,25 +834,25 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, TimeSeriesSettings.QUEUE_MAINTENANCE, - modelManager, - checkpoint, - coldstartQueue, - resultWriteQueue, + adModelManager, + adCheckpoint, + adColdstartQueue, stateManager, anomalyDetectionIndices, - cacheProvider, + adCacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, - checkpointWriteQueue, - adStats + adCheckpointWriteQueue, + adStats, + adSaveResultStrategy ); - ColdEntityWorker coldEntityQueue = new ColdEntityWorker( + ADColdEntityWorker adColdEntityQueue = new ADColdEntityWorker( heapSizeBytes, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, @@ -710,17 +860,52 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointReadQueue, + adCheckpointReadQueue, TimeSeriesSettings.HOURLY_MAINTENANCE, stateManager ); - ADDataMigrator dataMigrator = new ADDataMigrator(client, clusterService, xContentRegistry, anomalyDetectionIndices); - HashRing hashRing = new HashRing(nodeFilter, getClock(), settings, client, clusterService, dataMigrator, modelManager); + ADDataMigrator adDataMigrator = new ADDataMigrator(client, clusterService, xContentRegistry, anomalyDetectionIndices); + + anomalyDetectorRunner = new AnomalyDetectorRunner(adModelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); + + ADTaskCacheManager adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, adMemoryTracker); + + ResultBulkIndexingHandler anomalyResultBulkIndexHandler = + new ResultBulkIndexingHandler<>( + client, + settings, + threadPool, + ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + this.clientUtil, + this.indexUtils, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + + ADSearchHandler adSearchHandler = new ADSearchHandler(settings, clusterService, client); + + ResultBulkIndexingHandler anomalyResultHandler = new ResultBulkIndexingHandler<>( + client, + settings, + threadPool, + ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + this.clientUtil, + this.indexUtils, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); - anomalyDetectorRunner = new AnomalyDetectorRunner(modelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); + // ===================== + // common components, need AD/forecasting components to initialize + // ===================== + HashRing hashRing = new HashRing(nodeFilter, getClock(), settings, client, clusterService, adDataMigrator, adModelManager); + ADTaskProfileRunner adTaskProfileRunner = new ADTaskProfileRunner(hashRing, client); - ADTaskCacheManager adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); adTaskManager = new ADTaskManager( settings, clusterService, @@ -730,24 +915,18 @@ public PooledObject wrap(LinkedBuffer obj) { nodeFilter, hashRing, adTaskCacheManager, - threadPool - ); - AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler = new AnomalyResultBulkIndexHandler( - client, - settings, threadPool, - this.clientUtil, - this.indexUtils, - clusterService, - anomalyDetectionIndices + stateManager, + adTaskProfileRunner ); + adBatchTaskRunner = new ADBatchTaskRunner( settings, threadPool, clusterService, client, securityClientUtil, - adCircuitBreakerService, + circuitBreakerService, featureManager, adTaskManager, anomalyDetectionIndices, @@ -756,51 +935,382 @@ public PooledObject wrap(LinkedBuffer obj) { adTaskCacheManager, searchFeatureDao, hashRing, - modelManager + adModelManager ); - ADSearchHandler adSearchHandler = new ADSearchHandler(settings, clusterService, client); + adResultResponseRecorder = new ExecuteADResultResponseRecorder( + anomalyDetectionIndices, + anomalyResultHandler, + adTaskManager, + nodeFilter, + threadPool, + client, + stateManager, + adTaskCacheManager, + TimeSeriesSettings.NUM_MIN_SAMPLES + ); - AnomalyIndexHandler anomalyResultHandler = new AnomalyIndexHandler( + adIndexJobActionHandler = new ADIndexJobActionHandler( client, + anomalyDetectionIndices, + xContentRegistry, + adTaskManager, + adResultResponseRecorder, + stateManager, + settings + ); + + // ===================== + // forecast components + // ===================== + ForecastEnabledSetting.getInstance().init(clusterService); + ForecastNumericSetting.getInstance().init(clusterService); + + forecastIndices = ThrowingSupplierWrapper + .throwingSupplierWrapper( + () -> new ForecastIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + ForecastSettings.FORECAST_MAX_UPDATE_RETRY_TIMES + ) + ) + .get(); + + double forecastModelMaxSizePercent = ForecastSettings.FORECAST_MODEL_MAX_SIZE_PERCENTAGE.get(settings); + + MemoryTracker forecastMemoryTracker = new MemoryTracker( + jvmService, + forecastModelMaxSizePercent, + clusterService, + circuitBreakerService + ); + + ForecastCheckpointDao forecastCheckpoint = new ForecastCheckpointDao( + client, + clientUtil, + gson, + TimeSeriesSettings.MAX_CHECKPOINT_BYTES, + serializeRCFBufferPool, + TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, + forecastIndices, + new RCFCasterMapper(), + AccessController.doPrivileged((PrivilegedAction>) () -> RuntimeSchema.getSchema(RCFCasterState.class)), + getClock() + ); + + ForecastCacheProvider forecastCacheProvider = new ForecastCacheProvider(); + + CheckPointMaintainRequestAdapter forecastAdapter = + new CheckPointMaintainRequestAdapter( + forecastCheckpoint, + ForecastIndex.CHECKPOINT.getIndexName(), + ForecastSettings.FORECAST_CHECKPOINT_SAVING_FREQ, + getClock(), + clusterService, + settings, + forecastCacheProvider + ); + + ForecastCheckpointWriteWorker forecastCheckpointWriteQueue = new ForecastCheckpointWriteWorker( + heapSizeBytes, + TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + forecastCheckpoint, + ForecastIndex.CHECKPOINT.getIndexName(), + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager, + TimeSeriesSettings.HOURLY_MAINTENANCE + ); + + ForecastCheckpointMaintainWorker forecastCheckpointMaintainQueue = new ForecastCheckpointMaintainWorker( + heapSizeBytes, + TimeSeriesSettings.CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, threadPool, - ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, - anomalyDetectionIndices, - this.clientUtil, - this.indexUtils, - clusterService + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + forecastCheckpointWriteQueue, + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager, + forecastAdapter::convert ); - adResultResponseRecorder = new ExecuteADResultResponseRecorder( - anomalyDetectionIndices, - anomalyResultHandler, - adTaskManager, + ForecastPriorityCache forecastPriorityCache = new ForecastPriorityCache( + forecastCheckpoint, + ForecastSettings.FORECAST_DEDICATED_CACHE_SIZE.get(settings), + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, + AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, + adMemoryTracker, + TimeSeriesSettings.NUM_TREES, + getClock(), + clusterService, + TimeSeriesSettings.HOURLY_MAINTENANCE, + threadPool, + FORECAST_THREAD_POOL_NAME, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + settings, + ForecastSettings.FORECAST_CHECKPOINT_SAVING_FREQ, + forecastCheckpointWriteQueue, + forecastCheckpointMaintainQueue + ); + + // cache provider allows us to break circular dependency among PriorityCache, CacheBuffer, + // CheckPointMaintainRequestAdapter, and CheckpointMaintainWorker + forecastCacheProvider.set(forecastPriorityCache); + + ForecastColdStart forecastColdStarter = new ForecastColdStart( + getClock(), + threadPool, + stateManager, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.NUM_MIN_SAMPLES, + searchFeatureDao, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + featureManager, + TimeSeriesSettings.HOURLY_MAINTENANCE, + forecastCheckpointWriteQueue, + (int) (AD_COOLDOWN_MINUTES.get(settings).getMinutes()), + -1, // no hard coded random seed + -1, // interpolation is disabled so we don't need to specify the number of sampled points + TimeSeriesSettings.MAX_COLD_START_ROUNDS + ); + + ForecastModelManager forecastModelManager = new ForecastModelManager( + forecastCheckpoint, + getClock(), + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_MIN_SAMPLES, + forecastColdStarter, + forecastMemoryTracker, + featureManager + ); + + ForecastIndexMemoryPressureAwareResultHandler forecastIndexMemoryPressureAwareResultHandler = + new ForecastIndexMemoryPressureAwareResultHandler(client, forecastIndices); + + ForecastResultWriteWorker forecastResultWriteQueue = new ForecastResultWriteWorker( + heapSizeBytes, + TimeSeriesSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + forecastIndexMemoryPressureAwareResultHandler, + xContentRegistry, + stateManager, + TimeSeriesSettings.HOURLY_MAINTENANCE + ); + + ForecastSaveResultStrategy forecastSaveResultStrategy = new ForecastSaveResultStrategy( + forecastIndices.getSchemaVersion(ForecastIndex.RESULT), + forecastResultWriteQueue + ); + + ForecastColdStartWorker forecastColdstartQueue = new ForecastColdStartWorker( + heapSizeBytes, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_COLD_START_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + forecastColdStarter, + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager, + forecastPriorityCache, + forecastModelManager, + forecastSaveResultStrategy + ); + + Map> forecastStatsMap = ImmutableMap + .>builder() + // forecast stats + .put(StatNames.FORECAST_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.FORECAST_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.FORECAST_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.FORECAST_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put( + StatNames.FORECAST_RESULTS_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ForecastIndex.RESULT.getIndexName())) + ) + .put( + StatNames.FORECAST_MODELS_CHECKPOINT_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ForecastIndex.CHECKPOINT.getIndexName())) + ) + .put( + StatNames.FORECAST_STATE_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ForecastIndex.STATE.getIndexName())) + ) + .put(StatNames.FORECASTER_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.SINGLE_STREAM_FORECASTER_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.HC_FORECASTER_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.FORECAST_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put( + StatNames.MODEL_INFORMATION.getName(), + new TimeSeriesStat<>(false, new ForecastModelsOnNodeSupplier(forecastCacheProvider, settings, clusterService)) + ) + .put( + StatNames.CONFIG_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.CONFIG_INDEX)) + ) + .put( + StatNames.JOB_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.JOB_INDEX)) + ) + .put(StatNames.MODEL_COUNT.getName(), new TimeSeriesStat<>(false, new ForecastModelsOnNodeCountSupplier(forecastCacheProvider))) + .build(); + + forecastStats = new ForecastStats(forecastStatsMap); + + ForecastCheckpointReadWorker forecastCheckpointReadQueue = new ForecastCheckpointReadWorker( + heapSizeBytes, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + forecastModelManager, + forecastCheckpoint, + forecastColdstartQueue, + stateManager, + forecastIndices, + forecastCacheProvider, + TimeSeriesSettings.HOURLY_MAINTENANCE, + forecastCheckpointWriteQueue, + forecastStats, + forecastSaveResultStrategy + ); + + ForecastColdEntityWorker forecastColdEntityQueue = new ForecastColdEntityWorker( + heapSizeBytes, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + forecastCheckpointReadQueue, + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager + ); + + TaskCacheManager forecastTaskCacheManager = new TaskCacheManager(settings, clusterService); + + forecastTaskManager = new ForecastTaskManager( + forecastTaskCacheManager, + client, + xContentRegistry, + forecastIndices, + clusterService, + settings, + threadPool, + stateManager + ); + + ResultBulkIndexingHandler forecastResultHandler = + new ResultBulkIndexingHandler<>( + client, + settings, + threadPool, + ForecastIndex.RESULT.getIndexName(), + forecastIndices, + this.clientUtil, + this.indexUtils, + clusterService, + ForecastSettings.FORECAST_BACKOFF_INITIAL_DELAY, + ForecastSettings.FORECAST_MAX_RETRY_FOR_BACKOFF + ); + + ForecastSearchHandler forecastSearchHandler = new ForecastSearchHandler(settings, clusterService, client); + + forecastResultResponseRecorder = new ExecuteForecastResultResponseRecorder( + forecastIndices, + forecastResultHandler, + forecastTaskManager, nodeFilter, threadPool, client, stateManager, - adTaskCacheManager, + forecastTaskCacheManager, TimeSeriesSettings.NUM_MIN_SAMPLES ); + forecastIndexJobActionHandler = new ForecastIndexJobActionHandler( + client, + forecastIndices, + xContentRegistry, + forecastTaskManager, + forecastResultResponseRecorder, + stateManager, + settings + ); + // return objects used by Guice to inject dependencies for e.g., // transport action handler constructors return ImmutableList .of( - anomalyDetectionIndices, - anomalyDetectorRunner, + // common components searchFeatureDao, imputer, gson, jvmService, hashRing, featureManager, - modelManager, stateManager, - new ADClusterEventListener(clusterService, hashRing), - adCircuitBreakerService, - adStats, + new ClusterEventListener(clusterService, hashRing), + circuitBreakerService, new ClusterManagerEventListener( clusterService, threadPool, @@ -809,23 +1319,51 @@ public PooledObject wrap(LinkedBuffer obj) { clientUtil, nodeFilter, AnomalyDetectorSettings.AD_CHECKPOINT_TTL, + ForecastSettings.FORECAST_CHECKPOINT_TTL, settings ), nodeFilter, - multiEntityResultHandler, - checkpoint, - cacheProvider, + // AD components + anomalyDetectionIndices, + anomalyDetectorRunner, + adModelManager, + adStats, + adIndexMemoryPressureAwareResultHandler, + adCheckpoint, + adCacheProvider, adTaskManager, adBatchTaskRunner, adSearchHandler, - coldstartQueue, - resultWriteQueue, - checkpointReadQueue, - checkpointWriteQueue, - coldEntityQueue, - entityColdStarter, + adColdstartQueue, + adResultWriteQueue, + adCheckpointReadQueue, + adCheckpointWriteQueue, + adColdEntityQueue, + adEntityColdStarter, adTaskCacheManager, - adResultResponseRecorder + adResultResponseRecorder, + adIndexJobActionHandler, + adSaveResultStrategy, + new ADTaskProfileRunner(hashRing, client), + // forecast components + forecastIndices, + forecastStats, + forecastModelManager, + forecastIndexMemoryPressureAwareResultHandler, + forecastCheckpoint, + forecastCacheProvider, + forecastColdstartQueue, + forecastResultWriteQueue, + forecastCheckpointReadQueue, + forecastCheckpointWriteQueue, + forecastColdEntityQueue, + forecastColdStarter, + forecastTaskManager, + forecastSearchHandler, + forecastIndexJobActionHandler, + forecastTaskCacheManager, + forecastSaveResultStrategy, + new ForecastTaskProfileRunner() ); } @@ -857,14 +1395,29 @@ public List> getExecutorBuilders(Settings settings) { Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) / 8), TimeValue.timeValueMinutes(10), AD_THREAD_POOL_PREFIX + AD_BATCH_TASK_THREAD_POOL_NAME + ), + new ScalingExecutorBuilder( + FORECAST_THREAD_POOL_NAME, + 1, + // this pool is used by both real time and run once. + // HCAD can be heavy after supporting 1 million entities. + // Limit to use at most 3/4 of the processors. + Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) * 3 / 4), + TimeValue.timeValueMinutes(10), + FORECAST_THREAD_POOL_PREFIX + FORECAST_THREAD_POOL_NAME ) ); } @Override public List> getSettings() { - List> enabledSetting = ADEnabledSetting.getInstance().getSettings(); - List> numericSetting = ADNumericSetting.getInstance().getSettings(); + List> adEnabledSetting = ADEnabledSetting.getInstance().getSettings(); + List> adNumericSetting = ADNumericSetting.getInstance().getSettings(); + + List> forecastEnabledSetting = ForecastEnabledSetting.getInstance().getSettings(); + List> forecastNumericSetting = ForecastNumericSetting.getInstance().getSettings(); + + List> timeSeriesEnabledSetting = TimeSeriesEnabledSetting.getInstance().getSettings(); List> systemSetting = ImmutableList .of( @@ -960,6 +1513,15 @@ public List> getSettings() { // ====================================== // Forecast settings // ====================================== + // HC forecasting cache + ForecastSettings.FORECAST_DEDICATED_CACHE_SIZE, + // config parameters + ForecastSettings.FORECAST_INTERVAL, + ForecastSettings.FORECAST_WINDOW_DELAY, + // Fault tolerance + ForecastSettings.FORECAST_BACKOFF_MINUTES, + ForecastSettings.FORECAST_BACKOFF_INITIAL_DELAY, + ForecastSettings.FORECAST_MAX_RETRY_FOR_BACKOFF, // result index rollover ForecastSettings.FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD, ForecastSettings.FORECAST_RESULT_HISTORY_RETENTION_PERIOD, @@ -972,6 +1534,40 @@ public List> getSettings() { ForecastSettings.FORECAST_INDEX_PRESSURE_SOFT_LIMIT, ForecastSettings.FORECAST_INDEX_PRESSURE_HARD_LIMIT, ForecastSettings.FORECAST_MAX_PRIMARY_SHARDS, + // restful apis + ForecastSettings.FORECAST_REQUEST_TIMEOUT, + // resource constraint + ForecastSettings.MAX_SINGLE_STREAM_FORECASTERS, + ForecastSettings.MAX_HC_FORECASTERS, + // Security + ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES, + // Historical + ForecastSettings.MAX_OLD_TASK_DOCS_PER_FORECASTER, + // rate limiting + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_COLD_START_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_BATCH_SIZE, + ForecastSettings.FORECAST_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_COLD_START_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + ForecastSettings.FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, + ForecastSettings.FORECAST_CHECKPOINT_SAVING_FREQ, + ForecastSettings.FORECAST_CHECKPOINT_TTL, + // query limit + ForecastSettings.FORECAST_MAX_ENTITIES_PER_INTERVAL, + ForecastSettings.FORECAST_PAGE_SIZE, + // stats/profile API + ForecastSettings.FORECAST_MAX_MODEL_SIZE_PER_NODE, + // clean resource + ForecastSettings.DELETE_FORECAST_RESULT_WHEN_DELETE_FORECASTER, // ====================================== // Common settings // ====================================== @@ -984,7 +1580,14 @@ public List> getSettings() { ); return unmodifiableList( Stream - .of(enabledSetting.stream(), systemSetting.stream(), numericSetting.stream()) + .of( + adEnabledSetting.stream(), + forecastEnabledSetting.stream(), + timeSeriesEnabledSetting.stream(), + systemSetting.stream(), + adNumericSetting.stream(), + forecastNumericSetting.stream() + ) .reduce(Stream::concat) .orElseGet(Stream::empty) .collect(Collectors.toList()) @@ -1010,14 +1613,15 @@ public List getNamedXContent() { public List> getActions() { return Arrays .asList( - new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class), + // AD + new ActionHandler<>(DeleteADModelAction.INSTANCE, DeleteADModelTransportAction.class), new ActionHandler<>(StopDetectorAction.INSTANCE, StopDetectorTransportAction.class), new ActionHandler<>(RCFResultAction.INSTANCE, RCFResultTransportAction.class), new ActionHandler<>(ThresholdResultAction.INSTANCE, ThresholdResultTransportAction.class), new ActionHandler<>(AnomalyResultAction.INSTANCE, AnomalyResultTransportAction.class), new ActionHandler<>(CronAction.INSTANCE, CronTransportAction.class), new ActionHandler<>(ADStatsNodesAction.INSTANCE, ADStatsNodesTransportAction.class), - new ActionHandler<>(ProfileAction.INSTANCE, ProfileTransportAction.class), + new ActionHandler<>(ADProfileAction.INSTANCE, ADProfileTransportAction.class), new ActionHandler<>(RCFPollingAction.INSTANCE, RCFPollingTransportAction.class), new ActionHandler<>(SearchAnomalyDetectorAction.INSTANCE, SearchAnomalyDetectorTransportAction.class), new ActionHandler<>(SearchAnomalyResultAction.INSTANCE, SearchAnomalyResultTransportAction.class), @@ -1028,8 +1632,8 @@ public List getNamedXContent() { new ActionHandler<>(IndexAnomalyDetectorAction.INSTANCE, IndexAnomalyDetectorTransportAction.class), new ActionHandler<>(AnomalyDetectorJobAction.INSTANCE, AnomalyDetectorJobTransportAction.class), new ActionHandler<>(ADResultBulkAction.INSTANCE, ADResultBulkTransportAction.class), - new ActionHandler<>(EntityResultAction.INSTANCE, EntityResultTransportAction.class), - new ActionHandler<>(EntityProfileAction.INSTANCE, EntityProfileTransportAction.class), + new ActionHandler<>(EntityADResultAction.INSTANCE, EntityADResultTransportAction.class), + new ActionHandler<>(ADEntityProfileAction.INSTANCE, ADEntityProfileTransportAction.class), new ActionHandler<>(SearchAnomalyDetectorInfoAction.INSTANCE, SearchAnomalyDetectorInfoTransportAction.class), new ActionHandler<>(PreviewAnomalyDetectorAction.INSTANCE, PreviewAnomalyDetectorTransportAction.class), new ActionHandler<>(ADBatchAnomalyResultAction.INSTANCE, ADBatchAnomalyResultTransportAction.class), @@ -1039,7 +1643,30 @@ public List getNamedXContent() { new ActionHandler<>(ForwardADTaskAction.INSTANCE, ForwardADTaskTransportAction.class), new ActionHandler<>(DeleteAnomalyResultsAction.INSTANCE, DeleteAnomalyResultsTransportAction.class), new ActionHandler<>(SearchTopAnomalyResultAction.INSTANCE, SearchTopAnomalyResultTransportAction.class), - new ActionHandler<>(ValidateAnomalyDetectorAction.INSTANCE, ValidateAnomalyDetectorTransportAction.class) + new ActionHandler<>(ValidateAnomalyDetectorAction.INSTANCE, ValidateAnomalyDetectorTransportAction.class), + // forecast + new ActionHandler<>(IndexForecasterAction.INSTANCE, IndexForecasterTransportAction.class), + new ActionHandler<>(ForecastResultAction.INSTANCE, ForecastResultTransportAction.class), + new ActionHandler<>(EntityForecastResultAction.INSTANCE, EntityForecastResultTransportAction.class), + new ActionHandler<>(ForecastResultBulkAction.INSTANCE, ForecastResultBulkTransportAction.class), + new ActionHandler<>(ForecastSingleStreamResultAction.INSTANCE, ForecastSingleStreamResultTransportAction.class), + new ActionHandler<>(ForecasterJobAction.INSTANCE, ForecasterJobTransportAction.class), + new ActionHandler<>(StopForecasterAction.INSTANCE, StopForecasterTransportAction.class), + new ActionHandler<>(DeleteForecastModelAction.INSTANCE, DeleteForecastModelTransportAction.class), + new ActionHandler<>(GetForecasterAction.INSTANCE, GetForecasterTransportAction.class), + new ActionHandler<>(DeleteForecasterAction.INSTANCE, DeleteForecasterTransportAction.class), + new ActionHandler<>(SearchForecasterAction.INSTANCE, SearchForecasterTransportAction.class), + new ActionHandler<>(SearchForecasterInfoAction.INSTANCE, SearchForecasterInfoTransportAction.class), + new ActionHandler<>(SearchTopForecastResultAction.INSTANCE, SearchTopForecastResultTransportAction.class), + new ActionHandler<>(ForecastEntityProfileAction.INSTANCE, ForecastEntityProfileTransportAction.class), + new ActionHandler<>(ForecastProfileAction.INSTANCE, ForecastProfileTransportAction.class), + new ActionHandler<>(SearchForecastTasksAction.INSTANCE, SearchForecastTasksTransportAction.class), + new ActionHandler<>(StatsForecasterAction.INSTANCE, StatsForecasterTransportAction.class), + new ActionHandler<>(ForecastStatsNodesAction.INSTANCE, ForecastStatsNodesTransportAction.class), + new ActionHandler<>(ForecastRunOnceAction.INSTANCE, ForecastRunOnceTransportAction.class), + new ActionHandler<>(ForecastRunOnceProfileAction.INSTANCE, ForecastRunOnceProfileTransportAction.class), + new ActionHandler<>(ValidateForecasterAction.INSTANCE, ValidateForecasterTransportAction.class), + new ActionHandler<>(SuggestForecasterParamAction.INSTANCE, SuggestForecasterParamTransportAction.class) ); } @@ -1055,7 +1682,7 @@ public String getJobIndex() { @Override public ScheduledJobRunner getJobRunner() { - return AnomalyDetectorJobRunner.getJobRunnerInstance(); + return JobRunner.getJobRunnerInstance(); } @Override diff --git a/src/main/java/org/opensearch/timeseries/breaker/CircuitBreakerService.java b/src/main/java/org/opensearch/timeseries/breaker/CircuitBreakerService.java index efa48ec7f..dd5ed15c8 100644 --- a/src/main/java/org/opensearch/timeseries/breaker/CircuitBreakerService.java +++ b/src/main/java/org/opensearch/timeseries/breaker/CircuitBreakerService.java @@ -16,8 +16,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.timeseries.settings.TimeSeriesEnabledSetting; /** * Class {@code CircuitBreakerService} provide storing, retrieving circuit breakers functions. @@ -76,7 +76,7 @@ public CircuitBreakerService init() { } public Boolean isOpen() { - if (!ADEnabledSetting.isADBreakerEnabled()) { + if (!TimeSeriesEnabledSetting.isBreakerEnabled()) { return false; } diff --git a/src/main/java/org/opensearch/ad/caching/CacheBuffer.java b/src/main/java/org/opensearch/timeseries/caching/CacheBuffer.java similarity index 73% rename from src/main/java/org/opensearch/ad/caching/CacheBuffer.java rename to src/main/java/org/opensearch/timeseries/caching/CacheBuffer.java index fb48fd273..8d5605816 100644 --- a/src/main/java/org/opensearch/ad/caching/CacheBuffer.java +++ b/src/main/java/org/opensearch/timeseries/caching/CacheBuffer.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.time.Duration; @@ -25,273 +25,150 @@ import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.InitProgressProfile; -import org.opensearch.ad.ratelimit.CheckpointMaintainRequest; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.util.DateUtils; import org.opensearch.timeseries.ExpiringState; import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.util.DateUtils; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class CacheBuffer & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, CheckpointMaintainerType extends CheckpointMaintainWorker> + implements + ExpiringState { -/** - * We use a layered cache to manage active entities’ states. We have a two-level - * cache that stores active entity states in each node. Each detector has its - * dedicated cache that stores ten (dynamically adjustable) entities’ states per - * node. A detector’s hottest entities load their states in the dedicated cache. - * If less than 10 entities use the dedicated cache, the secondary cache can use - * the rest of the free memory available to AD. The secondary cache is a shared - * memory among all detectors for the long tail. The shared cache size is 10% - * heap minus all of the dedicated cache consumed by single-entity and multi-entity - * detectors. The shared cache’s size shrinks as the dedicated cache is filled - * up or more detectors are started. - * - * Implementation-wise, both dedicated cache and shared cache are stored in items - * and minimumCapacity controls the boundary. If items size is equals to or less - * than minimumCapacity, consider items as dedicated cache; otherwise, consider - * top minimumCapacity active entities (last X entities in priorityList) as in dedicated - * cache and all others in shared cache. - */ -public class CacheBuffer implements ExpiringState { private static final Logger LOG = LogManager.getLogger(CacheBuffer.class); - // max entities to track per detector - private final int MAX_TRACKING_ENTITIES = 1000000; + protected Instant lastUsedTime; + protected final Clock clock; + + protected final MemoryTracker memoryTracker; + protected int checkpointIntervalHrs; + protected final Duration modelTtl; + // max entities to track per detector + protected final int MAX_TRACKING_ENTITIES = 1000000; // the reserved cache size. So no matter how many entities there are, we will // keep the size for minimum capacity entities - private int minimumCapacity; - - // key is model id - private final ConcurrentHashMap> items; + protected int minimumCapacity; // memory consumption per entity - private final long memoryConsumptionPerEntity; - private final MemoryTracker memoryTracker; - private final Duration modelTtl; - private final String detectorId; - private Instant lastUsedTime; - private long reservedBytes; - private final PriorityTracker priorityTracker; - private final Clock clock; - private final CheckpointWriteWorker checkpointWriteQueue; - private final CheckpointMaintainWorker checkpointMaintainQueue; - private int checkpointIntervalHrs; + protected final long memoryConsumptionPerModel; + protected long reservedBytes; + protected final CheckpointWriterType checkpointWriteQueue; + protected final CheckpointMaintainerType checkpointMaintainQueue; + protected final String configId; + protected final Origin origin; + protected final PriorityTracker priorityTracker; + // key is model id + protected final ConcurrentHashMap> items; public CacheBuffer( int minimumCapacity, - long intervalSecs, - long memoryConsumptionPerEntity, - MemoryTracker memoryTracker, Clock clock, + MemoryTracker memoryTracker, + int checkpointIntervalHrs, Duration modelTtl, - String detectorId, - CheckpointWriteWorker checkpointWriteQueue, - CheckpointMaintainWorker checkpointMaintainQueue, - int checkpointIntervalHrs + long memoryConsumptionPerEntity, + CheckpointWriterType checkpointWriteQueue, + CheckpointMaintainerType checkpointMaintainQueue, + String configId, + long intervalSecs, + Origin origin ) { - this.memoryConsumptionPerEntity = memoryConsumptionPerEntity; - setMinimumCapacity(minimumCapacity); - - this.items = new ConcurrentHashMap<>(); - this.memoryTracker = memoryTracker; - - this.modelTtl = modelTtl; - this.detectorId = detectorId; this.lastUsedTime = clock.instant(); - this.clock = clock; - this.priorityTracker = new PriorityTracker(clock, intervalSecs, clock.instant().getEpochSecond(), MAX_TRACKING_ENTITIES); + this.memoryTracker = memoryTracker; + setCheckpointIntervalHrs(checkpointIntervalHrs); + this.modelTtl = modelTtl; + this.memoryConsumptionPerModel = memoryConsumptionPerEntity; this.checkpointWriteQueue = checkpointWriteQueue; this.checkpointMaintainQueue = checkpointMaintainQueue; - setCheckpointIntervalHrs(checkpointIntervalHrs); - } - - /** - * Update step at period t_k: - * new priority = old priority + log(1+e^{\log(g(t_k-L))-old priority}) where g(n) = e^{0.125n}, - * and n is the period. - * @param entityModelId model Id - */ - private void update(String entityModelId) { - priorityTracker.updatePriority(entityModelId); - - Instant now = clock.instant(); - items.get(entityModelId).setLastUsedTime(now); - lastUsedTime = now; - } - - /** - * Insert the model state associated with a model Id to the cache - * @param entityModelId the model Id - * @param value the ModelState - */ - public void put(String entityModelId, ModelState value) { - // race conditions can happen between the put and one of the following operations: - // remove: not a problem as it is unlikely we are removing and putting the same thing - // maintenance: not a problem as we are unlikely to maintain an entry that's not - // already in the cache - // clear: not a problem as we are releasing memory in MemoryTracker. - // The newly added one loses references and soon GC will collect it. - // We have memory tracking correction to fix incorrect memory usage record. - // put from other threads: not a problem as the entry is associated with - // entityModelId and our put is idempotent - put(entityModelId, value, value.getPriority()); + this.configId = configId; + this.origin = origin; + this.priorityTracker = new PriorityTracker(clock, intervalSecs, clock.instant().getEpochSecond(), MAX_TRACKING_ENTITIES); + this.items = new ConcurrentHashMap<>(); + // called after minimumCapacity and memoryConsumptionPerModel are set + setMinimumCapacity(minimumCapacity); } - /** - * Insert the model state associated with a model Id to the cache. Update priority. - * @param entityModelId the model Id - * @param value the ModelState - * @param priority the priority - */ - private void put(String entityModelId, ModelState value, float priority) { - ModelState contentNode = items.get(entityModelId); - if (contentNode == null) { - priorityTracker.addPriority(entityModelId, priority); - items.put(entityModelId, value); - Instant now = clock.instant(); - value.setLastUsedTime(now); - lastUsedTime = now; - // shared cache empty means we are consuming reserved cache. - // Since we have already considered them while allocating CacheBuffer, - // skip bookkeeping. - if (!sharedCacheEmpty()) { - memoryTracker.consumeMemory(memoryConsumptionPerEntity, false, Origin.REAL_TIME_DETECTOR); - } - } else { - update(entityModelId); - items.put(entityModelId, value); + public void setMinimumCapacity(int minimumCapacity) { + if (minimumCapacity < 0) { + throw new IllegalArgumentException("minimum capacity should be larger than or equal 0"); } + this.minimumCapacity = minimumCapacity; + this.reservedBytes = memoryConsumptionPerModel * minimumCapacity; } - /** - * Retrieve the ModelState associated with the model Id or null if the CacheBuffer - * contains no mapping for the model Id - * @param key the model Id - * @return the Model state to which the specified model Id is mapped, or null - * if this CacheBuffer contains no mapping for the model Id - */ - public ModelState get(String key) { - // We can get an item that is to be removed soon due to race condition. - // This is acceptable as it won't cause any corruption and exception. - // And this item is used for scoring one last time. - ModelState node = items.get(key); - if (node == null) { - return null; - } - update(key); - return node; + @Override + public boolean expired(Duration stateTtl) { + return expired(lastUsedTime, stateTtl, clock.instant()); } - /** - * Retrieve the ModelState associated with the model Id or null if the CacheBuffer - * contains no mapping for the model Id. Compared to get method, the method won't - * increment entity priority. Used in cache buffer maintenance. - * - * @param key the model Id - * @return the Model state to which the specified model Id is mapped, or null - * if this CacheBuffer contains no mapping for the model Id - */ - public ModelState getWithoutUpdatePriority(String key) { - // We can get an item that is to be removed soon due to race condition. - // This is acceptable as it won't cause any corruption and exception. - // And this item is used for scoring one last time. - ModelState node = items.get(key); - if (node == null) { - return null; + public void setCheckpointIntervalHrs(int checkpointIntervalHrs) { + this.checkpointIntervalHrs = checkpointIntervalHrs; + // 0 can cause java.lang.ArithmeticException: / by zero + // negative value is meaningless + if (checkpointIntervalHrs <= 0) { + this.checkpointIntervalHrs = 1; } - return node; } - /** - * - * @return whether there is one item that can be removed from shared cache - */ - public boolean canRemove() { - return !items.isEmpty() && items.size() > minimumCapacity; + public int getCheckpointIntervalHrs() { + return checkpointIntervalHrs; } /** - * remove the smallest priority item. - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key - */ - public ModelState remove() { - // race conditions can happen between the put and one of the following operations: - // remove from other threads: not a problem. If they remove the same item, - // our method is idempotent. If they remove two different items, - // they don't impact each other. - // maintenance: not a problem as all of the data structures are concurrent. - // Two threads removing the same entry is not a problem. - // clear: not a problem as we are releasing memory in MemoryTracker. - // The removed one loses references and soon GC will collect it. - // We have memory tracking correction to fix incorrect memory usage record. - // put: not a problem as it is unlikely we are removing and putting the same thing - Optional key = priorityTracker.getMinimumPriorityEntityId(); - if (key.isPresent()) { - return remove(key.get()); - } - return null; + * + * @return reserved bytes by the CacheBuffer + */ + public long getReservedBytes() { + return reservedBytes; } /** - * Remove everything associated with the key and make a checkpoint. - * - * @param keyToRemove The key to remove - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key - */ - public ModelState remove(String keyToRemove) { - return remove(keyToRemove, true); + * + * @return the estimated number of bytes per entity state + */ + public long getMemoryConsumptionPerModel() { + return memoryConsumptionPerModel; } - /** - * Remove everything associated with the key and make a checkpoint if input specified so. - * - * @param keyToRemove The key to remove - * @param saveCheckpoint Whether saving checkpoint or not - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key - */ - public ModelState remove(String keyToRemove, boolean saveCheckpoint) { - priorityTracker.removePriority(keyToRemove); - - // if shared cache is empty, we are using reserved memory - boolean reserved = sharedCacheEmpty(); - - ModelState valueRemoved = items.remove(keyToRemove); + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; - if (valueRemoved != null) { - if (!reserved) { - // release in shared memory - memoryTracker.releaseMemory(memoryConsumptionPerEntity, false, Origin.REAL_TIME_DETECTOR); - } + if (obj instanceof CacheBuffer) { + @SuppressWarnings("unchecked") + CacheBuffer other = + (CacheBuffer) obj; - EntityModel modelRemoved = valueRemoved.getModel(); - if (modelRemoved != null) { - if (saveCheckpoint) { - // null model has only samples. For null model we save a checkpoint - // regardless of last checkpoint time. whether If we don't save, - // we throw the new samples and might never be able to initialize the model - boolean isNullModel = !modelRemoved.getTrcf().isPresent(); - checkpointWriteQueue.write(valueRemoved, isNullModel, RequestPriority.MEDIUM); - } + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(configId, other.configId); - modelRemoved.clear(); - } + return equalsBuilder.isEquals(); } + return false; + } - return valueRemoved; + @Override + public int hashCode() { + return new HashCodeBuilder().append(configId).toHashCode(); } - /** - * @return whether dedicated cache is available or not - */ - public boolean dedicatedCacheAvailable() { - return items.size() < minimumCapacity; + public String getConfigId() { + return configId; } /** @@ -302,56 +179,47 @@ public boolean sharedCacheEmpty() { } /** - * - * @return the estimated number of bytes per entity state - */ - public long getMemoryConsumptionPerEntity() { - return memoryConsumptionPerEntity; - } - - /** - * - * If the cache is not full, check if some other items can replace internal entities - * within the same detector. - * - * @param priority another entity's priority - * @return whether one entity can be replaced by another entity with a certain priority - */ - public boolean canReplaceWithinDetector(float priority) { - if (items.isEmpty()) { - return false; + * + * @return bytes consumed in the shared cache by the CacheBuffer + */ + public long getBytesInSharedCache() { + int sharedCacheEntries = items.size() - minimumCapacity; + if (sharedCacheEntries > 0) { + return memoryConsumptionPerModel * sharedCacheEntries; } - Optional> minPriorityItem = priorityTracker.getMinimumPriority(); - return minPriorityItem.isPresent() && priority > minPriorityItem.get().getValue(); + return 0; } /** - * Replace the smallest priority entity with the input entity - * @param entityModelId the Model Id - * @param value the model State - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key + * Clear associated memory. Used when we are removing an detector. */ - public ModelState replace(String entityModelId, ModelState value) { - ModelState replaced = remove(); - put(entityModelId, value); - return replaced; + public void clear() { + // race conditions can happen between the put and remove/maintenance/put: + // not a problem as we are releasing memory in MemoryTracker. + // The newly added one loses references and soon GC will collect it. + // We have memory tracking correction to fix incorrect memory usage record. + memoryTracker.releaseMemory(getReservedBytes(), true, origin); + if (!sharedCacheEmpty()) { + memoryTracker.releaseMemory(getBytesInSharedCache(), false, origin); + } + items.clear(); + priorityTracker.clearPriority(); } /** * Remove expired state and save checkpoints of existing states * @return removed states */ - public List> maintenance() { + public List> maintenance() { List modelsToSave = new ArrayList<>(); - List> removedStates = new ArrayList<>(); + List> removedStates = new ArrayList<>(); Instant now = clock.instant(); int currentHour = DateUtils.getUTCHourOfDay(now); int currentSlot = currentHour % checkpointIntervalHrs; items.entrySet().stream().forEach(entry -> { String entityModelId = entry.getKey(); try { - ModelState modelState = entry.getValue(); + ModelState modelState = entry.getValue(); if (modelState.getLastUsedTime().plus(modelTtl).isBefore(now)) { // race conditions can happen between the put and one of the following operations: @@ -397,7 +265,7 @@ public List> maintenance() { new CheckpointMaintainRequest( // the request expires when the next maintainance starts System.currentTimeMillis() + modelTtl.toMillis(), - detectorId, + configId, RequestPriority.LOW, entityModelId ) @@ -414,9 +282,97 @@ public List> maintenance() { } /** + * Remove everything associated with the key and make a checkpoint if input specified so. + * + * @param keyToRemove The key to remove + * @param saveCheckpoint Whether saving checkpoint or not + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState remove(String keyToRemove, boolean saveCheckpoint) { + priorityTracker.removePriority(keyToRemove); + + // if shared cache is empty, we are using reserved memory + boolean reserved = sharedCacheEmpty(); + + ModelState valueRemoved = items.remove(keyToRemove); + + if (valueRemoved != null) { + if (!reserved) { + // release in shared memory + memoryTracker.releaseMemory(memoryConsumptionPerModel, false, origin); + } + + if (saveCheckpoint) { + // null model has only samples. For null model we save a checkpoint + // regardless of last checkpoint time. whether If we don't save, + // we throw the new samples and might never be able to initialize the model + checkpointWriteQueue.write(valueRemoved, valueRemoved.getModel().isEmpty(), RequestPriority.MEDIUM); + } + + valueRemoved.clear(); + } + + return valueRemoved; + } + + /** + * Remove everything associated with the key and make a checkpoint. * - * @return the number of active entities + * @param keyToRemove The key to remove + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key */ + public ModelState remove(String keyToRemove) { + return remove(keyToRemove, true); + } + + public PriorityTracker getPriorityTracker() { + return priorityTracker; + } + + /** + * remove the smallest priority item. + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState remove() { + // race conditions can happen between the put and one of the following operations: + // remove from other threads: not a problem. If they remove the same item, + // our method is idempotent. If they remove two different items, + // they don't impact each other. + // maintenance: not a problem as all of the data structures are concurrent. + // Two threads removing the same entry is not a problem. + // clear: not a problem as we are releasing memory in MemoryTracker. + // The removed one loses references and soon GC will collect it. + // We have memory tracking correction to fix incorrect memory usage record. + // put: not a problem as it is unlikely we are removing and putting the same thing + Optional key = priorityTracker.getMinimumPriorityEntityId(); + if (key.isPresent()) { + return remove(key.get()); + } + return null; + } + + /** + * + * @return whether there is one item that can be removed from shared cache + */ + public boolean canRemove() { + return !items.isEmpty() && items.size() > minimumCapacity; + } + + /** + * @return whether dedicated cache is available or not + */ + public boolean dedicatedCacheAvailable() { + return items.size() < minimumCapacity; + } + + /** + * + * @return the number of active entities + */ public int getActiveEntities() { return items.size(); } @@ -436,7 +392,7 @@ public boolean isActive(String entityModelId) { * @return Last used time of the model */ public long getLastUsedTime(String entityModelId) { - ModelState state = items.get(entityModelId); + ModelState state = items.get(entityModelId); if (state != null) { return state.getLastUsedTime().toEpochMilli(); } @@ -448,105 +404,139 @@ public long getLastUsedTime(String entityModelId) { * @param entityModelId entity Id * @return Get the model of an entity */ - public Optional getModel(String entityModelId) { - return Optional.of(items).map(map -> map.get(entityModelId)).map(state -> state.getModel()); + public ModelState getModelState(String entityModelId) { + // flatMap allows for mapping the inner Optional directly, which results in + // a single Optional instead of a nested Optional>. + return items.get(entityModelId); } /** - * Clear associated memory. Used when we are removing an detector. + * Update step at period t_k: + * new priority = old priority + log(1+e^{\log(g(t_k-L))-old priority}) where g(n) = e^{0.125n}, + * and n is the period. + * @param entityModelId model Id */ - public void clear() { - // race conditions can happen between the put and remove/maintenance/put: - // not a problem as we are releasing memory in MemoryTracker. + private void update(String entityModelId) { + priorityTracker.updatePriority(entityModelId); + + Instant now = clock.instant(); + items.get(entityModelId).setLastUsedTime(now); + lastUsedTime = now; + } + + /** + * Insert the model state associated with a model Id to the cache + * @param entityModelId the model Id + * @param value the ModelState + */ + public void put(String entityModelId, ModelState value) { + // race conditions can happen between the put and one of the following operations: + // remove: not a problem as it is unlikely we are removing and putting the same thing + // maintenance: not a problem as we are unlikely to maintain an entry that's not + // already in the cache + // clear: not a problem as we are releasing memory in MemoryTracker. // The newly added one loses references and soon GC will collect it. // We have memory tracking correction to fix incorrect memory usage record. - memoryTracker.releaseMemory(getReservedBytes(), true, Origin.REAL_TIME_DETECTOR); - if (!sharedCacheEmpty()) { - memoryTracker.releaseMemory(getBytesInSharedCache(), false, Origin.REAL_TIME_DETECTOR); + // put from other threads: not a problem as the entry is associated with + // entityModelId and our put is idempotent + put(entityModelId, value, value.getPriority()); + } + + /** + * Insert the model state associated with a model Id to the cache. Update priority. + * @param entityModelId the model Id + * @param value the ModelState + * @param priority the priority + */ + private void put(String entityModelId, ModelState value, float priority) { + ModelState contentNode = items.get(entityModelId); + if (contentNode == null) { + priorityTracker.addPriority(entityModelId, priority); + items.put(entityModelId, value); + Instant now = clock.instant(); + value.setLastUsedTime(now); + lastUsedTime = now; + // shared cache empty means we are consuming reserved cache. + // Since we have already considered them while allocating CacheBuffer, + // skip bookkeeping. + if (!sharedCacheEmpty()) { + memoryTracker.consumeMemory(memoryConsumptionPerModel, false, origin); + } + } else { + update(entityModelId); + items.put(entityModelId, value); } - items.clear(); - priorityTracker.clearPriority(); } /** - * - * @return reserved bytes by the CacheBuffer + * Retrieve the ModelState associated with the model Id or null if the CacheBuffer + * contains no mapping for the model Id + * @param key the model Id + * @return the Model state to which the specified model Id is mapped, or null + * if this CacheBuffer contains no mapping for the model Id */ - public long getReservedBytes() { - return reservedBytes; + public ModelState get(String key) { + // We can get an item that is to be removed soon due to race condition. + // This is acceptable as it won't cause any corruption and exception. + // And this item is used for scoring one last time. + ModelState node = items.get(key); + if (node == null) { + return null; + } + update(key); + return node; } /** + * Retrieve the ModelState associated with the model Id or null if the CacheBuffer + * contains no mapping for the model Id. Compared to get method, the method won't + * increment entity priority. Used in cache buffer maintenance. * - * @return bytes consumed in the shared cache by the CacheBuffer + * @param key the model Id + * @return the Model state to which the specified model Id is mapped, or null + * if this CacheBuffer contains no mapping for the model Id */ - public long getBytesInSharedCache() { - int sharedCacheEntries = items.size() - minimumCapacity; - if (sharedCacheEntries > 0) { - return memoryConsumptionPerEntity * sharedCacheEntries; + public ModelState getWithoutUpdatePriority(String key) { + // We can get an item that is to be removed soon due to race condition. + // This is acceptable as it won't cause any corruption and exception. + // And this item is used for scoring one last time. + ModelState node = items.get(key); + if (node == null) { + return null; } - return 0; + return node; } - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) + /** + * + * If the cache is not full, check if some other items can replace internal entities + * within the same config. + * + * @param priority another entity's priority + * @return whether one entity can be replaced by another entity with a certain priority + */ + public boolean canReplaceWithinConfig(float priority) { + if (items.isEmpty()) { return false; - if (obj instanceof InitProgressProfile) { - CacheBuffer other = (CacheBuffer) obj; - - EqualsBuilder equalsBuilder = new EqualsBuilder(); - equalsBuilder.append(detectorId, other.detectorId); - - return equalsBuilder.isEquals(); } - return false; - } - - @Override - public int hashCode() { - return new HashCodeBuilder().append(detectorId).toHashCode(); - } - - @Override - public boolean expired(Duration stateTtl) { - return expired(lastUsedTime, stateTtl, clock.instant()); + Optional> minPriorityItem = priorityTracker.getMinimumPriority(); + return minPriorityItem.isPresent() && priority > minPriorityItem.get().getValue(); } - public String getId() { - return detectorId; + /** + * Replace the smallest priority entity with the input entity + * @param entityModelId the Model Id + * @param value the model State + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState replace(String entityModelId, ModelState value) { + ModelState replaced = remove(); + put(entityModelId, value); + return replaced; } - public List> getAllModels() { + public List> getAllModelStates() { return items.values().stream().collect(Collectors.toList()); } - - public PriorityTracker getPriorityTracker() { - return priorityTracker; - } - - public void setMinimumCapacity(int minimumCapacity) { - if (minimumCapacity < 0) { - throw new IllegalArgumentException("minimum capacity should be larger than or equal 0"); - } - this.minimumCapacity = minimumCapacity; - this.reservedBytes = memoryConsumptionPerEntity * minimumCapacity; - } - - public void setCheckpointIntervalHrs(int checkpointIntervalHrs) { - this.checkpointIntervalHrs = checkpointIntervalHrs; - // 0 can cause java.lang.ArithmeticException: / by zero - // negative value is meaningless - if (checkpointIntervalHrs <= 0) { - this.checkpointIntervalHrs = 1; - } - } - - public int getCheckpointIntervalHrs() { - return checkpointIntervalHrs; - } } diff --git a/src/main/java/org/opensearch/timeseries/caching/CacheProvider.java b/src/main/java/org/opensearch/timeseries/caching/CacheProvider.java new file mode 100644 index 000000000..9b4a53705 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/caching/CacheProvider.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.caching; + +import org.opensearch.common.inject.Provider; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * A wrapper to call concrete implementation of caching. Used in transport + * action. Don't use interface because transport action handler constructor + * requires a concrete class as input. + * + */ +public class CacheProvider> + implements + Provider { + private CacheType cache; + + public CacheProvider() { + + } + + @Override + public CacheType get() { + return cache; + } + + public void set(CacheType cache) { + this.cache = cache; + } +} diff --git a/src/main/java/org/opensearch/ad/caching/DoorKeeper.java b/src/main/java/org/opensearch/timeseries/caching/DoorKeeper.java similarity index 60% rename from src/main/java/org/opensearch/ad/caching/DoorKeeper.java rename to src/main/java/org/opensearch/timeseries/caching/DoorKeeper.java index 5bb5e3cd5..488b5b7b4 100644 --- a/src/main/java/org/opensearch/ad/caching/DoorKeeper.java +++ b/src/main/java/org/opensearch/timeseries/caching/DoorKeeper.java @@ -9,57 +9,51 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.HashMap; +import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.timeseries.ExpiringState; import org.opensearch.timeseries.MaintenanceState; -import com.google.common.base.Charsets; -import com.google.common.hash.BloomFilter; -import com.google.common.hash.Funnels; - /** - * A bloom filter with regular reset. + * A hashmap thats track the exact frequency of each element and reset regularly. * - * Reference: https://arxiv.org/abs/1512.00727 + * The name of door keeper derives from https://arxiv.org/abs/1512.00727 * */ public class DoorKeeper implements MaintenanceState, ExpiringState { private final Logger LOG = LogManager.getLogger(DoorKeeper.class); // stores entity's model id - private BloomFilter bloomFilter; - // the number of expected insertions to the constructed BloomFilter; must be positive private final long expectedInsertions; - // the desired false positive probability (must be positive and less than 1.0) - private final double fpp; + private Map frequencyMap; private Instant lastMaintenanceTime; private final Duration resetInterval; private final Clock clock; private Instant lastAccessTime; + private final int countThreshold; - public DoorKeeper(long expectedInsertions, double fpp, Duration resetInterval, Clock clock) { + public DoorKeeper(long expectedInsertions, Duration resetInterval, Clock clock, int countThreshold) { this.expectedInsertions = expectedInsertions; - this.fpp = fpp; this.resetInterval = resetInterval; this.clock = clock; + this.countThreshold = countThreshold; this.lastAccessTime = clock.instant(); maintenance(); } - public boolean mightContain(String modelId) { + public void put(String modelId) { this.lastAccessTime = clock.instant(); - return bloomFilter.mightContain(modelId); - } - - public boolean put(String modelId) { - this.lastAccessTime = clock.instant(); - return bloomFilter.put(modelId); + this.frequencyMap.put(modelId, this.frequencyMap.getOrDefault(modelId, 0) + 1); + if (frequencyMap.size() > expectedInsertions) { + reset(); + } } /** @@ -67,13 +61,22 @@ public boolean put(String modelId) { */ @Override public void maintenance() { - if (bloomFilter == null || lastMaintenanceTime.plus(resetInterval).isBefore(clock.instant())) { + if (frequencyMap == null || lastMaintenanceTime.plus(resetInterval).isBefore(clock.instant())) { LOG.debug("maintaining for doorkeeper"); - bloomFilter = BloomFilter.create(Funnels.stringFunnel(Charsets.US_ASCII), expectedInsertions, fpp); - lastMaintenanceTime = clock.instant(); + reset(); } } + private void reset() { + frequencyMap = new HashMap<>(); + lastMaintenanceTime = clock.instant(); + } + + public boolean appearsMoreThanThreshold(String item) { + this.lastAccessTime = clock.instant(); + return this.frequencyMap.getOrDefault(item, 0) > countThreshold; + } + @Override public boolean expired(Duration stateTtl) { // ignore stateTtl since we have customized resetInterval diff --git a/src/main/java/org/opensearch/ad/caching/PriorityCache.java b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java similarity index 64% rename from src/main/java/org/opensearch/ad/caching/PriorityCache.java rename to src/main/java/org/opensearch/timeseries/caching/PriorityCache.java index 40e28975d..47fadce63 100644 --- a/src/main/java/org/opensearch/ad/caching/PriorityCache.java +++ b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java @@ -9,16 +9,14 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_DEDICATED_CACHE_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -32,22 +30,13 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantLock; +import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Triple; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.settings.ADEnabledSetting; -import org.opensearch.ad.util.DateUtils; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; @@ -57,47 +46,59 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.MemoryTracker.Origin; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DateUtils; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; -public class PriorityCache implements EntityCache { - private final Logger LOG = LogManager.getLogger(PriorityCache.class); +public abstract class PriorityCache & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, CheckpointMaintainerType extends CheckpointMaintainWorker, CacheBufferType extends CacheBuffer> + implements + TimeSeriesCache { + + private static final Logger LOG = LogManager.getLogger(PriorityCache.class); // detector id -> CacheBuffer, weight based - private final Map activeEnities; - private final CheckpointDao checkpointDao; - private volatile int dedicatedCacheSize; + private final Map activeEnities; + private final CheckpointDaoType checkpointDao; + protected volatile int hcDedicatedCacheSize; // LRU Cache, key is model id - private Cache> inActiveEntities; - private final MemoryTracker memoryTracker; + private Cache> inActiveEntities; + protected final MemoryTracker memoryTracker; private final ReentrantLock maintenanceLock; private final int numberOfTrees; - private final Clock clock; - private final Duration modelTtl; + protected final Clock clock; + protected final Duration modelTtl; // A bloom filter placed in front of inactive entity cache to // filter out unpopular items that are not likely to appear more // than once. Key is detector id private Map doorKeepers; private ThreadPool threadPool; + private String threadPoolName; private Random random; - private CheckpointWriteWorker checkpointWriteQueue; // iterating through all of inactive entities is heavy. We don't want to do // it again and again for no obvious benefits. private Instant lastInActiveEntityMaintenance; protected int maintenanceFreqConstant; - private CheckpointMaintainWorker checkpointMaintainQueue; - private int checkpointIntervalHrs; + protected int checkpointIntervalHrs; + private Origin origin; public PriorityCache( - CheckpointDao checkpointDao, - int dedicatedCacheSize, + CheckpointDaoType checkpointDao, + int hcDedicatedCacheSize, Setting checkpointTtl, int maxInactiveStates, MemoryTracker memoryTracker, @@ -106,22 +107,24 @@ public PriorityCache( ClusterService clusterService, Duration modelTtl, ThreadPool threadPool, - CheckpointWriteWorker checkpointWriteQueue, + String threadPoolName, int maintenanceFreqConstant, - CheckpointMaintainWorker checkpointMaintainQueue, Settings settings, - Setting checkpointSavingFreq + Setting checkpointSavingFreq, + Origin origin, + Setting dedicatedCacheSizeSetting, + Setting modelMaxSizePercent ) { this.checkpointDao = checkpointDao; this.activeEnities = new ConcurrentHashMap<>(); - this.dedicatedCacheSize = dedicatedCacheSize; - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_DEDICATED_CACHE_SIZE, (it) -> { - this.dedicatedCacheSize = it; - this.setDedicatedCacheSizeListener(); + this.hcDedicatedCacheSize = hcDedicatedCacheSize; + clusterService.getClusterSettings().addSettingsUpdateConsumer(dedicatedCacheSizeSetting, (it) -> { + this.hcDedicatedCacheSize = it; + this.setHCDedicatedCacheSizeListener(); this.tryClearUpMemory(); }, this::validateDedicatedCacheSize); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MODEL_MAX_SIZE_PERCENTAGE, it -> this.tryClearUpMemory()); + clusterService.getClusterSettings().addSettingsUpdateConsumer(modelMaxSizePercent, it -> this.tryClearUpMemory()); this.memoryTracker = memoryTracker; this.maintenanceLock = new ReentrantLock(); @@ -138,56 +141,58 @@ public PriorityCache( }); this.threadPool = threadPool; + this.threadPoolName = threadPoolName; this.random = new Random(42); - this.checkpointWriteQueue = checkpointWriteQueue; this.lastInActiveEntityMaintenance = Instant.MIN; this.maintenanceFreqConstant = maintenanceFreqConstant; - this.checkpointMaintainQueue = checkpointMaintainQueue; this.checkpointIntervalHrs = DateUtils.toDuration(checkpointSavingFreq.get(settings)).toHoursPart(); clusterService.getClusterSettings().addSettingsUpdateConsumer(checkpointSavingFreq, it -> { this.checkpointIntervalHrs = DateUtils.toDuration(it).toHoursPart(); this.setCheckpointFreqListener(); }); + this.origin = origin; } @Override - public ModelState get(String modelId, AnomalyDetector detector) { - String detectorId = detector.getId(); - CacheBuffer buffer = computeBufferIfAbsent(detector, detectorId); - ModelState modelState = buffer.get(modelId); + public ModelState get(String modelId, Config config) { + String configId = config.getId(); + CacheBufferType buffer = activeEnities.get(configId); + ModelState modelState = null; + if (buffer != null) { + modelState = buffer.get(modelId); + } // during maintenance period, stop putting new entries if (!maintenanceLock.isLocked() && modelState == null) { - if (ADEnabledSetting.isDoorKeeperInCacheEnabled()) { - DoorKeeper doorKeeper = doorKeepers.computeIfAbsent(detectorId, id -> { - // reset every 60 intervals - return new DoorKeeper( - TimeSeriesSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION, - TimeSeriesSettings.DOOR_KEEPER_FALSE_POSITIVE_RATE, - detector.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), - clock - ); - }); + DoorKeeper doorKeeper = doorKeepers.computeIfAbsent(configId, id -> { + // reset every 60 intervals + return new DoorKeeper( + TimeSeriesSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION, + config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + clock, + TimeSeriesSettings.CACHE_DOOR_KEEPER_COUNT_THRESHOLD + ); + }); - // first hit, ignore - // since door keeper may get reset during maintenance, it is possible - // the entity is still active even though door keeper has no record of - // this model Id. We have to call isActive method to make sure. Otherwise, - // the entity might miss an anomaly result every 60 intervals due to door keeper - // reset. - if (!doorKeeper.mightContain(modelId) && !isActive(detectorId, modelId)) { - doorKeeper.put(modelId); - return null; - } + // first few hits, ignore + // since door keeper may get reset during maintenance, it is possible + // the entity is still active even though door keeper has no record of + // this model Id. We have to call isActive method to make sure. Otherwise, + // the entity might miss a result every 60 intervals due to door keeper + // reset. + if (!doorKeeper.appearsMoreThanThreshold(modelId) && !isActive(configId, modelId)) { + doorKeeper.put(modelId); + return null; } try { - ModelState state = inActiveEntities.get(modelId, new Callable>() { + ModelState state = inActiveEntities.get(modelId, new Callable>() { @Override - public ModelState call() { - return new ModelState<>(null, modelId, detectorId, ModelType.ENTITY.getName(), clock, 0); + public ModelState call() { + return createEmptyModelState(modelId, configId); } + }); // make sure no model has been stored due to previous race conditions @@ -214,7 +219,9 @@ public ModelState call() { // intervals. // update state using new priority or create a new one - state.setPriority(buffer.getPriorityTracker().getUpdatedPriority(state.getPriority())); + if (buffer != null) { + state.setPriority(buffer.getPriorityTracker().getUpdatedPriority(state.getPriority())); + } // adjust shared memory in case we have used dedicated cache memory for other detectors if (random.nextInt(maintenanceFreqConstant) == 1) { @@ -229,61 +236,67 @@ public ModelState call() { return modelState; } - private Optional> getStateFromInactiveEntiiyCache(String modelId) { + private Optional> getStateFromInactiveEntiiyCache(String modelId) { if (modelId == null) { return Optional.empty(); } - // null if not even recorded in inActiveEntities yet because of doorKeeper + // null if not even recorded in inActiveEntities yet because of doorKeeper or first time start config return Optional.ofNullable(inActiveEntities.getIfPresent(modelId)); } @Override - public boolean hostIfPossible(AnomalyDetector detector, ModelState toUpdate) { - if (toUpdate == null) { + public boolean hostIfPossible(Config config, ModelState toUpdate) { + if (toUpdate == null || toUpdate.getModel().isEmpty()) { + System.out.println("hello"); return false; } String modelId = toUpdate.getModelId(); - String detectorId = toUpdate.getId(); + String configId = toUpdate.getConfigId(); - if (Strings.isEmpty(modelId) || Strings.isEmpty(detectorId)) { + if (Strings.isEmpty(modelId) || Strings.isEmpty(configId)) { + System.out.println("hello2"); return false; } - CacheBuffer buffer = computeBufferIfAbsent(detector, detectorId); + CacheBufferType buffer = computeBufferIfAbsent(config, configId); - Optional> state = getStateFromInactiveEntiiyCache(modelId); - if (false == state.isPresent()) { - return false; + Optional> state = getStateFromInactiveEntiiyCache(modelId); + ModelState modelState = null; + if (state.isPresent()) { + modelState = state.get(); + } else { + modelState = createEmptyModelState(modelId, configId); } - ModelState modelState = state.get(); - float priority = modelState.getPriority(); toUpdate.setLastUsedTime(clock.instant()); toUpdate.setPriority(priority); // current buffer's dedicated cache has free slots or can allocate in shared cache - if (buffer.dedicatedCacheAvailable() || memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + if (buffer.dedicatedCacheAvailable() || memoryTracker.canAllocate(buffer.getMemoryConsumptionPerModel())) { // buffer.put will call MemoryTracker.consumeMemory buffer.put(modelId, toUpdate); + System.out.println("hello3"); return true; } - if (memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + if (memoryTracker.canAllocate(buffer.getMemoryConsumptionPerModel())) { // buffer.put will call MemoryTracker.consumeMemory buffer.put(modelId, toUpdate); + System.out.println("hello4"); return true; } // can replace an entity in the same CacheBuffer living in reserved or shared cache - if (buffer.canReplaceWithinDetector(priority)) { - ModelState removed = buffer.replace(modelId, toUpdate); + if (buffer.canReplaceWithinConfig(priority)) { + ModelState removed = buffer.replace(modelId, toUpdate); // null in the case of some other threads have emptied the queue at // the same time so there is nothing to replace if (removed != null) { addIntoInactiveCache(removed); + System.out.println("hello5"); return true; } } @@ -291,20 +304,22 @@ public boolean hostIfPossible(AnomalyDetector detector, ModelState // If two threads try to remove the same entity and add their own state, the 2nd remove // returns null and only the first one succeeds. float scaledPriority = buffer.getPriorityTracker().getScaledPriority(priority); - Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); - CacheBuffer bufferToRemove = bufferToRemoveEntity.getLeft(); + Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); + CacheBufferType bufferToRemove = bufferToRemoveEntity.getLeft(); String entityModelId = bufferToRemoveEntity.getMiddle(); - ModelState removed = null; + ModelState removed = null; if (bufferToRemove != null && ((removed = bufferToRemove.remove(entityModelId)) != null)) { buffer.put(modelId, toUpdate); addIntoInactiveCache(removed); + System.out.println("hello6"); return true; } + System.out.println("hello7"); return false; } - private void addIntoInactiveCache(ModelState removed) { + private void addIntoInactiveCache(ModelState removed) { if (removed == null) { return; } @@ -314,10 +329,10 @@ private void addIntoInactiveCache(ModelState removed) { inActiveEntities.put(removed.getModelId(), removed); } - private void addEntity(List destination, Entity entity, String detectorId) { + private void addEntity(List destination, Entity entity, String configId) { // It's possible our doorkeepr prevented the entity from entering inactive entities cache if (entity != null) { - Optional modelId = entity.getModelId(detectorId); + Optional modelId = entity.getModelId(configId); if (modelId.isPresent() && inActiveEntities.getIfPresent(modelId.get()) != null) { destination.add(entity); } @@ -325,38 +340,31 @@ private void addEntity(List destination, Entity entity, String detectorI } @Override - public Pair, List> selectUpdateCandidate( - Collection cacheMissEntities, - String detectorId, - AnomalyDetector detector - ) { + public Pair, List> selectUpdateCandidate(Collection cacheMissEntities, String configId, Config config) { List hotEntities = new ArrayList<>(); List coldEntities = new ArrayList<>(); - CacheBuffer buffer = activeEnities.get(detectorId); + CacheBufferType buffer = activeEnities.get(configId); if (buffer == null) { - // don't want to create side-effects by creating a CacheBuffer - // In current implementation, this branch is impossible as we call - // PriorityCache.get method before invoking this method. The - // PriorityCache.get method creates a CacheBuffer if not present. - // Since this method is public, need to deal with this case in case of misuse. - return Pair.of(hotEntities, coldEntities); + // when a config is just started or during run once, there is + // no cache buffer yet. Make every cache miss entities hot + return Pair.of(new ArrayList<>(cacheMissEntities), coldEntities); } Iterator cacheMissEntitiesIter = cacheMissEntities.iterator(); // current buffer's dedicated cache has free slots while (cacheMissEntitiesIter.hasNext() && buffer.dedicatedCacheAvailable()) { - addEntity(hotEntities, cacheMissEntitiesIter.next(), detectorId); + addEntity(hotEntities, cacheMissEntitiesIter.next(), configId); } - while (cacheMissEntitiesIter.hasNext() && memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + while (cacheMissEntitiesIter.hasNext() && memoryTracker.canAllocate(buffer.getMemoryConsumptionPerModel())) { // can allocate in shared cache // race conditions can happen when multiple threads evaluating this condition. // This is a problem as our AD memory usage is close to full and we put // more things than we planned. One model in HCAD is small, // it is fine we exceed a little. We have regular maintenance to remove // extra memory usage. - addEntity(hotEntities, cacheMissEntitiesIter.next(), detectorId); + addEntity(hotEntities, cacheMissEntitiesIter.next(), configId); } // check if we can replace anything in dedicated or shared cache @@ -370,23 +378,23 @@ public Pair, List> selectUpdateCandidate( // thread safe as each detector has one thread at one time and only the // thread can access its buffer. Entity entity = cacheMissEntitiesIter.next(); - Optional modelId = entity.getModelId(detectorId); + Optional modelId = entity.getModelId(configId); if (false == modelId.isPresent()) { continue; } - Optional> state = getStateFromInactiveEntiiyCache(modelId.get()); + Optional> state = getStateFromInactiveEntiiyCache(modelId.get()); if (false == state.isPresent()) { // not even recorded in inActiveEntities yet because of doorKeeper continue; } - ModelState modelState = state.get(); + ModelState modelState = state.get(); float priority = modelState.getPriority(); - if (buffer.canReplaceWithinDetector(priority)) { - addEntity(hotEntities, entity, detectorId); + if (buffer.canReplaceWithinConfig(priority)) { + addEntity(hotEntities, entity, configId); } else { // re-evaluate replacement condition in other buffers otherBufferReplaceCandidates.add(entity); @@ -395,7 +403,7 @@ public Pair, List> selectUpdateCandidate( // record current minimum priority among all detectors to save redundant // scanning of all CacheBuffers - CacheBuffer bufferToRemove = null; + CacheBufferType bufferToRemove = null; float minPriority = Float.MIN_VALUE; // check if we can replace in other CacheBuffer @@ -405,77 +413,64 @@ public Pair, List> selectUpdateCandidate( // If two threads try to remove the same entity and add their own state, the 2nd remove // returns null and only the first one succeeds. Entity entity = cacheMissEntitiesIter.next(); - Optional modelId = entity.getModelId(detectorId); + Optional modelId = entity.getModelId(configId); if (false == modelId.isPresent()) { continue; } - Optional> inactiveState = getStateFromInactiveEntiiyCache(modelId.get()); + Optional> inactiveState = getStateFromInactiveEntiiyCache(modelId.get()); if (false == inactiveState.isPresent()) { // empty state should not stand a chance to replace others continue; } - ModelState state = inactiveState.get(); + ModelState state = inactiveState.get(); float priority = state.getPriority(); float scaledPriority = buffer.getPriorityTracker().getScaledPriority(priority); if (scaledPriority <= minPriority) { // not even larger than the minPriority, we can put this to coldEntities - addEntity(coldEntities, entity, detectorId); + addEntity(coldEntities, entity, configId); continue; } // Float.MIN_VALUE means we need to re-iterate through all CacheBuffers if (minPriority == Float.MIN_VALUE) { - Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); + Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); bufferToRemove = bufferToRemoveEntity.getLeft(); minPriority = bufferToRemoveEntity.getRight(); } if (bufferToRemove != null) { - addEntity(hotEntities, entity, detectorId); + addEntity(hotEntities, entity, configId); // reset minPriority after the replacement so that we need to iterate all CacheBuffer // again minPriority = Float.MIN_VALUE; } else { // after trying everything, we can now safely put this to cold entities list - addEntity(coldEntities, entity, detectorId); + addEntity(coldEntities, entity, configId); } } return Pair.of(hotEntities, coldEntities); } - private CacheBuffer computeBufferIfAbsent(AnomalyDetector detector, String detectorId) { - CacheBuffer buffer = activeEnities.get(detectorId); + private CacheBufferType computeBufferIfAbsent(Config config, String configId) { + CacheBufferType buffer = activeEnities.get(configId); if (buffer == null) { - long requiredBytes = getRequiredMemory(detector, dedicatedCacheSize); + long requiredBytes = getRequiredMemory(config, config.isHighCardinality() ? hcDedicatedCacheSize : 1); if (memoryTracker.canAllocateReserved(requiredBytes)) { - memoryTracker.consumeMemory(requiredBytes, true, Origin.REAL_TIME_DETECTOR); - long intervalSecs = detector.getIntervalInSeconds(); - - buffer = new CacheBuffer( - dedicatedCacheSize, - intervalSecs, - getRequiredMemory(detector, 1), - memoryTracker, - clock, - modelTtl, - detectorId, - checkpointWriteQueue, - checkpointMaintainQueue, - checkpointIntervalHrs - ); - activeEnities.put(detectorId, buffer); + memoryTracker.consumeMemory(requiredBytes, true, origin); + buffer = createEmptyCacheBuffer(config, requiredBytes); + activeEnities.put(configId, buffer); // There can be race conditions between tryClearUpMemory and // activeEntities.put above as tryClearUpMemory accesses activeEnities too. // Put tryClearUpMemory after consumeMemory to prevent that. tryClearUpMemory(); } else { - throw new LimitExceededException(detectorId, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG); + throw new LimitExceededException(configId, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG); } } @@ -484,20 +479,12 @@ private CacheBuffer computeBufferIfAbsent(AnomalyDetector detector, String detec /** * - * @param detector Detector config accessor + * @param config Detector config accessor * @param numberOfEntity number of entities * @return Memory in bytes required for hosting numberOfEntity entities */ - private long getRequiredMemory(AnomalyDetector detector, int numberOfEntity) { - int dimension = detector.getEnabledFeatureIds().size() * detector.getShingleSize(); - return numberOfEntity * memoryTracker - .estimateTRCFModelSize( - dimension, - numberOfTrees, - TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO, - detector.getShingleSize().intValue(), - true - ); + private long getRequiredMemory(Config config, int numberOfEntity) { + return numberOfEntity * getRequiredMemoryPerEntity(config, memoryTracker, numberOfTrees); } /** @@ -511,12 +498,12 @@ private long getRequiredMemory(AnomalyDetector detector, int numberOfEntity) { * @param candidatePriority the candidate entity's priority * @return the CacheBuffer if we can find a CacheBuffer to make room for the candidate entity */ - private Triple canReplaceInSharedCache(CacheBuffer originBuffer, float candidatePriority) { - CacheBuffer minPriorityBuffer = null; + private Triple canReplaceInSharedCache(CacheBufferType originBuffer, float candidatePriority) { + CacheBufferType minPriorityBuffer = null; float minPriority = candidatePriority; String minPriorityEntityModelId = null; - for (Map.Entry entry : activeEnities.entrySet()) { - CacheBuffer buffer = entry.getValue(); + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBufferType buffer = entry.getValue(); if (buffer != originBuffer && buffer.canRemove()) { Optional> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); if (!priorityEntry.isPresent()) { @@ -536,12 +523,12 @@ private Triple canReplaceInSharedCache(CacheBuffer o /** * Clear up overused memory. Can happen due to race condition or other detectors * consumes resources from shared memory. - * tryClearUpMemory is ran using AD threadpool because the function is expensive. + * tryClearUpMemory is ran using analysis-specific threadpool because the function is expensive. */ private void tryClearUpMemory() { try { if (maintenanceLock.tryLock()) { - threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> clearMemory()); + threadPool.executor(threadPoolName).execute(() -> clearMemory()); } else { threadPool.schedule(() -> { try { @@ -549,7 +536,7 @@ private void tryClearUpMemory() { } catch (Exception e) { LOG.error("Fail to clear up memory taken by CacheBuffer. Will retry during maintenance."); } - }, new TimeValue(random.nextInt(90), TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + }, new TimeValue(random.nextInt(90), TimeUnit.SECONDS), threadPoolName); } } finally { if (maintenanceLock.isHeldByCurrentThread()) { @@ -561,12 +548,12 @@ private void tryClearUpMemory() { private void clearMemory() { recalculateUsedMemory(); long memoryToShed = memoryTracker.memoryToShed(); - PriorityQueue> removalCandiates = null; + PriorityQueue> removalCandiates = null; if (memoryToShed > 0) { // sort the triple in an ascending order of priority removalCandiates = new PriorityQueue<>((x, y) -> Float.compare(x.getLeft(), y.getLeft())); - for (Map.Entry entry : activeEnities.entrySet()) { - CacheBuffer buffer = entry.getValue(); + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBufferType buffer = entry.getValue(); Optional> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); if (!priorityEntry.isPresent()) { continue; @@ -579,12 +566,12 @@ private void clearMemory() { } while (memoryToShed > 0) { if (false == removalCandiates.isEmpty()) { - Triple toRemove = removalCandiates.poll(); - CacheBuffer minPriorityBuffer = toRemove.getMiddle(); + Triple toRemove = removalCandiates.poll(); + CacheBufferType minPriorityBuffer = toRemove.getMiddle(); String minPriorityEntityModelId = toRemove.getRight(); - ModelState removed = minPriorityBuffer.remove(minPriorityEntityModelId); - memoryToShed -= minPriorityBuffer.getMemoryConsumptionPerEntity(); + ModelState removed = minPriorityBuffer.remove(minPriorityEntityModelId); + memoryToShed -= minPriorityBuffer.getMemoryConsumptionPerModel(); addIntoInactiveCache(removed); if (minPriorityBuffer.canRemove()) { @@ -609,12 +596,12 @@ private void clearMemory() { private void recalculateUsedMemory() { long reserved = 0; long shared = 0; - for (Map.Entry entry : activeEnities.entrySet()) { - CacheBuffer buffer = entry.getValue(); + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBufferType buffer = entry.getValue(); reserved += buffer.getReservedBytes(); shared += buffer.getBytesInSharedCache(); } - memoryTracker.syncMemoryState(Origin.REAL_TIME_DETECTOR, reserved + shared, reserved); + memoryTracker.syncMemoryState(origin, reserved + shared, reserved); } /** @@ -630,15 +617,15 @@ public void maintenance() { // clean up memory if we allocate more memory than we should tryClearUpMemory(); activeEnities.entrySet().stream().forEach(cacheBufferEntry -> { - String detectorId = cacheBufferEntry.getKey(); - CacheBuffer cacheBuffer = cacheBufferEntry.getValue(); + String configId = cacheBufferEntry.getKey(); + CacheBufferType cacheBuffer = cacheBufferEntry.getValue(); // remove expired cache buffer if (cacheBuffer.expired(modelTtl)) { - activeEnities.remove(detectorId); + activeEnities.remove(configId); cacheBuffer.clear(); } else { - List> removedStates = cacheBuffer.maintenance(); - for (ModelState state : removedStates) { + List> removedStates = cacheBuffer.maintenance(); + for (ModelState state : removedStates) { addIntoInactiveCache(state); } } @@ -647,11 +634,11 @@ public void maintenance() { maintainInactiveCache(); doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { - String detectorId = doorKeeperEntry.getKey(); + String configId = doorKeeperEntry.getKey(); DoorKeeper doorKeeper = doorKeeperEntry.getValue(); // doorKeeper has its own state ttl if (doorKeeper.expired(null)) { - doorKeepers.remove(detectorId); + doorKeepers.remove(configId); } else { doorKeeper.maintenance(); } @@ -666,19 +653,19 @@ public void maintenance() { /** * Permanently deletes models hosted in memory and persisted in index. * - * @param detectorId id the of the detector for which models are to be permanently deleted + * @param configId id the of the config for which models are to be permanently deleted */ @Override - public void clear(String detectorId) { - if (Strings.isEmpty(detectorId)) { + public void clear(String configId) { + if (Strings.isEmpty(configId)) { return; } - CacheBuffer buffer = activeEnities.remove(detectorId); + CacheBufferType buffer = activeEnities.remove(configId); if (buffer != null) { buffer.clear(); } - checkpointDao.deleteModelCheckpointByDetectorId(detectorId); - doorKeepers.remove(detectorId); + checkpointDao.deleteModelCheckpointByConfigId(configId); + doorKeepers.remove(configId); } /** @@ -688,7 +675,7 @@ public void clear(String detectorId) { */ @Override public int getActiveEntities(String detectorId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + CacheBufferType cacheBuffer = activeEnities.get(detectorId); if (cacheBuffer != null) { return cacheBuffer.getActiveEntities(); } @@ -697,13 +684,13 @@ public int getActiveEntities(String detectorId) { /** * Whether an entity is active or not - * @param detectorId The Id of the detector that an entity belongs to + * @param configId The Id of the detector that an entity belongs to * @param entityModelId Entity's Model Id * @return Whether an entity is active or not */ @Override - public boolean isActive(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + public boolean isActive(String configId, String entityModelId) { + CacheBufferType cacheBuffer = activeEnities.get(configId); if (cacheBuffer != null) { return cacheBuffer.isActive(entityModelId); } @@ -711,32 +698,22 @@ public boolean isActive(String detectorId, String entityModelId) { } @Override - public long getTotalUpdates(String detectorId) { + public long getTotalUpdates(String configId) { return Optional .of(activeEnities) - .map(entities -> entities.get(detectorId)) + .map(entities -> entities.get(configId)) .map(buffer -> buffer.getPriorityTracker().getHighestPriorityEntityId()) .map(entityModelIdOptional -> entityModelIdOptional.get()) - .map(entityModelId -> getTotalUpdates(detectorId, entityModelId)) + .map(entityModelId -> getTotalUpdates(configId, entityModelId)) .orElse(0L); } @Override - public long getTotalUpdates(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); - if (cacheBuffer != null) { - Optional modelOptional = cacheBuffer.getModel(entityModelId); - // TODO: make it work for shingles. samples.size() is not the real shingle - long accumulatedShingles = modelOptional - .flatMap(model -> model.getTrcf()) - .map(trcf -> trcf.getForest()) - .map(rcf -> rcf.getTotalUpdates()) - .orElseGet( - () -> modelOptional.map(model -> model.getSamples()).map(samples -> samples.size()).map(Long::valueOf).orElse(0L) - ); - return accumulatedShingles; - } - return 0L; + public long getTotalUpdates(String configId, String entityModelId) { + return Optional + .ofNullable(activeEnities.get(configId)) + .map(cacheBuffer -> getTotalUpdates(cacheBuffer.getModelState(entityModelId))) + .orElse(0L); } /** @@ -756,24 +733,25 @@ public int getTotalActiveEntities() { * @return list of modelStates */ @Override - public List> getAllModels() { - List> states = new ArrayList<>(); - activeEnities.values().stream().forEach(cacheBuffer -> states.addAll(cacheBuffer.getAllModels())); + public List> getAllModels() { + List> states = new ArrayList<>(); + activeEnities.values().stream().forEach(cacheBuffer -> states.addAll(cacheBuffer.getAllModelStates())); return states; } /** - * Gets all of a detector's model sizes hosted on a node + * Gets all of a config's model sizes hosted on a node * + * @param configId config Id * @return a map of model id to its memory size */ @Override - public Map getModelSize(String detectorId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + public Map getModelSize(String configId) { + CacheBufferType cacheBuffer = activeEnities.get(configId); Map res = new HashMap<>(); if (cacheBuffer != null) { - long size = cacheBuffer.getMemoryConsumptionPerEntity(); - cacheBuffer.getAllModels().forEach(entry -> res.put(entry.getModelId(), size)); + long size = cacheBuffer.getMemoryConsumptionPerModel(); + cacheBuffer.getAllModelStates().forEach(entry -> res.put(entry.getModelId(), size)); } return res; } @@ -792,8 +770,8 @@ public Map getModelSize(String detectorId) { * milliseconds when the entity's state is lastly used. Otherwise, return -1. */ @Override - public long getLastActiveMs(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + public long getLastActiveTime(String detectorId, String entityModelId) { + CacheBufferType cacheBuffer = activeEnities.get(detectorId); long lastUsedMs = -1; if (cacheBuffer != null) { lastUsedMs = cacheBuffer.getLastUsedTime(entityModelId); @@ -801,7 +779,7 @@ public long getLastActiveMs(String detectorId, String entityModelId) { return lastUsedMs; } } - ModelState stateInActive = inActiveEntities.getIfPresent(entityModelId); + ModelState stateInActive = inActiveEntities.getIfPresent(entityModelId); if (stateInActive != null) { lastUsedMs = stateInActive.getLastUsedTime().toEpochMilli(); } @@ -815,7 +793,7 @@ public void releaseMemoryForOpenCircuitBreaker() { tryClearUpMemory(); activeEnities.values().stream().forEach(cacheBuffer -> { if (cacheBuffer.canRemove()) { - ModelState removed = cacheBuffer.remove(); + ModelState removed = cacheBuffer.remove(); addIntoInactiveCache(removed); } }); @@ -831,9 +809,9 @@ private void maintainInactiveCache() { inActiveEntities.cleanUp(); // // make sure no model has been stored due to bugs - for (ModelState state : inActiveEntities.asMap().values()) { - EntityModel model = state.getModel(); - if (model != null && model.getTrcf().isPresent()) { + for (ModelState state : inActiveEntities.asMap().values()) { + Optional modelOptional = state.getModel(); + if (modelOptional.isPresent()) { LOG.warn(new ParameterizedMessage("Inactive entity's model is null: [{}]. Maybe there are bugs.", state.getModelId())); state.setModel(null); } @@ -846,8 +824,8 @@ private void maintainInactiveCache() { * Called when dedicated cache size changes. Will adjust existing cache buffer's * cache size */ - private void setDedicatedCacheSizeListener() { - activeEnities.values().stream().forEach(cacheBuffer -> cacheBuffer.setMinimumCapacity(dedicatedCacheSize)); + private void setHCDedicatedCacheSizeListener() { + activeEnities.values().stream().forEach(cacheBuffer -> cacheBuffer.setMinimumCapacity(hcDedicatedCacheSize)); } private void setCheckpointFreqListener() { @@ -856,20 +834,16 @@ private void setCheckpointFreqListener() { @Override public List getAllModelProfile(String detectorId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); - List res = new ArrayList<>(); + CacheBufferType cacheBuffer = activeEnities.get(detectorId); if (cacheBuffer != null) { - long size = cacheBuffer.getMemoryConsumptionPerEntity(); - cacheBuffer.getAllModels().forEach(entry -> { - EntityModel model = entry.getModel(); - Entity entity = null; - if (model != null && model.getEntity().isPresent()) { - entity = model.getEntity().get(); - } - res.add(new ModelProfile(entry.getModelId(), entity, size)); - }); + long size = cacheBuffer.getMemoryConsumptionPerModel(); + return cacheBuffer + .getAllModelStates() + .stream() + .map(entry -> new ModelProfile(entry.getModelId(), entry.getEntity().orElse(null), size)) + .collect(Collectors.toList()); } - return res; + return Collections.emptyList(); } /** @@ -881,14 +855,14 @@ public List getAllModelProfile(String detectorId) { */ @Override public Optional getModelProfile(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); - if (cacheBuffer != null && cacheBuffer.getModel(entityModelId).isPresent()) { - EntityModel model = cacheBuffer.getModel(entityModelId).get(); + CacheBufferType cacheBuffer = activeEnities.get(detectorId); + if (cacheBuffer != null && cacheBuffer.getModelState(entityModelId) != null) { + ModelState modelState = cacheBuffer.getModelState(entityModelId); Entity entity = null; - if (model != null && model.getEntity().isPresent()) { - entity = model.getEntity().get(); + if (modelState != null && modelState.getEntity().isPresent()) { + entity = modelState.getEntity().get(); } - return Optional.of(new ModelProfile(entityModelId, entity, cacheBuffer.getMemoryConsumptionPerEntity())); + return Optional.of(new ModelProfile(entityModelId, entity, cacheBuffer.getMemoryConsumptionPerModel())); } return Optional.empty(); } @@ -900,11 +874,11 @@ public Optional getModelProfile(String detectorId, String entityMo * @param newDedicatedCacheSize the new dedicated cache size to validate */ private void validateDedicatedCacheSize(Integer newDedicatedCacheSize) { - if (this.dedicatedCacheSize < newDedicatedCacheSize) { - int delta = newDedicatedCacheSize - this.dedicatedCacheSize; + if (this.hcDedicatedCacheSize < newDedicatedCacheSize) { + int delta = newDedicatedCacheSize - this.hcDedicatedCacheSize; long totalIncreasedBytes = 0; - for (CacheBuffer cacheBuffer : activeEnities.values()) { - totalIncreasedBytes += cacheBuffer.getMemoryConsumptionPerEntity() * delta; + for (CacheBufferType cacheBuffer : activeEnities.values()) { + totalIncreasedBytes += cacheBuffer.getMemoryConsumptionPerModel() * delta; } if (false == memoryTracker.canAllocateReserved(totalIncreasedBytes)) { @@ -915,13 +889,13 @@ private void validateDedicatedCacheSize(Integer newDedicatedCacheSize) { /** * Get a model state without incurring priority update. Used in maintenance. - * @param detectorId Detector Id + * @param configId Config Id * @param modelId Model Id * @return Model state */ @Override - public Optional> getForMaintainance(String detectorId, String modelId) { - CacheBuffer buffer = activeEnities.get(detectorId); + public Optional> getForMaintainance(String configId, String modelId) { + CacheBufferType buffer = activeEnities.get(configId); if (buffer == null) { return Optional.empty(); } @@ -929,31 +903,31 @@ public Optional> getForMaintainance(String detectorId, S } /** - * Remove entity model from active entity buffer and delete checkpoint. Used to clean corrupted model. - * @param detectorId Detector Id - * @param entityModelId Model Id + * Remove model from active entity buffer and delete checkpoint. Used to clean corrupted model. + * @param configId config Id + * @param modelId Model Id */ @Override - public void removeEntityModel(String detectorId, String entityModelId) { - CacheBuffer buffer = activeEnities.get(detectorId); + public void removeModel(String configId, String modelId) { + CacheBufferType buffer = activeEnities.get(configId); if (buffer != null) { - ModelState removed = null; - if ((removed = buffer.remove(entityModelId, false)) != null) { + ModelState removed = buffer.remove(modelId, false); + if (removed != null) { addIntoInactiveCache(removed); } } checkpointDao .deleteModelCheckpoint( - entityModelId, + modelId, ActionListener .wrap( - r -> LOG.debug(new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", entityModelId)), - e -> LOG.error(new ParameterizedMessage("Failed to delete checkpoint [{}].", entityModelId), e) + r -> LOG.debug(new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", modelId)), + e -> LOG.error(new ParameterizedMessage("Failed to delete checkpoint [{}].", modelId), e) ) ); } - private Cache> createInactiveCache(Duration inactiveEntityTtl, int maxInactiveStates) { + private Cache> createInactiveCache(Duration inactiveEntityTtl, int maxInactiveStates) { return CacheBuilder .newBuilder() .expireAfterAccess(inactiveEntityTtl.toHours(), TimeUnit.HOURS) @@ -961,4 +935,8 @@ private Cache> createInactiveCache(Duration inac .concurrencyLevel(1) .build(); } + + protected abstract ModelState createEmptyModelState(String modelId, String configId); + + protected abstract CacheBufferType createEmptyCacheBuffer(Config config, long memoryConsumptionPerEntity); } diff --git a/src/main/java/org/opensearch/ad/caching/PriorityTracker.java b/src/main/java/org/opensearch/timeseries/caching/PriorityTracker.java similarity index 97% rename from src/main/java/org/opensearch/ad/caching/PriorityTracker.java rename to src/main/java/org/opensearch/timeseries/caching/PriorityTracker.java index 439d67679..07f2087ec 100644 --- a/src/main/java/org/opensearch/ad/caching/PriorityTracker.java +++ b/src/main/java/org/opensearch/timeseries/caching/PriorityTracker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.util.AbstractMap.SimpleImmutableEntry; @@ -236,7 +236,7 @@ public void updatePriority(String entityId) { * @param entityId Entity Id * @param priority priority */ - protected void addPriority(String entityId, float priority) { + public void addPriority(String entityId, float priority) { PriorityNode node = new PriorityNode(entityId, priority); key2Priority.put(entityId, node); priorityList.add(node); @@ -260,7 +260,7 @@ private void adjustSizeIfRequired() { * Remove an entity in the tracker * @param entityId Entity Id */ - protected void removePriority(String entityId) { + public void removePriority(String entityId) { // remove if the key matches; priority does not matter priorityList.remove(new PriorityNode(entityId, 0)); key2Priority.remove(entityId); @@ -269,7 +269,7 @@ protected void removePriority(String entityId) { /** * Remove all of entities */ - protected void clearPriority() { + public void clearPriority() { key2Priority.clear(); priorityList.clear(); } @@ -292,7 +292,7 @@ protected void clearPriority() { * * @return new priority */ - float getUpdatedPriority(float oldPriority) { + public float getUpdatedPriority(float oldPriority) { long increment = computeWeightedPriorityIncrement(); oldPriority += Math.log(1 + Math.exp(increment - oldPriority)); // if overflow happens, using the most recent decayed count instead. @@ -319,7 +319,7 @@ float getUpdatedPriority(float oldPriority) { * @param currentPriority Current priority * @return the scaled priority */ - float getScaledPriority(float currentPriority) { + public float getScaledPriority(float currentPriority) { return currentPriority - computeWeightedPriorityIncrement(); } diff --git a/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java b/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java new file mode 100644 index 000000000..fa5b0c1eb --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java @@ -0,0 +1,187 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.caching; + +import java.util.Collection; +import java.util.List; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.timeseries.AnalysisModelSize; +import org.opensearch.timeseries.CleanState; +import org.opensearch.timeseries.MaintenanceState; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public interface TimeSeriesCache extends MaintenanceState, CleanState, AnalysisModelSize { + /** + * + * @param config Analysis config + * @param toUpdate Model state candidate + * @return if we can host the given model state + */ + boolean hostIfPossible(Config config, ModelState toUpdate); + + /** + * Get a model state without incurring priority update or load from state from disk. Used in maintenance. + * @param configId Config Id + * @param modelId Model Id + * @return Model state + */ + Optional> getForMaintainance(String configId, String modelId); + + /** + * Get the ModelState associated with the modelId. May or may not load the + * ModelState depending on the underlying cache's memory consumption. + * + * @param modelId Model Id + * @param config config accessor + * @return the ModelState associated with the config or null if no cached item + * for the config + */ + ModelState get(String modelId, Config config); + + /** + * Whether an entity is active or not + * @param configId The Id of the config that an entity belongs to + * @param entityModelId Entity model Id + * @return Whether an entity is active or not + */ + boolean isActive(String configId, String entityModelId); + + /** + * Get total updates of the config's most active entity's RCF model. + * + * @param configId detector id + * @return RCF model total updates of most active entity. + */ + long getTotalUpdates(String configId); + + /** + * Get RCF model total updates of specific entity + * + * @param configId config id + * @param entityModelId entity model id + * @return RCF model total updates of specific entity. + */ + long getTotalUpdates(String configId, String entityModelId); + + /** + * Gets modelStates of all model hosted on a node + * + * @return list of modelStates + */ + List> getAllModels(); + + /** + * Get the number of active entities of a config + * @param configId Config Id + * @return The number of active entities + */ + int getActiveEntities(String configId); + + /** + * + * @return total active entities in the cache + */ + int getTotalActiveEntities(); + + /** + * Return when the last active time of an entity's state. + * + * If the entity's state is active in the cache, the value indicates when the cache + * is lastly accessed (get/put). If the entity's state is inactive in the cache, + * the value indicates when the cache state is created or when the entity is evicted + * from active entity cache. + * + * @param configId The Id of the config that an entity belongs to + * @param entityModelId Entity's Model Id + * @return if the entity is in the cache, return the timestamp in epoch + * milliseconds when the entity's state is lastly used. Otherwise, return -1. + */ + long getLastActiveTime(String configId, String entityModelId); + + /** + * Release memory when memory circuit breaker is open + */ + void releaseMemoryForOpenCircuitBreaker(); + + /** + * Select candidate entities for which we can load models + * @param cacheMissEntities Cache miss entities + * @param configId Config Id + * @param config Config object + * @return A list of entities that are admitted into the cache as a result of the + * update and the left-over entities + */ + Pair, List> selectUpdateCandidate(Collection cacheMissEntities, String configId, Config config); + + /** + * + * @param configId Detector Id + * @return a detector's model information + */ + List getAllModelProfile(String configId); + + /** + * Gets an entity's model sizes + * + * @param configId Detector Id + * @param entityModelId Entity's model Id + * @return the entity's memory size + */ + Optional getModelProfile(String configId, String entityModelId); + + /** + * Remove entity model from active entity buffer and delete checkpoint. Used to clean corrupted model. + * @param configId config Id + * @param entityModelId Model Id + */ + void removeModel(String configId, String entityModelId); + + /** + * + * @param config Detector config accessor + * @param memoryTracker memory tracker + * @param numberOfTrees number of trees + * @return Memory in bytes required for hosting one entity model + */ + default long getRequiredMemoryPerEntity(Config config, MemoryTracker memoryTracker, int numberOfTrees) { + int dimension = config.getEnabledFeatureIds().size() * config.getShingleSize(); + return memoryTracker + .estimateTRCFModelSize( + dimension, + numberOfTrees, + TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO, + config.getShingleSize().intValue(), + true + ); + } + + default long getTotalUpdates(ModelState modelState) { + // TODO: make it work for shingles. samples.size() is not the real shingle + long accumulatedShingles = Optional + .ofNullable(modelState) + .flatMap(model -> model.getModel()) + .map(trcf -> trcf.getForest()) + .map(rcf -> rcf.getTotalUpdates()) + .orElseGet( + () -> Optional + .ofNullable(modelState) + .map(model -> model.getSamples()) + .map(samples -> samples.size()) + .map(Long::valueOf) + .orElse(0L) + ); + return accumulatedShingles; + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java b/src/main/java/org/opensearch/timeseries/cluster/ADDataMigrator.java similarity index 97% rename from src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java rename to src/main/java/org/opensearch/timeseries/cluster/ADDataMigrator.java index 4050c22f5..7dfcf37fb 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java +++ b/src/main/java/org/opensearch/timeseries/cluster/ADDataMigrator.java @@ -9,12 +9,10 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; -import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_DETECTOR_UPPER_LIMIT; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.model.TaskType.taskTypeToString; @@ -59,6 +57,7 @@ import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; import org.opensearch.timeseries.util.ExceptionUtil; /** @@ -212,10 +211,10 @@ private void checkIfRealtimeTaskExistsAndBackfill( BoolQueryBuilder query = new BoolQueryBuilder(); query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, jobId)); if (job.isEnabled()) { - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); } - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(ADTaskType.REALTIME_TASK_TYPES))); + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(ADTaskType.REALTIME_TASK_TYPES))); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(1); SearchRequest searchRequest = new SearchRequest(DETECTION_STATE_INDEX).source(searchSourceBuilder); client.search(searchRequest, ActionListener.wrap(r -> { diff --git a/src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java b/src/main/java/org/opensearch/timeseries/cluster/ClusterEventListener.java similarity index 72% rename from src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java rename to src/main/java/org/opensearch/timeseries/cluster/ClusterEventListener.java index 4f629c7bb..3712bfb73 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java +++ b/src/main/java/org/opensearch/timeseries/cluster/ClusterEventListener.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import java.util.concurrent.Semaphore; @@ -23,18 +23,18 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.gateway.GatewayService; -public class ADClusterEventListener implements ClusterStateListener { - private static final Logger LOG = LogManager.getLogger(ADClusterEventListener.class); - static final String NOT_RECOVERED_MSG = "Cluster is not recovered yet."; - static final String IN_PROGRESS_MSG = "Cluster state change in progress, return."; - static final String NODE_CHANGED_MSG = "Cluster node changed"; +public class ClusterEventListener implements ClusterStateListener { + private static final Logger LOG = LogManager.getLogger(ClusterEventListener.class); + public static final String NOT_RECOVERED_MSG = "Cluster is not recovered yet."; + public static final String IN_PROGRESS_MSG = "Cluster state change in progress, return."; + public static final String NODE_CHANGED_MSG = "Cluster node changed"; private final Semaphore inProgress; private HashRing hashRing; private final ClusterService clusterService; @Inject - public ADClusterEventListener(ClusterService clusterService, HashRing hashRing) { + public ClusterEventListener(ClusterService clusterService, HashRing hashRing) { this.clusterService = clusterService; this.clusterService.addListener(this); this.hashRing = hashRing; @@ -55,16 +55,13 @@ public void clusterChanged(ClusterChangedEvent event) { } try { - // Init AD version hash ring as early as possible. Some test case may fail as AD + // Init version hash ring as early as possible. Some test case may fail as AD // version hash ring not initialized when test run. if (!hashRing.isHashRingInited()) { hashRing .buildCircles( ActionListener - .wrap( - r -> LOG.info("Init AD version hash ring successfully"), - e -> LOG.error("Failed to init AD version hash ring") - ) + .wrap(r -> LOG.info("Init version hash ring successfully"), e -> LOG.error("Failed to init version hash ring")) ); } Delta delta = event.nodesDelta(); @@ -74,7 +71,7 @@ public void clusterChanged(ClusterChangedEvent event) { hashRing.addNodeChangeEvent(); hashRing.buildCircles(delta, ActionListener.runAfter(ActionListener.wrap(hasRingBuildDone -> { LOG.info("Hash ring build result: {}", hasRingBuildDone); - }, e -> { LOG.error("Failed updating AD version hash ring", e); }), () -> inProgress.release())); + }, e -> { LOG.error("Failed updating version hash ring", e); }), () -> inProgress.release())); } else { inProgress.release(); } diff --git a/src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java b/src/main/java/org/opensearch/timeseries/cluster/ClusterManagerEventListener.java similarity index 56% rename from src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java rename to src/main/java/org/opensearch/timeseries/cluster/ClusterManagerEventListener.java index 8b8a40405..e4bd8ae1d 100644 --- a/src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java +++ b/src/main/java/org/opensearch/timeseries/cluster/ClusterManagerEventListener.java @@ -9,14 +9,14 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import java.time.Clock; import java.time.Duration; +import java.util.Arrays; +import java.util.List; -import org.opensearch.ad.cluster.diskcleanup.IndexCleanup; -import org.opensearch.ad.cluster.diskcleanup.ModelCheckpointIndexRetention; -import org.opensearch.ad.util.DateUtils; +import org.opensearch.ad.cluster.diskcleanup.ADCheckpointIndexRetention; import org.opensearch.client.Client; import org.opensearch.cluster.LocalNodeClusterManagerListener; import org.opensearch.cluster.service.ClusterService; @@ -24,16 +24,20 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.forecast.cluster.diskcleanup.ForecastCheckpointIndexRetention; import org.opensearch.threadpool.Scheduler.Cancellable; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.cluster.diskcleanup.IndexCleanup; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.DateUtils; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import com.google.common.annotations.VisibleForTesting; public class ClusterManagerEventListener implements LocalNodeClusterManagerListener { - private Cancellable checkpointIndexRetentionCron; + private Cancellable adCheckpointIndexRetentionCron; + private Cancellable forecastCheckpointIndexRetentionCron; private Cancellable hourlyCron; private ClusterService clusterService; private ThreadPool threadPool; @@ -41,7 +45,8 @@ public class ClusterManagerEventListener implements LocalNodeClusterManagerListe private Clock clock; private ClientUtil clientUtil; private DiscoveryNodeFilterer nodeFilter; - private Duration checkpointTtlDuration; + private Duration adCheckpointTtlDuration; + private Duration forecastCheckpointTtlDuration; public ClusterManagerEventListener( ClusterService clusterService, @@ -50,7 +55,8 @@ public ClusterManagerEventListener( Clock clock, ClientUtil clientUtil, DiscoveryNodeFilterer nodeFilter, - Setting checkpointTtl, + Setting adCheckpointTtl, + Setting forecastCheckpointTtl, Settings settings ) { this.clusterService = clusterService; @@ -61,15 +67,22 @@ public ClusterManagerEventListener( this.clientUtil = clientUtil; this.nodeFilter = nodeFilter; - this.checkpointTtlDuration = DateUtils.toDuration(checkpointTtl.get(settings)); + this.adCheckpointTtlDuration = DateUtils.toDuration(adCheckpointTtl.get(settings)); + this.forecastCheckpointTtlDuration = DateUtils.toDuration(forecastCheckpointTtl.get(settings)); - clusterService.getClusterSettings().addSettingsUpdateConsumer(checkpointTtl, it -> { - this.checkpointTtlDuration = DateUtils.toDuration(it); - cancel(checkpointIndexRetentionCron); + clusterService.getClusterSettings().addSettingsUpdateConsumer(adCheckpointTtl, it -> { + this.adCheckpointTtlDuration = DateUtils.toDuration(it); + cancel(adCheckpointIndexRetentionCron); IndexCleanup indexCleanup = new IndexCleanup(client, clientUtil, clusterService); - checkpointIndexRetentionCron = threadPool + adCheckpointIndexRetentionCron = threadPool .scheduleWithFixedDelay( - new ModelCheckpointIndexRetention(checkpointTtlDuration, clock, indexCleanup), + new ADCheckpointIndexRetention(adCheckpointTtlDuration, clock, indexCleanup), + TimeValue.timeValueHours(24), + executorName() + ); + forecastCheckpointIndexRetentionCron = threadPool + .scheduleWithFixedDelay( + new ForecastCheckpointIndexRetention(forecastCheckpointTtlDuration, clock, indexCleanup), TimeValue.timeValueHours(24), executorName() ); @@ -89,19 +102,27 @@ public void beforeStop() { }); } - if (checkpointIndexRetentionCron == null) { + if (adCheckpointIndexRetentionCron == null) { IndexCleanup indexCleanup = new IndexCleanup(client, clientUtil, clusterService); - checkpointIndexRetentionCron = threadPool + adCheckpointIndexRetentionCron = threadPool + .scheduleWithFixedDelay( + new ADCheckpointIndexRetention(adCheckpointTtlDuration, clock, indexCleanup), + TimeValue.timeValueHours(24), + executorName() + ); + forecastCheckpointIndexRetentionCron = threadPool .scheduleWithFixedDelay( - new ModelCheckpointIndexRetention(checkpointTtlDuration, clock, indexCleanup), + new ForecastCheckpointIndexRetention(forecastCheckpointTtlDuration, clock, indexCleanup), TimeValue.timeValueHours(24), executorName() ); clusterService.addLifecycleListener(new LifecycleListener() { @Override public void beforeStop() { - cancel(checkpointIndexRetentionCron); - checkpointIndexRetentionCron = null; + cancel(adCheckpointIndexRetentionCron); + adCheckpointIndexRetentionCron = null; + cancel(forecastCheckpointIndexRetentionCron); + forecastCheckpointIndexRetentionCron = null; } }); } @@ -110,9 +131,11 @@ public void beforeStop() { @Override public void offClusterManager() { cancel(hourlyCron); - cancel(checkpointIndexRetentionCron); hourlyCron = null; - checkpointIndexRetentionCron = null; + cancel(adCheckpointIndexRetentionCron); + adCheckpointIndexRetentionCron = null; + cancel(forecastCheckpointIndexRetentionCron); + forecastCheckpointIndexRetentionCron = null; } private void cancel(Cancellable cron) { @@ -122,11 +145,11 @@ private void cancel(Cancellable cron) { } @VisibleForTesting - protected Cancellable getCheckpointIndexRetentionCron() { - return checkpointIndexRetentionCron; + public List getCheckpointIndexRetentionCron() { + return Arrays.asList(adCheckpointIndexRetentionCron, forecastCheckpointIndexRetentionCron); } - protected Cancellable getHourlyCron() { + public Cancellable getHourlyCron() { return hourlyCron; } diff --git a/src/main/java/org/opensearch/ad/cluster/DailyCron.java b/src/main/java/org/opensearch/timeseries/cluster/DailyCron.java similarity index 86% rename from src/main/java/org/opensearch/ad/cluster/DailyCron.java rename to src/main/java/org/opensearch/timeseries/cluster/DailyCron.java index 2692608d2..4a7c9a6d5 100644 --- a/src/main/java/org/opensearch/ad/cluster/DailyCron.java +++ b/src/main/java/org/opensearch/timeseries/cluster/DailyCron.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import java.time.Clock; import java.time.Duration; @@ -30,9 +30,9 @@ public class DailyCron implements Runnable { private static final Logger LOG = LogManager.getLogger(DailyCron.class); protected static final String FIELD_MODEL = "queue"; - static final String CANNOT_DELETE_OLD_CHECKPOINT_MSG = "Cannot delete old checkpoint."; - static final String CHECKPOINT_NOT_EXIST_MSG = "Checkpoint index does not exist."; - static final String CHECKPOINT_DELETED_MSG = "checkpoint docs get deleted"; + public static final String CANNOT_DELETE_OLD_CHECKPOINT_MSG = "Cannot delete old checkpoint."; + public static final String CHECKPOINT_NOT_EXIST_MSG = "Checkpoint index does not exist."; + public static final String CHECKPOINT_DELETED_MSG = "checkpoint docs get deleted"; private final Clock clock; private final Duration checkpointTtl; @@ -54,7 +54,7 @@ public void run() { QueryBuilders .rangeQuery(CommonName.TIMESTAMP) .lte(clock.millis() - checkpointTtl.toMillis()) - .format(ADCommonName.EPOCH_MILLIS_FORMAT) + .format(CommonName.EPOCH_MILLIS_FORMAT) ) ) .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN); diff --git a/src/main/java/org/opensearch/ad/cluster/HashRing.java b/src/main/java/org/opensearch/timeseries/cluster/HashRing.java similarity index 76% rename from src/main/java/org/opensearch/ad/cluster/HashRing.java rename to src/main/java/org/opensearch/timeseries/cluster/HashRing.java index 30ea1724f..759d70113 100644 --- a/src/main/java/org/opensearch/ad/cluster/HashRing.java +++ b/src/main/java/org/opensearch/timeseries/cluster/HashRing.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; @@ -35,7 +35,7 @@ import org.opensearch.action.admin.cluster.node.info.NodeInfo; import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; @@ -69,11 +69,11 @@ public class HashRing { // Semaphore to control only 1 thread can build AD hash ring. private Semaphore buildHashRingSemaphore; - // This field is to track AD version of all nodes. - // Key: node id; Value: AD node info - private Map nodeAdVersions; - // This field records AD version hash ring in realtime way. Historical detection will use this hash ring. - // Key: AD version; Value: hash ring which only contains eligible data nodes + // This field is to track time series plugin version of all nodes. + // Key: node id; Value: node info + private Map nodeVersions; + // This field records time series version hash ring in realtime way. Historical detection will use this hash ring. + // Key: time series version; Value: hash ring which only contains eligible data nodes private TreeMap> circles; // Track if hash ring inited or not. If not inited, the first clusterManager event will try to init it. private AtomicBoolean hashRingInited; @@ -82,8 +82,8 @@ public class HashRing { private long lastUpdateForRealtimeAD; // Cool down period before next hash ring rebuild. We need this as realtime AD needs stable hash ring. private volatile TimeValue coolDownPeriodForRealtimeAD; - // This field records AD version hash ring with cooldown period. Realtime job will use this hash ring. - // Key: AD version; Value: hash ring which only contains eligible data nodes + // This field records time series version hash ring with cooldown period. Realtime job will use this hash ring. + // Key: time series version; Value: hash ring which only contains eligible data nodes private TreeMap> circlesForRealtimeAD; // Record node change event. Will check if there is node change event when rebuild AD hash ring with @@ -95,7 +95,7 @@ public class HashRing { private final ADDataMigrator dataMigrator; private final Clock clock; private final Client client; - private final ModelManager modelManager; + private final ADModelManager modelManager; public HashRing( DiscoveryNodeFilterer nodeFilter, @@ -104,7 +104,7 @@ public HashRing( Client client, ClusterService clusterService, ADDataMigrator dataMigrator, - ModelManager modelManager + ADModelManager modelManager ) { this.nodeFilter = nodeFilter; this.buildHashRingSemaphore = new Semaphore(1); @@ -116,7 +116,7 @@ public HashRing( this.client = client; this.clusterService = clusterService; this.dataMigrator = dataMigrator; - this.nodeAdVersions = new ConcurrentHashMap<>(); + this.nodeVersions = new ConcurrentHashMap<>(); this.circles = new TreeMap<>(); this.circlesForRealtimeAD = new TreeMap<>(); this.hashRingInited = new AtomicBoolean(false); @@ -129,17 +129,17 @@ public boolean isHashRingInited() { } /** - * Build AD version based circles with discovery node delta change. Listen to clusterManager event in - * {@link ADClusterEventListener#clusterChanged(ClusterChangedEvent)}. + * Build version based circles with discovery node delta change. Listen to clusterManager event in + * {@link ClusterEventListener#clusterChanged(ClusterChangedEvent)}. * Will remove the removed nodes from cache and send request to newly added nodes to get their - * plugin information; then add new nodes to AD version hash ring. + * plugin information; then add new nodes to version hash ring. * * @param delta discovery node delta change * @param listener action listener */ public void buildCircles(DiscoveryNodes.Delta delta, ActionListener listener) { if (!buildHashRingSemaphore.tryAcquire()) { - LOG.info("AD version hash ring change is in progress. Can't build hash ring for node delta event."); + LOG.info("hash ring change is in progress. Can't build hash ring for node delta event."); listener.onResponse(false); return; } @@ -151,14 +151,14 @@ public void buildCircles(DiscoveryNodes.Delta delta, ActionListener lis } /** - * Build AD version based circles by comparing with all eligible data nodes. + * Build version based circles by comparing with all eligible data nodes. * 1. Remove nodes which are not eligible now; - * 2. Add nodes which are not in AD version circles. + * 2. Add nodes which are not in version circles. * @param actionListener action listener */ public void buildCircles(ActionListener actionListener) { if (!buildHashRingSemaphore.tryAcquire()) { - LOG.info("AD version hash ring change is in progress. Can't rebuild hash ring."); + LOG.info("hash ring change is in progress. Can't rebuild hash ring."); actionListener.onResponse(false); return; } @@ -167,35 +167,35 @@ public void buildCircles(ActionListener actionListener) { for (DiscoveryNode node : allNodes) { nodeIds.add(node.getId()); } - Set currentNodeIds = nodeAdVersions.keySet(); + Set currentNodeIds = nodeVersions.keySet(); Set removedNodeIds = Sets.difference(currentNodeIds, nodeIds); Set addedNodeIds = Sets.difference(nodeIds, currentNodeIds); buildCircles(removedNodeIds, addedNodeIds, actionListener); } - public void buildCirclesForRealtimeAD() { + public void buildCirclesForRealtime() { if (nodeChangeEvents.isEmpty()) { return; } - buildCircles(ActionListener.wrap(r -> { LOG.debug("build circles on AD versions successfully"); }, e -> { - LOG.error("Failed to build circles on AD versions", e); - })); + buildCircles( + ActionListener.wrap(r -> { LOG.debug("build circles successfully"); }, e -> { LOG.error("Failed to build circles", e); }) + ); } /** - * Build AD version hash ring. - * 1. Delete removed nodes from AD version hash ring. - * 2. Add new nodes to AD version hash ring + * Build version hash ring. + * 1. Delete removed nodes from version hash ring. + * 2. Add new nodes to version hash ring * - * If fail to acquire semaphore to update AD version hash ring, will return false to + * If fail to acquire semaphore to update version hash ring, will return false to * action listener; otherwise will return true. The "true" response just mean we got * semaphore and finished rebuilding hash ring, but the hash ring may stay the same. * Hash ring changed or not depends on if "removedNodeIds" or "addedNodeIds" is empty. * * We use different way to build hash ring for realtime job and historical analysis - * 1. For historical analysis,if node removed, we remove it immediately from adVersionCircles - * to avoid new AD task routes to it. If new node added, we add it immediately to adVersionCircles - * to make load more balanced and speed up AD task running. + * 1. For historical analysis,if node removed, we remove it immediately from version circles + * to avoid new task routes to it. If new node added, we add it immediately to version circles + * to make load more balanced and speed up task running. * 2. For realtime job, we don't record which node running detector's model partition. We just * use hash ring to get owning node. If we rebuild hash ring frequently, realtime job may get * different owning node and need to restore model on new owning node. If that happens a lot, @@ -205,7 +205,7 @@ public void buildCirclesForRealtimeAD() { * and still send RCF request to it. If new node added during cooldown period, realtime job won't * choose it as model partition owning node, thus we may have skewed load on data nodes. * - * [Important!]: When you call this function, make sure you TRY ACQUIRE adVersionCircleInProgress first. + * [Important!]: When you call this function, make sure you TRY ACQUIRE buildHashRingSemaphore first. * Check {@link HashRing#buildCircles(ActionListener)} and * {@link HashRing#buildCircles(DiscoveryNodes.Delta, ActionListener)} * @@ -222,10 +222,10 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, if (removedNodeIds != null && removedNodeIds.size() > 0) { LOG.info("Node removed: {}", Arrays.toString(removedNodeIds.toArray(new String[0]))); for (String nodeId : removedNodeIds) { - ADNodeInfo nodeInfo = nodeAdVersions.remove(nodeId); + TimeSeriesNodeInfo nodeInfo = nodeVersions.remove(nodeId); if (nodeInfo != null && nodeInfo.isEligibleDataNode()) { - removeNodeFromCircles(nodeId, nodeInfo.getAdVersion()); - LOG.info("Remove data node from AD version hash ring: {}", nodeId); + removeNodeFromCircles(nodeId, nodeInfo.getVersion()); + LOG.info("Remove data node from version hash ring: {}", nodeId); } } } @@ -234,12 +234,12 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, if (addedNodeIds != null) { allAddedNodes.addAll(addedNodeIds); } - if (!nodeAdVersions.containsKey(localNode.getId())) { + if (!nodeVersions.containsKey(localNode.getId())) { allAddedNodes.add(localNode.getId()); } if (allAddedNodes.size() == 0) { actionListener.onResponse(true); - // rebuild AD version hash ring with cooldown. + // rebuild version hash ring with cooldown. rebuildCirclesForRealtimeAD(); buildHashRingSemaphore.release(); return; @@ -264,15 +264,16 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, } TreeMap circle = null; for (PluginInfo pluginInfo : plugins.getPluginInfos()) { + // if (AD_PLUGIN_NAME.equals(pluginInfo.getName()) || AD_PLUGIN_NAME_FOR_TEST.equals(pluginInfo.getName())) { if (CommonName.TIME_SERIES_PLUGIN_NAME.equals(pluginInfo.getName()) || CommonName.TIME_SERIES_PLUGIN_NAME_FOR_TEST.equals(pluginInfo.getName())) { - Version version = ADVersionUtil.fromString(pluginInfo.getVersion()); + Version version = VersionUtil.fromString(pluginInfo.getVersion()); boolean eligibleNode = nodeFilter.isEligibleNode(curNode); if (eligibleNode) { circle = circles.computeIfAbsent(version, key -> new TreeMap<>()); - LOG.info("Add data node to AD version hash ring: {}", curNode.getId()); + LOG.info("Add data node to version hash ring: {}", curNode.getId()); } - nodeAdVersions.put(curNode.getId(), new ADNodeInfo(version, eligibleNode)); + nodeVersions.put(curNode.getId(), new TimeSeriesNodeInfo(version, eligibleNode)); break; } } @@ -283,15 +284,15 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, } } } - LOG.info("All nodes with known AD version: {}", nodeAdVersions); + LOG.info("All nodes with known version: {}", nodeVersions); - // rebuild AD version hash ring with cooldown after all new node added. + // rebuild version hash ring with cooldown after all new node added. rebuildCirclesForRealtimeAD(); if (!dataMigrator.isMigrated() && circles.size() > 0) { - // Find owning node with highest AD version to make sure the data migration logic be compatible to - // latest AD version when upgrade. - Optional owningNode = getOwningNodeWithHighestAdVersion(DEFAULT_HASH_RING_MODEL_ID); + // Find owning node with highest version to make sure the data migration logic be compatible to + // latest version when upgrade. + Optional owningNode = getOwningNodeWithHighestVersion(DEFAULT_HASH_RING_MODEL_ID); String localNodeId = localNode.getId(); if (owningNode.isPresent() && localNodeId.equals(owningNode.get().getId())) { dataMigrator.migrateData(); @@ -305,18 +306,18 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, }, e -> { buildHashRingSemaphore.release(); actionListener.onFailure(e); - LOG.error("Fail to get node info to build AD version hash ring", e); + LOG.error("Fail to get node info to build hash ring", e); })); } catch (Exception e) { - LOG.error("Failed to build AD version circles", e); + LOG.error("Failed to build circles", e); buildHashRingSemaphore.release(); actionListener.onFailure(e); } } - private void removeNodeFromCircles(String nodeId, Version adVersion) { - if (adVersion != null) { - TreeMap circle = this.circles.get(adVersion); + private void removeNodeFromCircles(String nodeId, Version version) { + if (version != null) { + TreeMap circle = this.circles.get(version); List deleted = new ArrayList<>(); for (Map.Entry entry : circle.entrySet()) { if (entry.getValue().getId().equals(nodeId)) { @@ -324,7 +325,7 @@ private void removeNodeFromCircles(String nodeId, Version adVersion) { } } if (deleted.size() == circle.size()) { - circles.remove(adVersion); + circles.remove(version); } else { for (Integer key : deleted) { circle.remove(key); @@ -336,7 +337,7 @@ private void removeNodeFromCircles(String nodeId, Version adVersion) { private void rebuildCirclesForRealtimeAD() { // Check if it's eligible to rebuild hash ring with cooldown if (eligibleToRebuildCirclesForRealtimeAD()) { - LOG.info("Rebuild AD hash ring for realtime AD with cooldown, nodeChangeEvents size {}", nodeChangeEvents.size()); + LOG.info("Rebuild hash ring for realtime with cooldown, nodeChangeEvents size {}", nodeChangeEvents.size()); int size = nodeChangeEvents.size(); TreeMap> newCircles = new TreeMap<>(); for (Map.Entry> entry : circles.entrySet()) { @@ -344,17 +345,17 @@ private void rebuildCirclesForRealtimeAD() { } circlesForRealtimeAD = newCircles; lastUpdateForRealtimeAD = clock.millis(); - LOG.info("Build AD version hash ring successfully"); + LOG.info("Build version hash ring successfully"); String localNodeId = clusterService.localNode().getId(); Set modelIds = modelManager.getAllModelIds(); for (String modelId : modelIds) { - Optional node = getOwningNodeWithSameLocalAdVersionForRealtimeAD(modelId); + Optional node = getOwningNodeWithSameLocalVersionForRealtime(modelId); if (node.isPresent() && !node.get().getId().equals(localNodeId)) { LOG.info(REMOVE_MODEL_MSG + " {}", modelId); modelManager .stopModel( // stopModel will clear model cache - SingleStreamModelIdMapper.getDetectorIdForModelId(modelId), + SingleStreamModelIdMapper.getConfigIdForModelId(modelId), modelId, ActionListener .wrap( @@ -366,7 +367,7 @@ private void rebuildCirclesForRealtimeAD() { } // It's possible that multiple threads add new event to nodeChangeEvents, // but this is the only place to consume/poll the event and there is only - // one thread poll it as we are using adVersionCircleInProgress semaphore(1) + // one thread poll it as we are using buildHashRingSemaphore // to control only 1 thread build hash ring. while (size-- > 0) { Boolean poll = nodeChangeEvents.poll(); @@ -397,7 +398,7 @@ private void rebuildCirclesForRealtimeAD() { * * @return true if it's eligible to rebuild hash ring */ - protected boolean eligibleToRebuildCirclesForRealtimeAD() { + public boolean eligibleToRebuildCirclesForRealtimeAD() { // Check if there is any node change event if (nodeChangeEvents.isEmpty() && !circlesForRealtimeAD.isEmpty()) { return false; @@ -412,71 +413,71 @@ protected boolean eligibleToRebuildCirclesForRealtimeAD() { } /** - * Get owning node with highest AD version circle. + * Get owning node with highest version circle. * @param modelId model id * @return owning node */ - public Optional getOwningNodeWithHighestAdVersion(String modelId) { + public Optional getOwningNodeWithHighestVersion(String modelId) { int modelHash = Murmur3HashFunction.hash(modelId); Map.Entry> versionTreeMapEntry = circles.lastEntry(); if (versionTreeMapEntry == null) { return Optional.empty(); } - TreeMap adVersionCircle = versionTreeMapEntry.getValue(); - Map.Entry entry = adVersionCircle.higherEntry(modelHash); - return Optional.ofNullable(Optional.ofNullable(entry).orElse(adVersionCircle.firstEntry())).map(x -> x.getValue()); + TreeMap versionCircle = versionTreeMapEntry.getValue(); + Map.Entry entry = versionCircle.higherEntry(modelHash); + return Optional.ofNullable(Optional.ofNullable(entry).orElse(versionCircle.firstEntry())).map(x -> x.getValue()); } /** - * Get owning node with same AD version of local node. + * Get owning node with same version of local node. * @param modelId model id * @param function consumer function * @param listener action listener * @param listener response type */ - public void buildAndGetOwningNodeWithSameLocalAdVersion( + public void buildAndGetOwningNodeWithSameLocalVersion( String modelId, Consumer> function, ActionListener listener ) { buildCircles(ActionListener.wrap(r -> { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Optional owningNode = getOwningNodeWithSameAdVersionDirectly(modelId, adVersion, false); + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Optional owningNode = getOwningNodeWithSameVersionDirectly(modelId, version, false); function.accept(owningNode); }, e -> listener.onFailure(e))); } - public Optional getOwningNodeWithSameLocalAdVersionForRealtimeAD(String modelId) { + public Optional getOwningNodeWithSameLocalVersionForRealtime(String modelId) { try { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Optional owningNode = getOwningNodeWithSameAdVersionDirectly(modelId, adVersion, true); + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Optional owningNode = getOwningNodeWithSameVersionDirectly(modelId, version, true); // rebuild hash ring - buildCirclesForRealtimeAD(); + buildCirclesForRealtime(); return owningNode; } catch (Exception e) { - LOG.error("Failed to get owning node with same local AD version", e); + LOG.error("Failed to get owning node with same local time series version", e); return Optional.empty(); } } - private Optional getOwningNodeWithSameAdVersionDirectly(String modelId, Version adVersion, boolean forRealtime) { + private Optional getOwningNodeWithSameVersionDirectly(String modelId, Version version, boolean forRealtime) { int modelHash = Murmur3HashFunction.hash(modelId); - TreeMap adVersionCircle = forRealtime ? circlesForRealtimeAD.get(adVersion) : circles.get(adVersion); - if (adVersionCircle != null) { - Map.Entry entry = adVersionCircle.higherEntry(modelHash); - return Optional.ofNullable(Optional.ofNullable(entry).orElse(adVersionCircle.firstEntry())).map(x -> x.getValue()); + TreeMap versionCircle = forRealtime ? circlesForRealtimeAD.get(version) : circles.get(version); + if (versionCircle != null) { + Map.Entry entry = versionCircle.higherEntry(modelHash); + return Optional.ofNullable(Optional.ofNullable(entry).orElse(versionCircle.firstEntry())).map(x -> x.getValue()); } return Optional.empty(); } - public void getNodesWithSameLocalAdVersion(Consumer function, ActionListener listener) { + public void getNodesWithSameLocalVersion(Consumer function, ActionListener listener) { buildCircles(ActionListener.wrap(updated -> { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Set nodes = getNodesWithSameAdVersion(adVersion, false); - if (!nodeAdVersions.containsKey(localNode.getId())) { + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Set nodes = getNodesWithSameVersion(version, false); + if (!nodeVersions.containsKey(localNode.getId())) { nodes.add(localNode); } // Make sure listener return in function @@ -484,17 +485,17 @@ public void getNodesWithSameLocalAdVersion(Consumer functio }, e -> listener.onFailure(e))); } - public DiscoveryNode[] getNodesWithSameLocalAdVersion() { + public DiscoveryNode[] getNodesWithSameLocalVersion() { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Set nodes = getNodesWithSameAdVersion(adVersion, false); + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Set nodes = getNodesWithSameVersion(version, false); // rebuild hash ring - buildCirclesForRealtimeAD(); + buildCirclesForRealtime(); return nodes.toArray(new DiscoveryNode[0]); } - protected Set getNodesWithSameAdVersion(Version adVersion, boolean forRealtime) { - TreeMap circle = forRealtime ? circlesForRealtimeAD.get(adVersion) : circles.get(adVersion); + public Set getNodesWithSameVersion(Version version, boolean forRealtime) { + TreeMap circle = forRealtime ? circlesForRealtimeAD.get(version) : circles.get(version); Set nodeIds = new HashSet<>(); Set nodes = new HashSet<>(); if (circle == null) { @@ -511,13 +512,13 @@ protected Set getNodesWithSameAdVersion(Version adVersion, boolea } /** - * Get AD version. + * Get time series version. * @param nodeId node id - * @return AD version + * @return version */ - public Version getAdVersion(String nodeId) { - ADNodeInfo adNodeInfo = nodeAdVersions.get(nodeId); - return adNodeInfo == null ? null : adNodeInfo.getAdVersion(); + public Version getVersion(String nodeId) { + TimeSeriesNodeInfo nodeInfo = nodeVersions.get(nodeId); + return nodeInfo == null ? null : nodeInfo.getVersion(); } /** @@ -561,17 +562,17 @@ private String getIpAddress(TransportAddress address) { } /** - * Get all eligible data nodes whose AD versions are known in AD version based hash ring. + * Get all eligible data nodes whose time series versions are known in hash ring. * @param function consumer function * @param listener action listener * @param action listener response type */ - public void getAllEligibleDataNodesWithKnownAdVersion(Consumer function, ActionListener listener) { + public void getAllEligibleDataNodesWithKnownVersion(Consumer function, ActionListener listener) { buildCircles(ActionListener.wrap(r -> { DiscoveryNode[] eligibleDataNodes = nodeFilter.getEligibleDataNodes(); List allNodes = new ArrayList<>(); for (DiscoveryNode node : eligibleDataNodes) { - if (nodeAdVersions.containsKey(node.getId())) { + if (nodeVersions.containsKey(node.getId())) { allNodes.add(node); } } diff --git a/src/main/java/org/opensearch/ad/cluster/HourlyCron.java b/src/main/java/org/opensearch/timeseries/cluster/HourlyCron.java similarity index 84% rename from src/main/java/org/opensearch/ad/cluster/HourlyCron.java rename to src/main/java/org/opensearch/timeseries/cluster/HourlyCron.java index 687aace69..4381ade8d 100644 --- a/src/main/java/org/opensearch/ad/cluster/HourlyCron.java +++ b/src/main/java/org/opensearch/timeseries/cluster/HourlyCron.java @@ -9,23 +9,23 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.FailedNodeException; import org.opensearch.ad.transport.CronAction; -import org.opensearch.ad.transport.CronRequest; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.transport.CronRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; public class HourlyCron implements Runnable { private static final Logger LOG = LogManager.getLogger(HourlyCron.class); - static final String SUCCEEDS_LOG_MSG = "Hourly maintenance succeeds"; - static final String NODE_EXCEPTION_LOG_MSG = "Hourly maintenance of node has exception"; - static final String EXCEPTION_LOG_MSG = "Hourly maintenance has exception."; + public static final String SUCCEEDS_LOG_MSG = "Hourly maintenance succeeds"; + public static final String NODE_EXCEPTION_LOG_MSG = "Hourly maintenance of node has exception"; + public static final String EXCEPTION_LOG_MSG = "Hourly maintenance has exception."; private DiscoveryNodeFilterer nodeFilter; private Client client; diff --git a/src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java b/src/main/java/org/opensearch/timeseries/cluster/TimeSeriesNodeInfo.java similarity index 70% rename from src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java rename to src/main/java/org/opensearch/timeseries/cluster/TimeSeriesNodeInfo.java index e438623d5..f67d663ae 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java +++ b/src/main/java/org/opensearch/timeseries/cluster/TimeSeriesNodeInfo.java @@ -9,25 +9,25 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import org.opensearch.Version; /** - * This class records AD version of nodes and whether node is eligible data node to run AD. + * This class records time series plugin version of nodes and whether node is eligible data node to run time series analysis. */ -public class ADNodeInfo { - // AD plugin version +public class TimeSeriesNodeInfo { + // time series plugin version private Version adVersion; // Is node eligible to run AD. private boolean isEligibleDataNode; - public ADNodeInfo(Version version, boolean isEligibleDataNode) { + public TimeSeriesNodeInfo(Version version, boolean isEligibleDataNode) { this.adVersion = version; this.isEligibleDataNode = isEligibleDataNode; } - public Version getAdVersion() { + public Version getVersion() { return adVersion; } diff --git a/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java b/src/main/java/org/opensearch/timeseries/cluster/VersionUtil.java similarity index 95% rename from src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java rename to src/main/java/org/opensearch/timeseries/cluster/VersionUtil.java index 7e880de66..8d506732d 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java +++ b/src/main/java/org/opensearch/timeseries/cluster/VersionUtil.java @@ -9,12 +9,12 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import org.opensearch.Version; import org.opensearch.timeseries.constant.CommonName; -public class ADVersionUtil { +public class VersionUtil { public static final int VERSION_SEGMENTS = 3; diff --git a/src/main/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetention.java b/src/main/java/org/opensearch/timeseries/cluster/diskcleanup/BaseModelCheckpointIndexRetention.java similarity index 86% rename from src/main/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetention.java rename to src/main/java/org/opensearch/timeseries/cluster/diskcleanup/BaseModelCheckpointIndexRetention.java index 28fc05e37..6885d611e 100644 --- a/src/main/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetention.java +++ b/src/main/java/org/opensearch/timeseries/cluster/diskcleanup/BaseModelCheckpointIndexRetention.java @@ -9,14 +9,13 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster.diskcleanup; +package org.opensearch.timeseries.cluster.diskcleanup; import java.time.Clock; import java.time.Duration; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.action.ActionListener; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.QueryBuilders; @@ -34,8 +33,8 @@ * We will keep the this logic, and add new clean up way based on shard size. *

*/ -public class ModelCheckpointIndexRetention implements Runnable { - private static final Logger LOG = LogManager.getLogger(ModelCheckpointIndexRetention.class); +public class BaseModelCheckpointIndexRetention implements Runnable { + private static final Logger LOG = LogManager.getLogger(BaseModelCheckpointIndexRetention.class); // The recommended max shard size is 50G, we don't wanna our index exceeds this number private static final long MAX_SHARD_SIZE_IN_BYTE = 50 * 1024 * 1024 * 1024L; @@ -46,25 +45,32 @@ public class ModelCheckpointIndexRetention implements Runnable { private final Duration defaultCheckpointTtl; private final Clock clock; private final IndexCleanup indexCleanup; + private final String checkpointIndexName; - public ModelCheckpointIndexRetention(Duration defaultCheckpointTtl, Clock clock, IndexCleanup indexCleanup) { + public BaseModelCheckpointIndexRetention( + Duration defaultCheckpointTtl, + Clock clock, + IndexCleanup indexCleanup, + String checkpointIndexName + ) { this.defaultCheckpointTtl = defaultCheckpointTtl; this.clock = clock; this.indexCleanup = indexCleanup; + this.checkpointIndexName = checkpointIndexName; } @Override public void run() { indexCleanup .deleteDocsByQuery( - ADCommonName.CHECKPOINT_INDEX_NAME, + checkpointIndexName, QueryBuilders .boolQuery() .filter( QueryBuilders .rangeQuery(CommonName.TIMESTAMP) .lte(clock.millis() - defaultCheckpointTtl.toMillis()) - .format(ADCommonName.EPOCH_MILLIS_FORMAT) + .format(CommonName.EPOCH_MILLIS_FORMAT) ), ActionListener.wrap(response -> { cleanupBasedOnShardSize(defaultCheckpointTtl.minusDays(1)); @@ -79,7 +85,7 @@ public void run() { private void cleanupBasedOnShardSize(Duration cleanUpTtl) { indexCleanup .deleteDocsBasedOnShardSize( - ADCommonName.CHECKPOINT_INDEX_NAME, + checkpointIndexName, MAX_SHARD_SIZE_IN_BYTE, QueryBuilders .boolQuery() @@ -87,7 +93,7 @@ private void cleanupBasedOnShardSize(Duration cleanUpTtl) { QueryBuilders .rangeQuery(CommonName.TIMESTAMP) .lte(clock.millis() - cleanUpTtl.toMillis()) - .format(ADCommonName.EPOCH_MILLIS_FORMAT) + .format(CommonName.EPOCH_MILLIS_FORMAT) ), ActionListener.wrap(cleanupNeeded -> { if (cleanupNeeded) { diff --git a/src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java b/src/main/java/org/opensearch/timeseries/cluster/diskcleanup/IndexCleanup.java similarity index 98% rename from src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java rename to src/main/java/org/opensearch/timeseries/cluster/diskcleanup/IndexCleanup.java index bd37127cb..899f41a73 100644 --- a/src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java +++ b/src/main/java/org/opensearch/timeseries/cluster/diskcleanup/IndexCleanup.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster.diskcleanup; +package org.opensearch.timeseries.cluster.diskcleanup; import java.util.Arrays; import java.util.Objects; diff --git a/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java b/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java index 0576f9693..996ea4dc0 100644 --- a/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java +++ b/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java @@ -13,6 +13,8 @@ import java.util.Locale; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + public class CommonMessages { // ====================================== // Validation message @@ -65,6 +67,11 @@ public static String getTooManyCategoricalFieldErr(int limit) { public static String FEATURE_QUERY_TOO_SPARSE = "Data is most likely too sparse when given feature queries are applied. Consider revising feature queries."; public static String TIMEOUT_ON_INTERVAL_REC = "Timed out getting interval recommendation"; + public static final String NOT_EXISTENT_VALIDATION_TYPE = "The given validation type doesn't exist"; + public static final String NOT_EXISTENT_SUGGEST_TYPE = "The given suggest type doesn't exist"; + public static final String DESCRIPTION_LENGTH_TOO_LONG = "Description length is too long. Max length is " + + TimeSeriesSettings.MAX_DESCRIPTION_LENGTH + + " characters."; // ====================================== // Index message @@ -77,7 +84,7 @@ public static String getTooManyCategoricalFieldErr(int limit) { // Resource constraints // ====================================== public static final String MEMORY_CIRCUIT_BROKEN_ERR_MSG = - "The total OpenSearch memory usage exceeds our threshold, opening the AD memory circuit."; + "The total OpenSearch memory usage exceeds our threshold, opening the memory circuit."; // ====================================== // Transport @@ -139,4 +146,9 @@ public static String getTooManyCategoricalFieldErr(int limit) { // Stats API // ====================================== public static String FAIL_TO_GET_STATS = "Fail to get stats"; + + // ====================================== + // Suggest API + // ====================================== + public static String FAIL_SUGGEST_ERR_MSG = "Fail to suggest parameters for "; } diff --git a/src/main/java/org/opensearch/timeseries/constant/CommonName.java b/src/main/java/org/opensearch/timeseries/constant/CommonName.java index 0b997ea5d..d348ff13e 100644 --- a/src/main/java/org/opensearch/timeseries/constant/CommonName.java +++ b/src/main/java/org/opensearch/timeseries/constant/CommonName.java @@ -11,6 +11,8 @@ package org.opensearch.timeseries.constant; +import org.opensearch.timeseries.stats.StatNames; + public class CommonName { // ====================================== @@ -62,7 +64,6 @@ public class CommonName { public static final String EXECUTION_START_TIME_FIELD = "execution_start_time"; public static final String EXECUTION_END_TIME_FIELD = "execution_end_time"; public static final String ERROR_FIELD = "error"; - public static final String ENTITY_FIELD = "entity"; public static final String USER_FIELD = "user"; public static final String CONFIDENCE_FIELD = "confidence"; public static final String DATA_QUALITY_FIELD = "data_quality"; @@ -70,12 +71,15 @@ public class CommonName { public static final String MODEL_ID_FIELD = "model_id"; public static final String TIMESTAMP = "timestamp"; public static final String FIELD_MODEL = "model"; + public static final String ANALYSIS_TYPE_FIELD = "analysis_type"; + public static final String ANSWER_FIELD = "answer"; + public static final String RUN_ONCE_FIELD = "run_once"; // entity sample in checkpoint. // kept for bwc purpose public static final String ENTITY_SAMPLE = "sp"; // current key for entity samples - public static final String ENTITY_SAMPLE_QUEUE = "samples"; + public static final String SAMPLE_QUEUE = "samples"; // ====================================== // Profile name @@ -105,6 +109,7 @@ public class CommonName { public static final String CONFIG_ID_KEY = "config_id"; public static final String MODEL_ID_KEY = "model_id"; public static final String TASK_ID_FIELD = "task_id"; + public static final String TASK = "task"; public static final String ENTITY_ID_FIELD = "entity_id"; // ====================================== @@ -113,4 +118,36 @@ public class CommonName { public static final String TIME_SERIES_PLUGIN_NAME = "opensearch-time-series-analytics"; public static final String TIME_SERIES_PLUGIN_NAME_FOR_TEST = "org.opensearch.timeseries.TimeSeriesAnalyticsPlugin"; public static final String TIME_SERIES_PLUGIN_VERSION_FOR_TEST = "NA"; + + // ====================================== + // Profile name + // ====================================== + public static final String CATEGORICAL_FIELD = "category_field"; + public static final String STATE = "state"; + public static final String ERROR = "error"; + public static final String COORDINATING_NODE = "coordinating_node"; + public static final String SHINGLE_SIZE = "shingle_size"; + public static final String TOTAL_SIZE_IN_BYTES = "total_size_in_bytes"; + public static final String MODELS = "models"; + public static final String MODEL = "model"; + public static final String INIT_PROGRESS = "init_progress"; + public static final String TOTAL_ENTITIES = "total_entities"; + public static final String ACTIVE_ENTITIES = "active_entities"; + public static final String ENTITY_INFO = "entity_info"; + public static final String TOTAL_UPDATES = "total_updates"; + public static final String MODEL_COUNT = StatNames.MODEL_COUNT.getName(); + + // ====================================== + // Ultrawarm node attributes + // ====================================== + // hot node + public static String HOT_BOX_TYPE = "hot"; + // warm node + public static String WARM_BOX_TYPE = "warm"; + // box type + public static final String BOX_TYPE_KEY = "box_type"; + // ====================================== + // Format name + // ====================================== + public static final String EPOCH_MILLIS_FORMAT = "epoch_millis"; } diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/FixedValueImputer.java b/src/main/java/org/opensearch/timeseries/dataprocessor/FixedValueImputer.java deleted file mode 100644 index 9b8f6bf21..000000000 --- a/src/main/java/org/opensearch/timeseries/dataprocessor/FixedValueImputer.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.timeseries.dataprocessor; - -import java.util.Arrays; - -/** - * fixing missing value (denoted using Double.NaN) using a fixed set of specified values. - * The 2nd parameter of interpolate is ignored as we infer the number of imputed values - * using the number of Double.NaN. - */ -public class FixedValueImputer extends Imputer { - private double[] fixedValue; - - public FixedValueImputer(double[] fixedValue) { - this.fixedValue = fixedValue; - } - - /** - * Given an array of samples, fill with given value. - * We will ignore the rest of samples beyond the 2nd element. - * - * @return an imputed array of size numImputed - */ - @Override - public double[][] impute(double[][] samples, int numImputed) { - int numFeatures = samples.length; - double[][] imputed = new double[numFeatures][numImputed]; - - for (int featureIndex = 0; featureIndex < numFeatures; featureIndex++) { - imputed[featureIndex] = singleFeatureInterpolate(samples[featureIndex], numImputed, fixedValue[featureIndex]); - } - return imputed; - } - - private double[] singleFeatureInterpolate(double[] samples, int numInterpolants, double defaultVal) { - return Arrays.stream(samples).map(d -> Double.isNaN(d) ? defaultVal : d).toArray(); - } - - @Override - protected double[] singleFeatureImpute(double[] samples, int numInterpolants) { - throw new UnsupportedOperationException("The operation is not supported"); - } -} diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/Imputer.java b/src/main/java/org/opensearch/timeseries/dataprocessor/Imputer.java index 4e885421c..801489c7b 100644 --- a/src/main/java/org/opensearch/timeseries/dataprocessor/Imputer.java +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/Imputer.java @@ -5,6 +5,8 @@ package org.opensearch.timeseries.dataprocessor; +import static org.apache.commons.math3.linear.MatrixUtils.createRealMatrix; + /* * An object for imputing feature vectors. * @@ -24,18 +26,20 @@ public abstract class Imputer { * `numFeatures`. * * - * @param samples A `numFeatures x numSamples` list of feature vectors. + * @param samples A `numSamples x numFeatures` list of feature vectors. * @param numImputed The desired number of imputed vectors. - * @return A `numFeatures x numImputed` list of feature vectors. + * @return A `numSamples x numFeatures` list of feature vectors. */ public double[][] impute(double[][] samples, int numImputed) { - int numFeatures = samples.length; - double[][] interpolants = new double[numFeatures][numImputed]; - + // convert to a `numFeatures x numSamples` list of feature vectors + double[][] transposed = transpose(samples); + int numFeatures = transposed.length; + double[][] imputants = new double[numFeatures][numImputed]; for (int featureIndex = 0; featureIndex < numFeatures; featureIndex++) { - interpolants[featureIndex] = singleFeatureImpute(samples[featureIndex], numImputed); + imputants[featureIndex] = singleFeatureImpute(transposed[featureIndex], numImputed); } - return interpolants; + // transpose back to a `numSamples x numFeatures` list of feature vectors + return transpose(imputants); } /** @@ -45,4 +49,8 @@ public double[][] impute(double[][] samples, int numImputed) { * @return input array with missing values imputed */ protected abstract double[] singleFeatureImpute(double[] samples, int numImputed); + + private double[][] transpose(double[][] matrix) { + return createRealMatrix(matrix).transpose().getData(); + } } diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputer.java b/src/main/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputer.java deleted file mode 100644 index e91c90814..000000000 --- a/src/main/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputer.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.timeseries.dataprocessor; - -/** - * Given an array of samples, fill missing values (represented using Double.NaN) - * with previous value. - * The return array may be smaller than the input array as we remove leading missing - * values after interpolation. If the first sample is Double.NaN - * as there is no last known value to fill in. - * The 2nd parameter of interpolate is ignored as we infer the number of imputed values - * using the number of Double.NaN. - * - */ -public class PreviousValueImputer extends Imputer { - - @Override - protected double[] singleFeatureImpute(double[] samples, int numInterpolants) { - int numSamples = samples.length; - double[] interpolants = new double[numSamples]; - - if (numSamples > 0) { - System.arraycopy(samples, 0, interpolants, 0, samples.length); - if (numSamples > 1) { - double lastKnownValue = Double.NaN; - for (int interpolantIndex = 0; interpolantIndex < numSamples; interpolantIndex++) { - if (Double.isNaN(interpolants[interpolantIndex])) { - if (!Double.isNaN(lastKnownValue)) { - interpolants[interpolantIndex] = lastKnownValue; - } - } else { - lastKnownValue = interpolants[interpolantIndex]; - } - } - } - } - return interpolants; - } -} diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/ZeroImputer.java b/src/main/java/org/opensearch/timeseries/dataprocessor/ZeroImputer.java deleted file mode 100644 index 1d0656de1..000000000 --- a/src/main/java/org/opensearch/timeseries/dataprocessor/ZeroImputer.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.timeseries.dataprocessor; - -import java.util.Arrays; - -/** - * fixing missing value (denoted using Double.NaN) using 0. - * The 2nd parameter of impute is ignored as we infer the number - * of imputed values using the number of Double.NaN. - */ -public class ZeroImputer extends Imputer { - - @Override - public double[] singleFeatureImpute(double[] samples, int numInterpolants) { - return Arrays.stream(samples).map(d -> Double.isNaN(d) ? 0.0 : d).toArray(); - } -} diff --git a/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java b/src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java similarity index 99% rename from src/main/java/org/opensearch/ad/feature/AbstractRetriever.java rename to src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java index 886dbcbc4..5f2609ed5 100644 --- a/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java +++ b/src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.util.Arrays; import java.util.Iterator; diff --git a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java b/src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java similarity index 93% rename from src/main/java/org/opensearch/ad/feature/CompositeRetriever.java rename to src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java index 3c9a1632a..f4cae0c0e 100644 --- a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java +++ b/src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.io.IOException; import java.time.Clock; @@ -27,7 +27,6 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.IndicesOptions; -import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -46,6 +45,7 @@ import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.util.ParseUtils; @@ -66,7 +66,7 @@ public class CompositeRetriever extends AbstractRetriever { private final long dataStartEpoch; private final long dataEndEpoch; - private final AnomalyDetector anomalyDetector; + private final Config config; private final NamedXContentRegistry xContent; private final Client client; private final SecurityClientUtil clientUtil; @@ -78,11 +78,12 @@ public class CompositeRetriever extends AbstractRetriever { private Clock clock; private IndexNameExpressionResolver indexNameExpressionResolver; private ClusterService clusterService; + private AnalysisType context; public CompositeRetriever( long dataStartEpoch, long dataEndEpoch, - AnomalyDetector anomalyDetector, + Config config, NamedXContentRegistry xContent, Client client, SecurityClientUtil clientUtil, @@ -92,11 +93,12 @@ public CompositeRetriever( int maxEntitiesPerInterval, int pageSize, IndexNameExpressionResolver indexNameExpressionResolver, - ClusterService clusterService + ClusterService clusterService, + AnalysisType context ) { this.dataStartEpoch = dataStartEpoch; this.dataEndEpoch = dataEndEpoch; - this.anomalyDetector = anomalyDetector; + this.config = config; this.xContent = xContent; this.client = client; this.clientUtil = clientUtil; @@ -107,13 +109,14 @@ public CompositeRetriever( this.clock = clock; this.indexNameExpressionResolver = indexNameExpressionResolver; this.clusterService = clusterService; + this.context = context; } // a constructor that provide default value of clock public CompositeRetriever( long dataStartEpoch, long dataEndEpoch, - AnomalyDetector anomalyDetector, + Config anomalyDetector, NamedXContentRegistry xContent, Client client, SecurityClientUtil clientUtil, @@ -122,7 +125,8 @@ public CompositeRetriever( int maxEntitiesPerInterval, int pageSize, IndexNameExpressionResolver indexNameExpressionResolver, - ClusterService clusterService + ClusterService clusterService, + AnalysisType context ) { this( dataStartEpoch, @@ -137,7 +141,8 @@ public CompositeRetriever( maxEntitiesPerInterval, pageSize, indexNameExpressionResolver, - clusterService + clusterService, + context ); } @@ -147,21 +152,21 @@ public CompositeRetriever( * detector definition */ public PageIterator iterator() throws IOException { - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(anomalyDetector.getTimeField()) + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(config.getTimeField()) .gte(dataStartEpoch) .lt(dataEndEpoch) .format("epoch_millis"); - BoolQueryBuilder internalFilterQuery = new BoolQueryBuilder().filter(anomalyDetector.getFilterQuery()).filter(rangeQuery); + BoolQueryBuilder internalFilterQuery = new BoolQueryBuilder().filter(config.getFilterQuery()).filter(rangeQuery); // multiple categorical fields are supported CompositeAggregationBuilder composite = AggregationBuilders .composite( AGG_NAME_COMP, - anomalyDetector.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) + config.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) ) .size(pageSize); - for (Feature feature : anomalyDetector.getFeatureAttributes()) { + for (Feature feature : config.getFeatureAttributes()) { AggregatorFactories.Builder internalAgg = ParseUtils .parseAggregators(feature.getAggregation().toString(), xContent, feature.getId()); composite.subAggregation(internalAgg.getAggregatorFactories().iterator().next()); @@ -201,7 +206,7 @@ public void next(ActionListener listener) { // inject user role while searching. - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0]), source); + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0]), source); final ActionListener searchResponseListener = new ActionListener() { @Override public void onResponse(SearchResponse response) { @@ -219,9 +224,9 @@ public void onFailure(Exception e) { .asyncRequestWithInjectedSecurity( searchRequest, client::search, - anomalyDetector.getId(), + config.getId(), client, - AnalysisType.AD, + context, searchResponseListener ); } @@ -291,7 +296,7 @@ private Page analyzePage(SearchResponse response) { } */ for (Bucket bucket : composite.getBuckets()) { - Optional featureValues = parseBucket(bucket, anomalyDetector.getEnabledFeatureIds()); + Optional featureValues = parseBucket(bucket, config.getEnabledFeatureIds()); // bucket.getKey() returns a map of categorical field like "host" and its value like "server_1" if (featureValues.isPresent() && bucket.getKey() != null) { results.put(Entity.createEntityByReordering(bucket.getKey()), featureValues.get()); @@ -335,7 +340,7 @@ Optional getComposite(SearchResponse response) { // such index // [blah]","index":"blah","resource.id":"blah","resource.type":"index_or_alias","index_uuid":"_na_"},"status":404}% if (response == null || response.getAggregations() == null) { - List sourceIndices = anomalyDetector.getIndices(); + List sourceIndices = config.getIndices(); String[] concreteIndices = indexNameExpressionResolver .concreteIndexNames(clusterService.state(), IndicesOptions.lenientExpandOpen(), sourceIndices.toArray(new String[0])); if (concreteIndices.length == 0) { diff --git a/src/main/java/org/opensearch/ad/feature/FeatureManager.java b/src/main/java/org/opensearch/timeseries/feature/FeatureManager.java similarity index 89% rename from src/main/java/org/opensearch/ad/feature/FeatureManager.java rename to src/main/java/org/opensearch/timeseries/feature/FeatureManager.java index 469f8707e..b8cea9419 100644 --- a/src/main/java/org/opensearch/ad/feature/FeatureManager.java +++ b/src/main/java/org/opensearch/timeseries/feature/FeatureManager.java @@ -9,10 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import static java.util.Arrays.copyOfRange; -import static org.apache.commons.math3.linear.MatrixUtils.createRealMatrix; import java.io.IOException; import java.time.Clock; @@ -23,9 +22,11 @@ import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collections; import java.util.Deque; import java.util.LinkedList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -41,13 +42,14 @@ import org.opensearch.action.support.ThreadedActionListener; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.model.Forecaster; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.CleanState; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.dataprocessor.Imputer; -import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; /** @@ -57,15 +59,13 @@ public class FeatureManager implements CleanState { private static final Logger logger = LogManager.getLogger(FeatureManager.class); - // Each anomaly detector has a queue of data points with timestamps (in epoch milliseconds). + // Each single-stream analysis has a queue of data points with timestamps (in epoch milliseconds). private final Map>>> detectorIdsToTimeShingles; private final SearchFeatureDao searchFeatureDao; - private final Imputer imputer; + public final Imputer imputer; private final Clock clock; - private final int maxTrainSamples; - private final int maxSampleStride; private final int trainSampleTimeRangeInHours; private final int minTrainSamples; private final double maxMissingPointsRate; @@ -82,8 +82,6 @@ public class FeatureManager implements CleanState { * @param searchFeatureDao DAO of features from search * @param imputer imputer of samples * @param clock clock for system time - * @param maxTrainSamples max number of samples from search - * @param maxSampleStride max stride between uninterpolated train samples * @param trainSampleTimeRangeInHours time range in hours for collect train samples * @param minTrainSamples min number of train samples * @param maxMissingPointsRate max proportion of shingle with missing points allowed to generate a shingle @@ -98,8 +96,6 @@ public FeatureManager( SearchFeatureDao searchFeatureDao, Imputer imputer, Clock clock, - int maxTrainSamples, - int maxSampleStride, int trainSampleTimeRangeInHours, int minTrainSamples, double maxMissingPointsRate, @@ -113,8 +109,6 @@ public FeatureManager( this.searchFeatureDao = searchFeatureDao; this.imputer = imputer; this.clock = clock; - this.maxTrainSamples = maxTrainSamples; - this.maxSampleStride = maxSampleStride; this.trainSampleTimeRangeInHours = trainSampleTimeRangeInHours; this.minTrainSamples = minTrainSamples; this.maxMissingPointsRate = maxMissingPointsRate; @@ -174,8 +168,35 @@ public void getCurrentFeatures(AnomalyDetector detector, long startTime, long en } } + public void getCurrentFeatures(Forecaster forecaster, long startTime, long endTime, ActionListener listener) { + List> missingRanges = Collections.singletonList(new SimpleImmutableEntry<>(startTime, endTime)); + try { + searchFeatureDao.getFeatureSamplesForPeriods(forecaster, missingRanges, AnalysisType.FORECAST, ActionListener.wrap(points -> { + // we only have one point + if (points.size() == 1) { + Optional point = points.get(0); + listener.onResponse(new SinglePointFeatures(point, Optional.empty())); + } else { + listener.onResponse(new SinglePointFeatures(Optional.empty(), Optional.empty())); + } + }, listener::onFailure)); + } catch (IOException e) { + listener.onFailure(new EndRunException(forecaster.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); + } + } + + public void getCurrentFeatures(Config config, long startTime, long endTime, ActionListener listener) { + if (config instanceof AnomalyDetector) { + getCurrentFeatures((AnomalyDetector) config, startTime, endTime, listener); + } else if (config instanceof Forecaster) { + getCurrentFeatures((Forecaster) config, startTime, endTime, listener); + } else { + throw new UnsupportedOperationException(String.format(Locale.ROOT, "config type %s is not supported.", config.getClass())); + } + } + private List> getMissingRangesInShingle( - AnomalyDetector detector, + Config detector, Map>> featuresMap, long endTime ) { @@ -207,7 +228,7 @@ private List> getMissingRangesInShingle( * @param listener onResponse is called with unprocessed features and processed features for the current data point. */ private void updateUnprocessedFeatures( - AnomalyDetector detector, + Config detector, Deque>> shingle, Map>> featuresMap, long endTime, @@ -221,17 +242,19 @@ private void updateUnprocessedFeatures( listener.onResponse(getProcessedFeatures(shingle, detector, endTime)); } - private double[][] filterAndFill(Deque>> shingle, long endTime, AnomalyDetector detector) { - int shingleSize = detector.getShingleSize(); + private double[][] filterAndFill(Deque>> shingle, long endTime, Config config) { + double[][] result = null; + + int shingleSize = config.getShingleSize(); Deque>> filteredShingle = shingle .stream() .filter(e -> e.getValue().isPresent()) .collect(Collectors.toCollection(ArrayDeque::new)); - double[][] result = null; + if (filteredShingle.size() >= shingleSize - getMaxMissingPoints(shingleSize)) { // Imputes missing data points with the values of neighboring data points. - long maxMillisecondsDifference = maxNeighborDistance * detector.getIntervalInMilliseconds(); - result = getNearbyPointsForShingle(detector, filteredShingle, endTime, maxMillisecondsDifference) + long maxMillisecondsDifference = maxNeighborDistance * config.getIntervalInMilliseconds(); + result = getNearbyPointsForShingle(config, filteredShingle, endTime, maxMillisecondsDifference) .map(e -> e.getValue().getValue().orElse(null)) .filter(d -> d != null) .toArray(double[][]::new); @@ -240,6 +263,7 @@ private double[][] filterAndFill(Deque>> shingle, result = null; } } + return result; } @@ -254,7 +278,7 @@ private double[][] filterAndFill(Deque>> shingle, * point value. */ private Stream>>> getNearbyPointsForShingle( - AnomalyDetector detector, + Config detector, Deque>> shingle, long endTime, long maxMillisecondsDifference @@ -281,7 +305,7 @@ private LongStream getFullShingleEndTimes(long endTime, long intervalMilli, int } /** - * Returns to listener data for cold-start training. + * Returns to listener data for cold-start training. Used in AD single-stream. * * Training data starts with getting samples from (costly) search. * Samples are increased in dimension via shingling. @@ -292,27 +316,37 @@ private LongStream getFullShingleEndTimes(long endTime, long intervalMilli, int */ public void getColdStartData(AnomalyDetector detector, ActionListener> listener) { ActionListener> latestTimeListener = ActionListener - .wrap(latest -> getColdStartSamples(latest, detector, listener), listener::onFailure); + .wrap(latest -> getColdStartSamples(latest, detector, AnalysisType.AD, listener), listener::onFailure); searchFeatureDao - .getLatestDataTime(detector, new ThreadedActionListener<>(logger, threadPool, adThreadPoolName, latestTimeListener, false)); + .getLatestDataTime( + detector, + Optional.empty(), + AnalysisType.AD, + new ThreadedActionListener<>(logger, threadPool, adThreadPoolName, latestTimeListener, false) + ); } - private void getColdStartSamples(Optional latest, AnomalyDetector detector, ActionListener> listener) { - int shingleSize = detector.getShingleSize(); + private void getColdStartSamples( + Optional latest, + Config config, + AnalysisType context, + ActionListener> listener + ) { + int shingleSize = config.getShingleSize(); if (latest.isPresent()) { - List> sampleRanges = getColdStartSampleRanges(detector, latest.get()); + List> sampleRanges = getColdStartSampleRanges(config, latest.get()); try { ActionListener>> getFeaturesListener = ActionListener .wrap(samples -> processColdStartSamples(samples, shingleSize, listener), listener::onFailure); searchFeatureDao .getFeatureSamplesForPeriods( - detector, + config, sampleRanges, - AnalysisType.AD, + context, new ThreadedActionListener<>(logger, threadPool, adThreadPoolName, getFeaturesListener, false) ); } catch (IOException e) { - listener.onFailure(new EndRunException(detector.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); + listener.onFailure(new EndRunException(config.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); } } else { listener.onResponse(Optional.empty()); @@ -361,7 +395,7 @@ private Optional fillAndShingle(LinkedList> shingle return result; } - private List> getColdStartSampleRanges(AnomalyDetector detector, long endMillis) { + private List> getColdStartSampleRanges(Config detector, long endMillis) { long interval = detector.getIntervalInMilliseconds(); int numSamples = Math.max((int) (Duration.ofHours(this.trainSampleTimeRangeInHours).toMillis() / interval), this.minTrainSamples); return IntStream @@ -624,18 +658,12 @@ private List> getPreviewRanges(List> ranges, private Entry getPreviewFeatures(double[][] samples, int stride, int shingleSize) { Entry unprocessedAndProcessed = Optional .of(samples) - .map(m -> transpose(m)) .map(m -> imputer.impute(m, stride * (samples.length - 1) + 1)) - .map(m -> transpose(m)) .map(m -> new SimpleImmutableEntry<>(copyOfRange(m, shingleSize - 1, m.length), batchShingle(m, shingleSize))) .get(); return unprocessedAndProcessed; } - public double[][] transpose(double[][] matrix) { - return createRealMatrix(matrix).transpose().getData(); - } - private long truncateToMinute(long epochMillis) { return Instant.ofEpochMilli(epochMillis).truncatedTo(ChronoUnit.MINUTES).toEpochMilli(); } @@ -688,11 +716,7 @@ public SinglePointFeatures getShingledFeatureForHistoricalAnalysis( return getProcessedFeatures(shingle, detector, endTime); } - private SinglePointFeatures getProcessedFeatures( - Deque>> shingle, - AnomalyDetector detector, - long endTime - ) { + private SinglePointFeatures getProcessedFeatures(Deque>> shingle, Config detector, long endTime) { int shingleSize = detector.getShingleSize(); Optional currentPoint = shingle.peekLast().getValue(); return new SinglePointFeatures( @@ -705,4 +729,18 @@ private SinglePointFeatures getProcessedFeatures( ); } + /** + * + * @param endTime End time of the stream + * @param intervalMilli interval between returned time + * @param startTime Start time of the stream + * @return a list of epoch timestamps from endTime with interval intervalMilli. The stream should stop when the number is earlier than startTime. + */ + private List getFullTrainingDataEndTimes(long endTime, long intervalMilli, long startTime) { + return LongStream + .iterate(startTime, i -> i + intervalMilli) + .takeWhile(i -> i <= endTime) + .boxed() // Convert LongStream to Stream + .collect(Collectors.toList()); // Collect to List + } } diff --git a/src/main/java/org/opensearch/ad/feature/Features.java b/src/main/java/org/opensearch/timeseries/feature/Features.java similarity index 98% rename from src/main/java/org/opensearch/ad/feature/Features.java rename to src/main/java/org/opensearch/timeseries/feature/Features.java index de347b78f..13cefc1d8 100644 --- a/src/main/java/org/opensearch/ad/feature/Features.java +++ b/src/main/java/org/opensearch/timeseries/feature/Features.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.util.Arrays; import java.util.List; diff --git a/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java b/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java index 1ce44472f..e6c440477 100644 --- a/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java +++ b/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java @@ -38,7 +38,6 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.feature.AbstractRetriever; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -79,11 +78,10 @@ * DAO for features from search. */ public class SearchFeatureDao extends AbstractRetriever { + private static final Logger logger = LogManager.getLogger(SearchFeatureDao.class); - protected static final String AGG_NAME_MIN = "min_timefield"; protected static final String AGG_NAME_TOP = "top_agg"; - - private static final Logger logger = LogManager.getLogger(SearchFeatureDao.class); + protected static final String AGG_NAME_MIN = "min_timefield"; // Dependencies private final Client client; @@ -166,14 +164,23 @@ public SearchFeatureDao( /** * Returns to listener the epoch time of the latset data under the detector. * - * @param detector info about the data + * @param config info about the data * @param listener onResponse is called with the epoch time of the latset data under the detector */ - public void getLatestDataTime(AnomalyDetector detector, ActionListener> listener) { + public void getLatestDataTime(Config config, Optional entity, AnalysisType context, ActionListener> listener) { + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery(); + + if (entity.isPresent()) { + for (TermQueryBuilder term : entity.get().getTermQueryForCustomerIndex()) { + internalFilterQuery.filter(term); + } + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())) + .query(internalFilterQuery) + .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(config.getTimeField())) .size(0); - SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + SearchRequest searchRequest = new SearchRequest().indices(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); final ActionListener searchResponseListener = ActionListener .wrap(response -> listener.onResponse(ParseUtils.getLatestDataTime(response)), listener::onFailure); // using the original context in listener as user roles have no permissions for internal operations like fetching a @@ -182,9 +189,9 @@ public void getLatestDataTime(AnomalyDetector detector, ActionListenerasyncRequestWithInjectedSecurity( searchRequest, client::search, - detector.getId(), + config.getId(), client, - AnalysisType.AD, + context, searchResponseListener ); } @@ -484,7 +491,7 @@ public void getMinDataTime(Config config, Optional entity, AnalysisType BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery(); if (entity.isPresent()) { - for (TermQueryBuilder term : entity.get().getTermQueryBuilders()) { + for (TermQueryBuilder term : entity.get().getTermQueryForCustomerIndex()) { internalFilterQuery.filter(term); } } @@ -863,6 +870,36 @@ private SearchRequest createPreviewSearchRequest(Config config, ListThe method constructs a search request based on the provided parameters, executes the search, + * and processes the response to extract and format the relevant sample data. The resulting list + * of samples, ordered by time, is passed to the {@code listener} on successful retrieval. + * + *

In cases where the OpenSearch aggregations return null (e.g., no data matches the query), + * the method responds with an empty list. This method also applies a document count threshold + * to filter out buckets with insignificant data, based on the {@code includesEmptyBucket} parameter. + * + *

It's important to note that this method assumes ascending order for the date range bucket + * aggregation results by default and treats the {@code config.getEnabledFeatureIds()} to parse + * and format each bucket's data into the expected double array format. + */ public void getColdStartSamplesForPeriods( Config config, List> ranges, diff --git a/src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java b/src/main/java/org/opensearch/timeseries/feature/SinglePointFeatures.java similarity index 97% rename from src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java rename to src/main/java/org/opensearch/timeseries/feature/SinglePointFeatures.java index cbd7ef78b..9849a67f8 100644 --- a/src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java +++ b/src/main/java/org/opensearch/timeseries/feature/SinglePointFeatures.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.util.Optional; diff --git a/src/main/java/org/opensearch/timeseries/function/ResponseTransformer.java b/src/main/java/org/opensearch/timeseries/function/ResponseTransformer.java new file mode 100644 index 000000000..93c897718 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/function/ResponseTransformer.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.function; + +/** + * A functional interface for response transformation + * + * @param input type + * @param output type + */ +@FunctionalInterface +public interface ResponseTransformer { + R transform(T input); +} diff --git a/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java b/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java index 36134c263..2f9785db4 100644 --- a/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java +++ b/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java @@ -157,7 +157,7 @@ protected IndexManagement( this.threadPool = threadPool; this.clusterService.addLocalNodeClusterManagerListener(this); this.nodeFilter = nodeFilter; - this.settings = Settings.builder().put("index.hidden", true).build(); + this.settings = Settings.builder().put(IndexMetadata.SETTING_INDEX_HIDDEN, true).build(); this.maxUpdateRunningTimes = maxUpdateRunningTimes; this.indexType = indexType; this.maxPrimaryShards = maxPrimaryShards; @@ -259,7 +259,7 @@ protected void choosePrimaryShards(CreateIndexRequest request, boolean hiddenInd .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, getNumberOfPrimaryShards()) // 1 replica for better search performance and fail-over .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1) - .put("index.hidden", hiddenIndex) + .put(IndexMetadata.SETTING_INDEX_HIDDEN, hiddenIndex) ); } @@ -492,7 +492,7 @@ public void initJobIndex(ActionListener actionListener) { // accordingly. // At least 1 replica for fail-over. .put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, minJobIndexReplicas + "-" + maxJobIndexReplicas) - .put("index.hidden", true) + .put(IndexMetadata.SETTING_INDEX_HIDDEN, true) ); adminClient.indices().create(request, actionListener); } catch (IOException e) { @@ -501,10 +501,42 @@ public void initJobIndex(ActionListener actionListener) { } } - public void validateCustomResultIndexAndExecute(String resultIndex, ExecutorFunction function, ActionListener listener) { + /** + * Validates the result index and executes the provided function. + * + *

+ * This method first checks if the mapping for the given result index is valid. If the mapping is not validated + * and is found to be invalid, the method logs a warning and notifies the listener of the failure. + *

+ * + *

+ * If the mapping is valid or has been previously validated, the method attempts to write and then immediately + * delete a dummy forecast result to the index. This is a workaround to verify the user's write permission on + * the custom result index, as there is currently no straightforward method to check for write permissions directly. + *

+ * + *

+ * If both write and delete operations are successful, the provided function is executed. If any step fails, + * the method logs an error and notifies the listener of the failure. + *

+ * + * @param The type of the action listener's response. + * @param resultIndex The custom result index to validate. + * @param function The function to be executed if validation is successful. + * @param mappingValidated Indicates whether the mapping for the result index has been previously validated. + * @param listener The listener to be notified of the success or failure of the operation. + * + * @throws IllegalArgumentException If the result index mapping is found to be invalid. + */ + public void validateResultIndexAndExecute( + String resultIndex, + ExecutorFunction function, + boolean mappingValidated, + ActionListener listener + ) { try { - if (!isValidResultIndexMapping(resultIndex)) { - logger.warn("Can't create detector with custom result index {} as its mapping is invalid", resultIndex); + if (!mappingValidated && !isValidResultIndexMapping(resultIndex)) { + logger.warn("Can't create analysis with custom result index {} as its mapping is invalid", resultIndex); listener.onFailure(new IllegalArgumentException(CommonMessages.INVALID_RESULT_INDEX_MAPPING + resultIndex)); return; } @@ -583,14 +615,14 @@ private void updateSettingIfNecessary(GroupedActionListener delegateListen ); for (IndexType timeseriesIndex : updates) { logger.info(new ParameterizedMessage("Check [{}]'s setting", timeseriesIndex.getIndexName())); - if (timeseriesIndex.isJobIndex()) { + if (timeseriesIndex.isJobIndex() && doesIndexExist(timeseriesIndex.getIndexName())) { updateJobIndexSettingIfNecessary( ADIndex.JOB.getIndexName(), indexStates.computeIfAbsent(timeseriesIndex, k -> new IndexState(k.getMapping())), conglomerateListeneer ); } else { - // we don't have settings to update for other indices + // we don't have settings to update for other cases IndexState indexState = indexStates.computeIfAbsent(timeseriesIndex, k -> new IndexState(k.getMapping())); indexState.settingUpToDate = true; logger.info(new ParameterizedMessage("Mark [{}]'s setting up-to-date", timeseriesIndex.getIndexName())); @@ -629,20 +661,20 @@ private void updateMappingIfNecessary(GroupedActionListener delegateListen updates.size() ); - for (IndexType adIndex : updates) { - logger.info(new ParameterizedMessage("Check [{}]'s mapping", adIndex.getIndexName())); - shouldUpdateIndex(adIndex, ActionListener.wrap(shouldUpdate -> { + for (IndexType index : updates) { + logger.info(new ParameterizedMessage("Check [{}]'s mapping", index.getIndexName())); + shouldUpdateIndex(index, ActionListener.wrap(shouldUpdate -> { if (shouldUpdate) { adminClient .indices() .putMapping( - new PutMappingRequest().indices(adIndex.getIndexName()).source(adIndex.getMapping(), XContentType.JSON), + new PutMappingRequest().indices(index.getIndexName()).source(index.getMapping(), XContentType.JSON), ActionListener.wrap(putMappingResponse -> { if (putMappingResponse.isAcknowledged()) { - logger.info(new ParameterizedMessage("Succeeded in updating [{}]'s mapping", adIndex.getIndexName())); - markMappingUpdated(adIndex); + logger.info(new ParameterizedMessage("Succeeded in updating [{}]'s mapping", index.getIndexName())); + markMappingUpdated(index); } else { - logger.error(new ParameterizedMessage("Fail to update [{}]'s mapping", adIndex.getIndexName())); + logger.error(new ParameterizedMessage("Fail to update [{}]'s mapping", index.getIndexName())); } conglomerateListeneer.onResponse(null); }, exception -> { @@ -650,7 +682,7 @@ private void updateMappingIfNecessary(GroupedActionListener delegateListen .error( new ParameterizedMessage( "Fail to update [{}]'s mapping due to [{}]", - adIndex.getIndexName(), + index.getIndexName(), exception.getMessage() ) ); @@ -661,14 +693,14 @@ private void updateMappingIfNecessary(GroupedActionListener delegateListen // index does not exist or the version is already up-to-date. // When creating index, new mappings will be used. // We don't need to update it. - logger.info(new ParameterizedMessage("We don't need to update [{}]'s mapping", adIndex.getIndexName())); - markMappingUpdated(adIndex); + logger.info(new ParameterizedMessage("We don't need to update [{}]'s mapping", index.getIndexName())); + markMappingUpdated(index); conglomerateListeneer.onResponse(null); } }, exception -> { logger .error( - new ParameterizedMessage("Fail to check whether we should update [{}]'s mapping", adIndex.getIndexName()), + new ParameterizedMessage("Fail to check whether we should update [{}]'s mapping", index.getIndexName()), exception ); conglomerateListeneer.onFailure(exception); @@ -737,7 +769,7 @@ public void initCustomResultIndexAndExecute(String resultIndex, ExecutorFunc initCustomResultIndexDirectly(resultIndex, ActionListener.wrap(response -> { if (response.isAcknowledged()) { logger.info("Successfully created result index {}", resultIndex); - validateCustomResultIndexAndExecute(resultIndex, function, listener); + validateResultIndexAndExecute(resultIndex, function, false, listener); } else { String error = "Creating result index with mappings call not acknowledged: " + resultIndex; logger.error(error); @@ -746,14 +778,14 @@ public void initCustomResultIndexAndExecute(String resultIndex, ExecutorFunc }, exception -> { if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { // It is possible the index has been created while we sending the create request - validateCustomResultIndexAndExecute(resultIndex, function, listener); + validateResultIndexAndExecute(resultIndex, function, false, listener); } else { logger.error("Failed to create result index " + resultIndex, exception); listener.onFailure(exception); } })); } else { - validateCustomResultIndexAndExecute(resultIndex, function, listener); + validateResultIndexAndExecute(resultIndex, function, false, listener); } } @@ -779,10 +811,10 @@ public void validateCustomIndexForBackendJob( injectSecurity.close(); listener.onFailure(e); }); - validateCustomResultIndexAndExecute(resultIndex, () -> { + validateResultIndexAndExecute(resultIndex, () -> { injectSecurity.close(); function.execute(); - }, wrappedListener); + }, true, wrappedListener); } catch (Exception e) { logger.error("Failed to validate custom index for backend job " + securityLogId, e); listener.onFailure(e); @@ -863,7 +895,6 @@ public boolean isValidResultIndexMapping(String resultIndex) { return false; } LinkedHashMap mapping = (LinkedHashMap) indexMapping.get(propertyName); - boolean correctResultIndexMapping = true; for (String fieldName : RESULT_FIELD_CONFIGS.keySet()) { @@ -874,6 +905,7 @@ public boolean isValidResultIndexMapping(String resultIndex) { // feature_id={type=keyword}}}}} // if it is a map of map, Object.equals can compare them regardless of order if (!mapping.containsKey(fieldName) || !defaultSchema.equals(mapping.get(fieldName))) { + logger.warn("mapping mismatch due to {}", fieldName); correctResultIndexMapping = false; break; } diff --git a/src/main/java/org/opensearch/timeseries/ml/CheckpointDao.java b/src/main/java/org/opensearch/timeseries/ml/CheckpointDao.java new file mode 100644 index 000000000..e79a5ed68 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/CheckpointDao.java @@ -0,0 +1,383 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Queue; + +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetAction; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetAction; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.index.reindex.ScrollableHitSource; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.util.ClientUtil; + +import com.google.gson.Gson; + +import io.protostuff.LinkedBuffer; + +public abstract class CheckpointDao & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + private static final Logger logger = LogManager.getLogger(CheckpointDao.class); + public static final String TIMEOUT_LOG_MSG = "Timeout while deleting checkpoints of"; + public static final String BULK_FAILURE_LOG_MSG = "Bulk failure while deleting checkpoints of"; + public static final String SEARCH_FAILURE_LOG_MSG = "Search failure while deleting checkpoints of"; + public static final String DOC_GOT_DELETED_LOG_MSG = "checkpoints docs get deleted"; + public static final String INDEX_DELETED_LOG_MSG = "Checkpoint index has been deleted. Has nothing to do:"; + + // dependencies + protected final Client client; + protected final ClientUtil clientUtil; + + // configuration + protected final String indexName; + + protected Gson gson; + + // we won't read/write a checkpoint larger than a threshold + protected final int maxCheckpointBytes; + + protected final GenericObjectPool serializeRCFBufferPool; + protected final int serializeRCFBufferSize; + + protected final IndexManagement indexUtil; + protected final Clock clock; + public static final String NOT_ABLE_TO_DELETE_CHECKPOINT_MSG = "Cannot delete all checkpoints of detector"; + + public CheckpointDao( + Client client, + ClientUtil clientUtil, + String indexName, + Gson gson, + int maxCheckpointBytes, + GenericObjectPool serializeRCFBufferPool, + int serializeRCFBufferSize, + IndexManagementType indexUtil, + Clock clock + ) { + this.client = client; + this.clientUtil = clientUtil; + this.indexName = indexName; + this.gson = gson; + this.maxCheckpointBytes = maxCheckpointBytes; + this.serializeRCFBufferPool = serializeRCFBufferPool; + this.serializeRCFBufferSize = serializeRCFBufferSize; + this.indexUtil = indexUtil; + this.clock = clock; + } + + protected void putModelCheckpoint(String modelId, Map source, ActionListener listener) { + if (indexUtil.doesCheckpointIndexExist()) { + saveModelCheckpointAsync(source, modelId, listener); + } else { + onCheckpointNotExist(source, modelId, listener); + } + } + + /** + * Update the model doc using fields in source. This ensures we won't touch + * the old checkpoint and nodes with old/new logic can coexist in a cluster. + * This is useful for introducing compact rcf new model format. + * + * @param source fields to update + * @param modelId model Id, used as doc id in the checkpoint index + * @param listener Listener to return response + */ + protected void saveModelCheckpointAsync(Map source, String modelId, ActionListener listener) { + + UpdateRequest updateRequest = new UpdateRequest(indexName, modelId); + updateRequest.doc(source); + // If the document does not already exist, the contents of the upsert element are inserted as a new document. + // If the document exists, update fields in the map + updateRequest.docAsUpsert(true); + clientUtil + .asyncRequest( + updateRequest, + client::update, + ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) + ); + } + + protected void onCheckpointNotExist(Map source, String modelId, ActionListener listener) { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + saveModelCheckpointAsync(source, modelId, listener); + + } else { + throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + saveModelCheckpointAsync(source, modelId, listener); + } else { + logger.error(String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), exception); + } + })); + } + + protected Map.Entry checkoutOrNewBuffer() { + LinkedBuffer buffer = null; + boolean isCheckout = true; + try { + buffer = serializeRCFBufferPool.borrowObject(); + } catch (Exception e) { + logger.warn("Failed to borrow a buffer from pool", e); + } + if (buffer == null) { + buffer = LinkedBuffer.allocate(serializeRCFBufferSize); + isCheckout = false; + } + return new SimpleImmutableEntry(buffer, isCheckout); + } + + /** + * Deletes the model checkpoint for the model. + * + * @param modelId id of the model + * @param listener onReponse is called with null when the operation is completed + */ + public void deleteModelCheckpoint(String modelId, ActionListener listener) { + clientUtil + .asyncRequest( + new DeleteRequest(indexName, modelId), + client::delete, + ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) + ); + } + + protected void logFailure(BulkByScrollResponse response, String id) { + if (response.isTimedOut()) { + logger.warn(CheckpointDao.TIMEOUT_LOG_MSG + " {}", id); + } else if (!response.getBulkFailures().isEmpty()) { + logger.warn(CheckpointDao.BULK_FAILURE_LOG_MSG + " {}", id); + for (BulkItemResponse.Failure bulkFailure : response.getBulkFailures()) { + logger.warn(bulkFailure); + } + } else { + logger.warn(CheckpointDao.SEARCH_FAILURE_LOG_MSG + " {}", id); + for (ScrollableHitSource.SearchFailure searchFailure : response.getSearchFailures()) { + logger.warn(searchFailure); + } + } + } + + /** + * Determines whether to save the checkpoint based on various conditions. + * + * @param modelState The current state of the model, which includes the last checkpoint time. + * @param forceWrite Indicates if the checkpoint should be saved regardless of other conditions. + * @param checkpointInterval The interval at which checkpoints should be saved. + * @param clock The clock used to determine the current time (usually in UTC). + * + * @return true if both of the following conditions are met: + * 1. The model state is valid (the model is non-null or it has non-empty samples), and + * 2. Either forceWrite is true, or the last checkpoint time is not the minimum instant and the current time exceeds the last checkpoint time by at least the checkpoint interval. + * Returns false otherwise. + */ + public boolean shouldSave(ModelState modelState, boolean forceWrite, Duration checkpointInterval, Clock clock) { + if (modelState == null) { + return false; + } + + Instant lastCheckpointTime = modelState.getLastCheckpointTime(); + boolean isTimeForCheckpoint = lastCheckpointTime != null + && !lastCheckpointTime.equals(Instant.MIN) + && lastCheckpointTime.plus(checkpointInterval).isBefore(clock.instant()); + boolean hasValidSamples = modelState.getSamples() != null && !modelState.getSamples().isEmpty(); + boolean isModelStateValid = modelState.getModel().isPresent() || hasValidSamples; + return isModelStateValid && (isTimeForCheckpoint || forceWrite); + } + + public void batchWrite(BulkRequest request, ActionListener listener) { + if (indexUtil.doesCheckpointIndexExist()) { + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + // create index failure. Notify callers using listener. + listener.onFailure(new TimeSeriesException("Creating checkpoint with mappings call not acknowledged.")); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + logger.error(String.format(Locale.ROOT, "Unexpected error creating checkpoint index"), exception); + listener.onFailure(exception); + } + })); + } + } + + /** + * Serialized samples + * @param samples input samples + * @return serialized object + */ + protected Optional toCheckpoint(Queue samples) { + if (samples == null || samples.isEmpty()) { + return Optional.empty(); + } + return Optional.of(samples.toArray(new Sample[0])); + } + + public void batchRead(MultiGetRequest request, ActionListener listener) { + clientUtil.execute(MultiGetAction.INSTANCE, request, listener); + } + + public void read(GetRequest request, ActionListener listener) { + clientUtil.execute(GetAction.INSTANCE, request, listener); + } + + /** + * Delete checkpoints associated with a config. Used in multi-entity detector. + * @param configId Config Id + */ + public void deleteModelCheckpointByConfigId(String configId) { + // A bulk delete request is performed for each batch of matching documents. If a + // search or bulk request is rejected, the requests are retried up to 10 times, + // with exponential back off. If the maximum retry limit is reached, processing + // halts and all failed requests are returned in the response. Any delete + // requests that completed successfully still stick, they are not rolled back. + DeleteByQueryRequest deleteRequest = createDeleteCheckpointRequest(configId); + logger.info("Delete checkpoints of config {}", configId); + client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { + if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { + logFailure(response, configId); + } + // can return 0 docs get deleted because: + // 1) we cannot find matching docs + // 2) bad stats from OpenSearch. In this case, docs are deleted, but + // OpenSearch says deleted is 0. + logger.info("{} " + CheckpointDao.DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info(CheckpointDao.INDEX_DELETED_LOG_MSG + " {}", configId); + } else { + // Gonna eventually delete in daily cron. + logger.error(NOT_ABLE_TO_DELETE_CHECKPOINT_MSG, exception); + } + })); + } + + protected Optional> processRawCheckpoint(GetResponse response) { + try { + return Optional.ofNullable(response).filter(GetResponse::isExists).map(GetResponse::getSource); + } catch (Exception e) { + // Assuming a logger is available + logger.error("Error processing raw checkpoint", e); + return Optional.empty(); + } + } + + /** + * Process a checkpoint GetResponse and return the EntityModel object + * @param response Checkpoint Index GetResponse + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time + */ + public ModelState processHCGetResponse(GetResponse response, String modelId, String configId) { + Optional> checkpointString = processRawCheckpoint(response); + if (checkpointString.isPresent()) { + return fromEntityModelCheckpoint(checkpointString.get(), modelId, configId); + } else { + return null; + } + } + + /** + * Process a checkpoint GetResponse and return the EntityModel object + * @param response Checkpoint Index GetResponse + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time + */ + public ModelState processSingleStreamGetResponse(GetResponse response, String modelId, String configId) { + Optional> checkpointString = processRawCheckpoint(response); + if (checkpointString.isPresent()) { + return fromSingleStreamModelCheckpoint(checkpointString.get(), modelId, configId); + } else { + return null; + } + } + + protected abstract ModelState fromEntityModelCheckpoint(Map checkpoint, String modelId, String configId); + + protected abstract ModelState fromSingleStreamModelCheckpoint( + Map checkpoint, + String modelId, + String configId + ); + + public abstract Map toIndexSource(ModelState modelState) throws IOException; + + protected abstract DeleteByQueryRequest createDeleteCheckpointRequest(String configId); + + protected Deque loadSampleQueue(Map checkpoint, String modelId) { + Deque sampleQueue = new ArrayDeque<>(); + // Even though we we save sample_queue using array, after ser/der, we need to read it as List + // we start using SAMPLE_QUEUE after forecasting refactoring. Previously in AD, we use CommonName.ENTITY_SAMPLE + // to store samples. The refactoring moves samples out of EntityModel and makes it a first-level field. + List> samples = (List>) checkpoint.get(CommonName.SAMPLE_QUEUE); + if (samples != null) { + samples.forEach(sampleMap -> { + try { + Sample sample = Sample.extractSample(sampleMap); + if (sample != null) { + sampleQueue.add(sample); + } + } catch (Exception e) { + logger.warn("Exception while deserializing samples for " + modelId, e); + } + }); + } + // can be null when checkpoint corrupted (e.g., a checkpoint not recognized by current code + // due to bugs). Better redo training. + return sampleQueue; + } +} diff --git a/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java b/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java new file mode 100644 index 000000000..b477f454a --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java @@ -0,0 +1,102 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * A customized ConcurrentHashMap that can automatically consume and release memory. + * This enables minimum change to our single-stream code as we just have to replace + * the map implementation. + * + * Note: this is mainly used for single-stream configs. The key is model id. + */ +public class MemoryAwareConcurrentHashmap extends + ConcurrentHashMap> { + protected final MemoryTracker memoryTracker; + + public MemoryAwareConcurrentHashmap(MemoryTracker memoryTracker) { + this.memoryTracker = memoryTracker; + } + + @Override + public ModelState remove(Object key) { + ModelState deletedModelState = super.remove(key); + if (deletedModelState != null && deletedModelState.getModel().isPresent()) { + long memoryToRelease = memoryTracker.estimateTRCFModelSize(deletedModelState.getModel().get()); + memoryTracker.releaseMemory(memoryToRelease, true, Origin.REAL_TIME_DETECTOR); + } + return deletedModelState; + } + + @Override + public ModelState put(String key, ModelState value) { + ModelState previousAssociatedState = super.put(key, value); + if (value != null && value.getModel().isPresent()) { + long memoryToConsume = memoryTracker.estimateTRCFModelSize(value.getModel().get()); + memoryTracker.consumeMemory(memoryToConsume, true, Origin.REAL_TIME_DETECTOR); + } + return previousAssociatedState; + } + + /** + * Gets all of a config's model sizes hosted on a node + * + * @param configId config Id + * @return a map of model id to its memory size + */ + public Map getModelSize(String configId) { + Map res = new HashMap<>(); + super.entrySet() + .stream() + .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId)) + .forEach(entry -> { + Optional modelOptional = entry.getValue().getModel(); + if (modelOptional.isPresent()) { + res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(modelOptional.get())); + } + }); + return res; + } + + /** + * Checks if a model exists for the given config. + * @param configId Config Id + * @return `true` if the model exists, `false` otherwise. + */ + public boolean doesModelExist(String configId) { + return super.entrySet() + .stream() + .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId)) + .anyMatch(n -> true); + } + + public boolean hostIfPossible(String modelId, ModelState toUpdate) { + return Optional + .ofNullable(toUpdate) + .filter(state -> state.getModel().isPresent()) + .filter(state -> memoryTracker.isHostingAllowed(modelId, state.getModel().get())) + .map(state -> { + super.put(modelId, toUpdate); + return true; + }) + .orElse(false); + } +} diff --git a/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java b/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java new file mode 100644 index 000000000..43dd43883 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java @@ -0,0 +1,579 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.core.util.Throwables; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.CleanState; +import org.opensearch.timeseries.MaintenanceState; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.caching.DoorKeeper; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.dataprocessor.ImputationOption; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.ExceptionUtil; + +import com.amazon.randomcutforest.config.ImputationMethod; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * The class bootstraps a model by performing a cold start + */ +public abstract class ModelColdStart & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker> + implements + MaintenanceState, + CleanState { + private static final Logger logger = LogManager.getLogger(ModelColdStart.class); + + private final Duration modelTtl; + + // A bloom filter checked before cold start to ensure we don't repeatedly + // retry cold start of the same model. + // keys are detector ids. + protected Map doorKeepers; + protected Instant lastThrottledColdStartTime; + protected int coolDownMinutes; + protected final Clock clock; + protected final ThreadPool threadPool; + protected final int numMinSamples; + protected CheckpointWriteWorkerType checkpointWriteWorker; + // make sure rcf use a specific random seed. Otherwise, we will use a random random (not a typo) seed. + // this is mainly used for testing to make sure the model we trained and the reference rcf produce + // the same results + protected final long rcfSeed; + protected final int numberOfTrees; + protected final int rcfSampleSize; + protected final double thresholdMinPvalue; + protected final double initialAcceptFraction; + protected final NodeStateManager nodeStateManager; + protected final int defaulStrideLength; + protected final int defaultNumberOfSamples; + protected final SearchFeatureDao searchFeatureDao; + protected final FeatureManager featureManager; + protected final int maxRoundofColdStart; + protected final String threadPoolName; + protected final AnalysisType context; + + public ModelColdStart( + Duration modelTtl, + int coolDownMinutes, + Clock clock, + ThreadPool threadPool, + int numMinSamples, + CheckpointWriteWorkerType checkpointWriteWorker, + long rcfSeed, + int numberOfTrees, + int rcfSampleSize, + double thresholdMinPvalue, + NodeStateManager nodeStateManager, + int defaultSampleStride, + int defaultTrainSamples, + SearchFeatureDao searchFeatureDao, + FeatureManager featureManager, + int maxRoundofColdStart, + String threadPoolName, + AnalysisType context + ) { + this.modelTtl = modelTtl; + this.coolDownMinutes = coolDownMinutes; + this.clock = clock; + this.threadPool = threadPool; + this.numMinSamples = numMinSamples; + this.checkpointWriteWorker = checkpointWriteWorker; + this.rcfSeed = rcfSeed; + this.numberOfTrees = numberOfTrees; + this.rcfSampleSize = rcfSampleSize; + this.thresholdMinPvalue = thresholdMinPvalue; + + this.doorKeepers = new ConcurrentHashMap<>(); + this.lastThrottledColdStartTime = Instant.MIN; + this.initialAcceptFraction = numMinSamples * 1.0d / rcfSampleSize; + + this.nodeStateManager = nodeStateManager; + this.defaulStrideLength = defaultSampleStride; + this.defaultNumberOfSamples = defaultTrainSamples; + this.searchFeatureDao = searchFeatureDao; + this.featureManager = featureManager; + this.maxRoundofColdStart = maxRoundofColdStart; + this.threadPoolName = threadPoolName; + this.context = context; + } + + @Override + public void maintenance() { + doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { + String id = doorKeeperEntry.getKey(); + DoorKeeper doorKeeper = doorKeeperEntry.getValue(); + if (doorKeeper.expired(modelTtl)) { + doorKeepers.remove(id); + } else { + doorKeeper.maintenance(); + } + }); + } + + @Override + public void clear(String id) { + doorKeepers.remove(id); + } + + /** + * Train models + * @param coldStartRequest cold start request + * @param configId Config Id + * @param modelState Model state + * @param listener callback before the method returns whenever ColdStarter + * finishes training or encounters exceptions. The listener helps notify the + * cold start queue to pull another request (if any) to execute. We save the + * training data in result index so that the frontend can plot it. + */ + public void trainModel( + FeatureRequest coldStartRequest, + String configId, + ModelState modelState, + ActionListener> listener + ) { + nodeStateManager.getConfig(configId, context, ActionListener.wrap(configOptional -> { + if (false == configOptional.isPresent()) { + logger.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); + listener.onFailure(new TimeSeriesException(configId, "fail to find config")); + return; + } + + Config config = configOptional.get(); + + String modelId = modelState.getModelId(); + + if (modelState.getSamples().size() < this.numMinSamples) { + coldStart(modelId, coldStartRequest, modelState, config, listener); + } else { + try { + trainModelFromExistingSamples(modelState, coldStartRequest.getEntity(), config, coldStartRequest.getTaskId()); + listener.onResponse(null); + } catch (Exception e) { + listener.onFailure(e); + } + } + }, listener::onFailure)); + } + + public void trainModelFromExistingSamples(ModelState modelState, Optional entity, Config config, String taskId) { + if (modelState.getSamples().size() >= this.numMinSamples) { + Deque samples = modelState.getSamples(); + trainModelFromDataSegments(new ArrayList<>(samples), entity, modelState, config, taskId); + // clear after use + modelState.clearSamples(); + } + } + + /** + * Training model + * @param modelId model Id corresponding to the entity + * @param coldStartRequest cold start request + * @param modelState model state + * @param config config accessor + * @param listener call back to send processed training data and last sample in the training data + */ + private void coldStart( + String modelId, + FeatureRequest coldStartRequest, + ModelState modelState, + Config config, + ActionListener> listener + ) { + logger.debug("Trigger cold start for {}", modelId); + + if (modelState == null) { + listener.onFailure(new IllegalArgumentException(String.format(Locale.ROOT, "Cannot have empty model state"))); + return; + } + + if (lastThrottledColdStartTime.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) { + listener.onResponse(null); + return; + } + + String configId = config.getId(); + boolean earlyExit = true; + try { + // Won't retry real-time cold start within 60 intervals for an entity + // coldStartRequest.getTaskId() == null in real-time cold start + + DoorKeeper doorKeeper = doorKeepers.computeIfAbsent(configId, id -> { + // reset every 60 intervals + return new DoorKeeper( + TimeSeriesSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION, + config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + clock, + TimeSeriesSettings.COLD_START_DOOR_KEEPER_COUNT_THRESHOLD + ); + }); + + if (doorKeeper.appearsMoreThanThreshold(modelId)) { + logger + .info( + "Won't retry real-time cold start within {} intervals for an model {}", + TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ, + modelId + ); + return; + } + + doorKeeper.put(modelId); + + ActionListener> coldStartCallBack = ActionListener.wrap(trainingData -> { + // existing samples might have different interval or duplicated data compared to training data we just grabbed. + // clear it before adding historical data. + modelState.clearSamples(); + if (trainingData != null && !trainingData.isEmpty()) { + int dataSize = trainingData.size(); + // only train models if we have enough samples + if (dataSize >= numMinSamples) { + // The function trainModelFromDataSegments will save a trained a model. trainModelFromDataSegments is called by + // multiple places, so I want to make the saving model implicit just in case I forgot. + List processedTrainingData = trainModelFromDataSegments( + trainingData, + coldStartRequest.getEntity(), + modelState, + config, + coldStartRequest.getTaskId() + ); + logger.info("Succeeded in training entity: {}", modelId); + listener.onResponse(processedTrainingData); + } else { + logger.info("Not enough data to train model: {}, currently we have {}", modelId, dataSize); + + trainingData.forEach(modelState::addSample); + // save to checkpoint + checkpointWriteWorker.write(modelState, true, RequestPriority.MEDIUM); + listener.onResponse(null); + } + } else { + logger.info("Cannot get training data for {}", modelId); + listener.onResponse(null); + } + + }, exception -> { + try { + logger.error(new ParameterizedMessage("Error while cold start {}", modelId), exception); + Throwable cause = Throwables.getRootCause(exception); + if (ExceptionUtil.isOverloaded(cause)) { + logger.error("too many requests"); + lastThrottledColdStartTime = Instant.now(); + } else if (cause instanceof TimeSeriesException || exception instanceof TimeSeriesException) { + // e.g., cannot find anomaly detector + nodeStateManager.setException(configId, exception); + } else { + nodeStateManager.setException(configId, new TimeSeriesException(configId, cause)); + } + listener.onFailure(exception); + } catch (Exception e) { + listener.onFailure(e); + } + }); + + threadPool + .executor(threadPoolName) + .execute( + () -> getColdStartData( + configId, + coldStartRequest, + new ThreadedActionListener<>(logger, threadPool, threadPoolName, coldStartCallBack, false) + ) + ); + earlyExit = false; + } finally { + if (earlyExit) { + listener.onResponse(null); + } + } + } + + /** + * Get training data for an entity. + * + * We first note the maximum and minimum timestamp, and sample at most 24 points + * (with 60 points apart between two neighboring samples) between those minimum + * and maximum timestamps. Samples can be missing. We only interpolate points + * between present neighboring samples. We then transform samples and interpolate + * points to shingles. Finally, full shingles will be used for cold start. + * + * @param configId config Id + * @param coldStartRequest cold start request + * @param listener A callback listener for receiving training data. + */ + private void getColdStartData(String configId, FeatureRequest coldStartRequest, ActionListener> listener) { + ActionListener> getDetectorListener = ActionListener.wrap(configOp -> { + if (!configOp.isPresent()) { + listener.onFailure(new EndRunException(configId, "Config is not available.", false)); + return; + } + Config config = configOp.get(); + + ActionListener> minTimeListener = ActionListener.wrap(earliest -> { + if (earliest.isPresent()) { + long startTimeMs = earliest.get().longValue(); + + // End time uses milliseconds as start time is assumed to be in milliseconds. + // Opensearch uses a set of preconfigured formats to recognize and parse these + // strings into a long value + // representing milliseconds-since-the-epoch in UTC. + // More on https://tinyurl.com/wub4fk92 + // also, since we want to use current feature to score, we don't use current interval + // [current start, current end] for training. So we fetch training data ending at current start + long endTimeMs = coldStartRequest.getDataStartTimeMillis(); + int numberOfSamples = selectNumberOfSamples(config); + + // we start with round 0 + getFeatures( + listener, + 0, + new ArrayList<>(), + config, + coldStartRequest.getEntity(), + numberOfSamples, + startTimeMs, + endTimeMs + ); + } else { + listener.onResponse(new ArrayList<>()); + } + }, listener::onFailure); + + searchFeatureDao + .getMinDataTime( + config, + coldStartRequest.getEntity(), + context, + new ThreadedActionListener<>(logger, threadPool, threadPoolName, minTimeListener, false) + ); + + }, listener::onFailure); + + nodeStateManager + .getConfig(configId, context, new ThreadedActionListener<>(logger, threadPool, threadPoolName, getDetectorListener, false)); + } + + /** + * Get the number of training samples to fetch from history. + * We require at least numMinSamples to let rcf output non-zero rcf scores. + * + * @return number of training samples + */ + private int selectNumberOfSamples(Config config) { + return Math.max(numMinSamples, config.getHistoryIntervals()); + } + + private void getFeatures( + ActionListener> listener, + int round, + List lastRounddataSample, + Config config, + Optional entity, + int numberOfSamples, + long startTimeMs, + long endTimeMs + ) { + if (startTimeMs >= endTimeMs || endTimeMs - startTimeMs < config.getIntervalInMilliseconds()) { + listener.onResponse(lastRounddataSample); + return; + } + + // Create ranges in ascending where the last sample's end time is the given endTimeMs. + // Sample ranges are also in ascending order in Opensearch's response. + List> sampleRanges = getTrainSampleRanges(config, startTimeMs, endTimeMs, numberOfSamples); + + if (sampleRanges.isEmpty()) { + listener.onResponse(lastRounddataSample); + return; + } + + ActionListener>> getFeaturelistener = ActionListener.wrap(featureSamples -> { + + int totalNumSamples = featureSamples.size(); + + if (totalNumSamples != sampleRanges.size()) { + String err = String + .format( + Locale.ROOT, + "length mismatch: totalNumSamples %d != time range length %d", + totalNumSamples, + sampleRanges.size() + ); + listener.onFailure(new IllegalArgumentException(err)); + return; + } + + // featuresSamples are in ascending order of time. + + List samples = new ArrayList<>(); + for (int index = 0; index < featureSamples.size(); index++) { + Optional featuresOptional = featureSamples.get(index); + if (featuresOptional.isPresent()) { + Entry curRange = sampleRanges.get(index); + samples + .add( + new Sample( + featuresOptional.get(), + Instant.ofEpochMilli(curRange.getKey()), + Instant.ofEpochMilli(curRange.getValue()) + ) + ); + } + } + + List concatenatedDataSample = null; + // make sure the following logic making sense via checking lastRoundFirstStartTime > 0 + if (lastRounddataSample != null && lastRounddataSample.size() > 0) { + concatenatedDataSample = new ArrayList<>(); + concatenatedDataSample.addAll(lastRounddataSample); + concatenatedDataSample.addAll(samples); + } else { + concatenatedDataSample = samples; + } + + // If the first round of probe provides numMinSamples points (note that if S0 is + // missing or all Si​ for some i > N is missing then we would miss a lot of points. + // Otherwise we can issue another round of query — if there is any sample in the + // second round then we would have numMinSamples points. If there is no sample + // in the second round then we should wait for real data. + // Note that even though we have imputation built in rcf, it is beneficial to require + // more samples at least during cold start. Garbage in, garbage out. + if (concatenatedDataSample.size() >= numMinSamples || round + 1 >= maxRoundofColdStart) { + listener.onResponse(concatenatedDataSample); + } else { + // the earliest sample's start time is the endTimeMs of next round of probe. + long earliestSampleStartTime = sampleRanges.get(0).getKey(); + getFeatures( + listener, + round + 1, + concatenatedDataSample, + config, + entity, + numberOfSamples, + startTimeMs, + earliestSampleStartTime + ); + } + }, listener::onFailure); + + try { + searchFeatureDao + .getColdStartSamplesForPeriods( + config, + sampleRanges, + entity, + // Accept empty bucket. + // 0, as returned by the engine should constitute a valid answer, “null” is a missing answer — it may be that 0 + // is meaningless in some case, but 0 is also meaningful in some cases. It may be that the query defining the + // metric is ill-formed, but that cannot be solved by cold-start strategy of the AD plugin — if we attempt to do + // that, we will have issues with legitimate interpretations of 0. + true, + context, + new ThreadedActionListener<>(logger, threadPool, threadPoolName, getFeaturelistener, false) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Get train samples within a time range. + * + * @param config accessor to config + * @param startMilli range start + * @param endMilli range end + * @param numberOfSamples maximum training samples to fetch + * @return list of sample time ranges in ascending order + */ + private List> getTrainSampleRanges(Config config, long startMilli, long endMilli, int numberOfSamples) { + long bucketSize = ((IntervalTimeConfiguration) config.getInterval()).toDuration().toMillis(); + int numBuckets = (int) Math.floor((endMilli - startMilli) / (double) bucketSize); + // adjust if numStrides is more than the max samples + int numIntervals = Math.min(numBuckets, numberOfSamples); + List> sampleRanges = Stream + .iterate(endMilli, i -> i - bucketSize) + .limit(numIntervals) + .map(time -> new SimpleImmutableEntry<>(time - bucketSize, time)) + .collect(Collectors.toList()); + + // Reverse the list to get time ranges in ascending order + Collections.reverse(sampleRanges); + + return sampleRanges; + } + + // Method to apply imputation method based on the imputation option + public static > T applyImputationMethod(Config config, T builder) { + ImputationOption imputationOption = config.getImputationOption(); + if (imputationOption == null) { + // by default using last known value + return builder.imputationMethod(ImputationMethod.PREVIOUS); + } else { + switch (imputationOption.getMethod()) { + case ZERO: + return builder.imputationMethod(ImputationMethod.ZERO); + case FIXED_VALUES: + // we did validate default fill is not empty and size matches enabled feature number in Config's constructor + return builder.imputationMethod(ImputationMethod.FIXED_VALUES).fillValues(imputationOption.getDefaultFill().get()); + case PREVIOUS: + return builder.imputationMethod(ImputationMethod.PREVIOUS); + case LINEAR: + return builder.imputationMethod(ImputationMethod.LINEAR); + default: + // by default using last known value + return builder.imputationMethod(ImputationMethod.PREVIOUS); + } + } + } + + protected abstract List trainModelFromDataSegments( + List dataPoints, + Optional entity, + ModelState state, + Config config, + String taskId + ); +} diff --git a/src/main/java/org/opensearch/timeseries/ml/ModelManager.java b/src/main/java/org/opensearch/timeseries/ml/ModelManager.java new file mode 100644 index 000000000..8d1ee996c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/ModelManager.java @@ -0,0 +1,188 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.time.Clock; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class ModelManager, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart> { + + private static final Logger LOG = LogManager.getLogger(ModelManager.class); + + public enum ModelType { + RCF("rcf"), + THRESHOLD("threshold"), + TRCF("trcf"), + RCFCASTER("rcf_caster"); + + private String name; + + ModelType(String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + protected final int rcfNumTrees; + protected final int rcfNumSamplesInTree; + protected final int rcfNumMinSamples; + protected ColdStarterType coldStarter; + protected MemoryTracker memoryTracker; + protected final Clock clock; + protected FeatureManager featureManager; + protected final CheckpointDaoType checkpointDao; + + public ModelManager( + int rcfNumTrees, + int rcfNumSamplesInTree, + int rcfNumMinSamples, + ColdStarterType coldStarter, + MemoryTracker memoryTracker, + Clock clock, + FeatureManager featureManager, + CheckpointDaoType checkpointDao + ) { + this.rcfNumTrees = rcfNumTrees; + this.rcfNumSamplesInTree = rcfNumSamplesInTree; + this.rcfNumMinSamples = rcfNumMinSamples; + this.coldStarter = coldStarter; + this.memoryTracker = memoryTracker; + this.clock = clock; + this.featureManager = featureManager; + this.checkpointDao = checkpointDao; + } + + public IntermediateResultType getResult( + Sample sample, + ModelState modelState, + String modelId, + Optional entity, + Config config, + String taskId + ) { + IntermediateResultType result = createEmptyResult(); + if (modelState != null) { + Optional entityModel = modelState.getModel(); + + if (entityModel.isEmpty()) { + coldStarter.trainModelFromExistingSamples(modelState, entity, config, taskId); + } + + if (modelState.getModel().isPresent()) { + result = score(sample, modelId, modelState, config); + } else { + modelState.addSample(sample); + } + } + return result; + } + + public void clearModels(String detectorId, Map models, ActionListener listener) { + Iterator id = models.keySet().iterator(); + clearModelForIterator(detectorId, models, id, listener); + } + + protected void clearModelForIterator(String detectorId, Map models, Iterator idIter, ActionListener listener) { + if (idIter.hasNext()) { + String modelId = idIter.next(); + if (SingleStreamModelIdMapper.getConfigIdForModelId(modelId).equals(detectorId)) { + models.remove(modelId); + checkpointDao + .deleteModelCheckpoint( + modelId, + ActionListener.wrap(r -> clearModelForIterator(detectorId, models, idIter, listener), listener::onFailure) + ); + } else { + clearModelForIterator(detectorId, models, idIter, listener); + } + } else { + listener.onResponse(null); + } + } + + @SuppressWarnings("unchecked") + public IntermediateResultType score( + Sample sample, + String modelId, + ModelState modelState, + Config config + ) { + + IntermediateResultType result = createEmptyResult(); + Optional model = modelState.getModel(); + try { + if (model != null && model.isPresent()) { + RCFModelType rcfModel = model.get(); + + if (!modelState.getSamples().isEmpty()) { + for (Sample unProcessedSample : modelState.getSamples()) { + // we are sure that the process method will indeed return an instance of RCFDescriptor. + rcfModel.process(unProcessedSample.getValueList(), unProcessedSample.getDataEndTime().getEpochSecond()); + } + modelState.clearSamples(); + } + + RCFDescriptor lastResult = (RCFDescriptor) rcfModel + .process(sample.getValueList(), sample.getDataEndTime().getEpochSecond()); + + if (lastResult != null) { + result = toResult(rcfModel.getForest(), lastResult); + } + } + } catch (Exception e) { + LOG + .error( + new ParameterizedMessage( + "Fail to score for [{}]: model Id [{}], feature [{}]", + modelState.getEntity().isEmpty() ? modelState.getConfigId() : modelState.getEntity().get(), + modelId, + Arrays.toString(sample.getValueList()) + ), + e + ); + throw e; + } finally { + modelState.setLastUsedTime(clock.instant()); + } + return result; + } + + protected abstract IntermediateResultType createEmptyResult(); + + protected abstract IntermediateResultType toResult( + RandomCutForest forecast, + RCFDescriptor castDescriptor + ); +} diff --git a/src/main/java/org/opensearch/ad/ml/ModelState.java b/src/main/java/org/opensearch/timeseries/ml/ModelState.java similarity index 63% rename from src/main/java/org/opensearch/ad/ml/ModelState.java rename to src/main/java/org/opensearch/timeseries/ml/ModelState.java index bb9050ecb..e2f914e8f 100644 --- a/src/main/java/org/opensearch/ad/ml/ModelState.java +++ b/src/main/java/org/opensearch/timeseries/ml/ModelState.java @@ -9,92 +9,84 @@ * GitHub history for details. */ -package org.opensearch.ad.ml; +package org.opensearch.timeseries.ml; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.ArrayDeque; +import java.util.Deque; import java.util.HashMap; import java.util.Map; +import java.util.Optional; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.timeseries.ExpiringState; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; -/** - * A ML model and states such as usage. - */ -public class ModelState implements ExpiringState { - +public class ModelState implements org.opensearch.timeseries.ExpiringState { public static String MODEL_TYPE_KEY = "model_type"; public static String LAST_USED_TIME_KEY = "last_used_time"; public static String LAST_CHECKPOINT_TIME_KEY = "last_checkpoint_time"; public static String PRIORITY_KEY = "priority"; - private T model; - private String modelId; - private String detectorId; - private String modelType; + + protected T model; + protected String modelId; + protected String configId; + protected String modelType; // time when the ML model was used last time - private Instant lastUsedTime; - private Instant lastCheckpointTime; - private Clock clock; - private float priority; + protected Instant lastUsedTime; + protected Instant lastCheckpointTime; + protected Clock clock; + protected float priority; + protected Deque samples; + protected Optional entity; /** * Constructor. * * @param model ML model * @param modelId Id of model partition - * @param detectorId Id of detector this model partition is used for + * @param configId Id of analysis this model partition is used for * @param modelType type of model * @param clock UTC clock * @param priority Priority of the model state. Used in multi-entity detectors' cache. + * @param entity Entity info if this is a HC entity state + * @param samples existing samples that haven't been processed */ - public ModelState(T model, String modelId, String detectorId, String modelType, Clock clock, float priority) { + public ModelState( + T model, + String modelId, + String configId, + String modelType, + Clock clock, + float priority, + Optional entity, + Deque samples + ) { this.model = model; this.modelId = modelId; - this.detectorId = detectorId; + this.configId = configId; this.modelType = modelType; this.lastUsedTime = clock.instant(); // this is inaccurate until we find the last checkpoint time from disk this.lastCheckpointTime = Instant.MIN; this.clock = clock; this.priority = priority; + this.entity = entity; + this.samples = samples; } /** - * Create state with zero priority. Used in single-entity detector. + * Constructor. Used in single-stream analysis. * - * @param Model object's type - * @param model The actual model object - * @param modelId Model Id - * @param detectorId Detector Id - * @param modelType Model type like RCF model + * @param model ML model + * @param modelId Id of model partition + * @param configId Id of analysis this model partition is used for + * @param modelType type of model * @param clock UTC clock - * - * @return the created model state - */ - public static ModelState createSingleEntityModelState( - T model, - String modelId, - String detectorId, - String modelType, - Clock clock - ) { - return new ModelState<>(model, modelId, detectorId, modelType, clock, 0f); - } - - /** - * Returns the ML model. - * - * @return the ML model. */ - public T getModel() { - return this.model; - } - - public void setModel(T model) { - this.model = model; + public ModelState(T model, String modelId, String configId, String modelType, Clock clock) { + this(model, modelId, configId, modelType, clock, 0, Optional.empty(), new ArrayDeque<>()); } /** @@ -106,15 +98,6 @@ public String getModelId() { return modelId; } - /** - * Gets the detectorID of the model - * - * @return detectorId associated with the model - */ - public String getId() { - return detectorId; - } - /** * Gets the type of the model * @@ -172,16 +155,81 @@ public void setPriority(float priority) { this.priority = priority; } + @Override + public boolean expired(Duration stateTtl) { + return expired(lastUsedTime, stateTtl, clock.instant()); + } + + /** + * Gets the Config ID of the model + * + * @return the config id associated with the model + */ + public String getConfigId() { + return configId; + } + + /** + * In old checkpoint mapping, we don't have entity. It's fine we are missing + * entity as it is mostly used for debugging. + * @return entity + */ + public Optional getEntity() { + return entity; + } + + public Deque getSamples() { + return this.samples; + } + + public void addSample(Sample sample) { + if (this.samples == null) { + this.samples = new ArrayDeque<>(); + } + if (sample != null && sample.getValueList() != null && sample.getValueList().length != 0) { + this.samples.add(sample); + } + } + + /** + * Sets a model. + * + * @param model model instance + */ + public void setModel(T model) { + this.model = model; + } + + /** + * + * @return optional model. + */ + public Optional getModel() { + return Optional.ofNullable(this.model); + } + + public void clearSamples() { + if (samples != null) { + samples.clear(); + } + } + + public void clear() { + clearSamples(); + model = null; + } + /** * Gets the Model State as a map * * @return Map of ModelStates */ + @SuppressWarnings("serial") public Map getModelStateAsMap() { return new HashMap() { { put(CommonName.MODEL_ID_FIELD, modelId); - put(ADCommonName.DETECTOR_ID_KEY, detectorId); + put(CommonName.CONFIG_ID_KEY, configId); put(MODEL_TYPE_KEY, modelType); /* A stats API broadcasts requests to all nodes and renders node responses using toXContent. * @@ -195,18 +243,10 @@ public Map getModelStateAsMap() { if (lastCheckpointTime != Instant.MIN) { put(LAST_CHECKPOINT_TIME_KEY, lastCheckpointTime.toEpochMilli()); } - if (model != null && model instanceof EntityModel) { - EntityModel summary = (EntityModel) model; - if (summary.getEntity().isPresent()) { - put(CommonName.ENTITY_KEY, summary.getEntity().get().toStat()); - } + if (entity.isPresent()) { + put(CommonName.ENTITY_KEY, entity.get().toStat()); } } }; } - - @Override - public boolean expired(Duration stateTtl) { - return expired(lastUsedTime, stateTtl, clock.instant()); - } } diff --git a/src/main/java/org/opensearch/timeseries/ml/Sample.java b/src/main/java/org/opensearch/timeseries/ml/Sample.java new file mode 100644 index 000000000..bc1212596 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/Sample.java @@ -0,0 +1,143 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.constant.CommonName; + +import com.google.common.base.Objects; + +public class Sample implements ToXContentObject { + private final double[] data; + private final Instant dataStartTime; + private final Instant dataEndTime; + + public Sample(double[] data, Instant dataStartTime, Instant dataEndTime) { + super(); + this.data = data; + this.dataStartTime = dataStartTime; + this.dataEndTime = dataEndTime; + } + + // Invalid sample + public Sample() { + this.data = new double[0]; + this.dataStartTime = this.dataEndTime = Instant.MIN; + } + + public double[] getValueList() { + return data; + } + + public Instant getDataStartTime() { + return dataStartTime; + } + + public Instant getDataEndTime() { + return dataEndTime; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (data != null) { + xContentBuilder.array(CommonName.VALUE_LIST_FIELD, data); + } + if (dataStartTime != null && dataStartTime != Instant.MIN) { + xContentBuilder.field(CommonName.DATA_START_TIME_FIELD, dataStartTime.toEpochMilli()); + } + if (dataEndTime != null && dataEndTime != Instant.MIN) { + xContentBuilder.field(CommonName.DATA_END_TIME_FIELD, dataEndTime.toEpochMilli()); + } + return xContentBuilder.endObject(); + } + + /** + * Extract Sample fields out of a serialized Map, which is what we get from a get checkpoint call. + * @param map serialized sample. + * Example input map: + * Key: last_processed_sample, Value type: java.util.HashMap + * Key: data_end_time, Value type: java.lang.Long + * Value: 1695825364700, Type: java.lang.Long + * Key: data_start_time, Value type: java.lang.Long + * Value: 1695825304700, Type: java.lang.Long + * Key: value_list, Value type: java.util.ArrayList + * Item type: java.lang.Double + * Value: 8840.0, Type: java.lang.Double + * @return a Sample. + */ + public static Sample extractSample(Map map) { + // Extract and convert values from the map + Long dataEndTimeLong = (Long) map.get(CommonName.DATA_END_TIME_FIELD); + Long dataStartTimeLong = (Long) map.get(CommonName.DATA_START_TIME_FIELD); + List valueList = (List) map.get(CommonName.VALUE_LIST_FIELD); + + // Check if all required keys are present in the map + if (dataEndTimeLong == null && dataStartTimeLong == null && valueList == null) { + return null; + } + + // Convert List to double[] + double[] data = valueList.stream().mapToDouble(Double::doubleValue).toArray(); + + // Convert long to Instant + Instant dataEndTime = Instant.ofEpochMilli(dataEndTimeLong); + Instant dataStartTime = Instant.ofEpochMilli(dataStartTimeLong); + + // Create a new Sample object and return it + return new Sample(data, dataStartTime, dataEndTime); + } + + public boolean isInvalid() { + return dataStartTime.compareTo(Instant.MIN) == 0 || dataEndTime.compareTo(Instant.MIN) == 0; + } + + @Override + public String toString() { + return "Sample [data=" + Arrays.toString(data) + ", dataStartTime=" + dataStartTime + ", dataEndTime=" + dataEndTime + "]"; + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Sample sample = (Sample) o; + // a few fields not included: + // 1)didn't include uiMetadata since toXContent/parse will produce a map of map + // and cause the parsed one not equal to the original one. This can be confusing. + // 2)didn't include id, schemaVersion, and lastUpdateTime as we deemed equality based on contents. + // Including id fails tests like AnomalyDetectorExecutionInput.testParseAnomalyDetectorExecutionInput. + return Arrays.equals(data, sample.data) + && dataStartTime.truncatedTo(ChronoUnit.MILLIS).equals(sample.dataStartTime.truncatedTo(ChronoUnit.MILLIS)) + && dataEndTime.truncatedTo(ChronoUnit.MILLIS).equals(sample.dataEndTime.truncatedTo(ChronoUnit.MILLIS)); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(data, dataStartTime.truncatedTo(ChronoUnit.MILLIS), dataEndTime.truncatedTo(ChronoUnit.MILLIS)); + } +} diff --git a/src/main/java/org/opensearch/timeseries/ml/SingleStreamModelIdMapper.java b/src/main/java/org/opensearch/timeseries/ml/SingleStreamModelIdMapper.java index c33c4818f..cf045f79d 100644 --- a/src/main/java/org/opensearch/timeseries/ml/SingleStreamModelIdMapper.java +++ b/src/main/java/org/opensearch/timeseries/ml/SingleStreamModelIdMapper.java @@ -22,9 +22,10 @@ * */ public class SingleStreamModelIdMapper { - protected static final String DETECTOR_ID_PATTERN = "(.*)_model_.+"; + protected static final String CONFIG_ID_PATTERN = "(.*)_model_.+"; protected static final String RCF_MODEL_ID_PATTERN = "%s_model_rcf_%d"; protected static final String THRESHOLD_MODEL_ID_PATTERN = "%s_model_threshold"; + protected static final String CASTER_MODEL_ID_PATTERN = "%s_model_caster"; /** * Returns the model ID for the RCF model partition. @@ -48,14 +49,24 @@ public static String getThresholdModelId(String detectorId) { } /** - * Gets the detector id from the model id. + * Returns the model ID for the rcf caster model. + * + * @param forecasterId ID of the forecaster for which the model is trained + * @return ID for the forecaster model + */ + public static String getCasterModelId(String forecasterId) { + return String.format(Locale.ROOT, CASTER_MODEL_ID_PATTERN, forecasterId); + } + + /** + * Gets the config id from the model id. * * @param modelId id of a model * @return id of the detector the model is for * @throws IllegalArgumentException if model id is invalid */ - public static String getDetectorIdForModelId(String modelId) { - Matcher matcher = Pattern.compile(DETECTOR_ID_PATTERN).matcher(modelId); + public static String getConfigIdForModelId(String modelId) { + Matcher matcher = Pattern.compile(CONFIG_ID_PATTERN).matcher(modelId); if (matcher.matches()) { return matcher.group(1); } else { @@ -70,7 +81,7 @@ public static String getDetectorIdForModelId(String modelId) { * @return thresholding model Id */ public static String getThresholdModelIdFromRCFModelId(String rcfModelId) { - String detectorId = getDetectorIdForModelId(rcfModelId); + String detectorId = getConfigIdForModelId(rcfModelId); return getThresholdModelId(detectorId); } } diff --git a/src/main/java/org/opensearch/timeseries/ml/TimeSeriesSingleStreamCheckpointDao.java b/src/main/java/org/opensearch/timeseries/ml/TimeSeriesSingleStreamCheckpointDao.java new file mode 100644 index 000000000..960234701 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/TimeSeriesSingleStreamCheckpointDao.java @@ -0,0 +1,16 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +public class TimeSeriesSingleStreamCheckpointDao { + +} diff --git a/src/main/java/org/opensearch/timeseries/model/Config.java b/src/main/java/org/opensearch/timeseries/model/Config.java index 15f67d116..52cd1fd45 100644 --- a/src/main/java/org/opensearch/timeseries/model/Config.java +++ b/src/main/java/org/opensearch/timeseries/model/Config.java @@ -11,7 +11,10 @@ import java.time.Duration; import java.time.Instant; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.Optional; +import java.util.function.Function; import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; @@ -32,14 +35,10 @@ import org.opensearch.timeseries.common.exception.ValidationException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.dataprocessor.FixedValueImputer; import org.opensearch.timeseries.dataprocessor.ImputationMethod; import org.opensearch.timeseries.dataprocessor.ImputationOption; -import org.opensearch.timeseries.dataprocessor.Imputer; -import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; -import org.opensearch.timeseries.dataprocessor.PreviousValueImputer; -import org.opensearch.timeseries.dataprocessor.ZeroImputer; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.owasp.encoder.Encode; import com.google.common.base.Objects; import com.google.common.collect.ImmutableList; @@ -71,11 +70,9 @@ public abstract class Config implements Writeable, ToXContentObject { public static final String USER_FIELD = "user"; public static final String RESULT_INDEX_FIELD = "result_index"; public static final String IMPUTATION_OPTION_FIELD = "imputation_option"; - - private static final Imputer zeroImputer; - private static final Imputer previousImputer; - private static final Imputer linearImputer; - private static final Imputer linearImputerIntegerSensitive; + public static final String SEASONALITY_FIELD = "suggested_seasonality"; + public static final String RECENCY_EMPHASIS_FIELD = "recency_emphasis"; + public static final String HISTORY_INTERVAL_FIELD = "history"; protected String id; protected Long version; @@ -95,24 +92,28 @@ public abstract class Config implements Writeable, ToXContentObject { protected List categoryFields; protected User user; protected ImputationOption imputationOption; + // Aggregation period to smooth the emphasis on the most recent data. Aggregation period to smooth + // the emphasis of the most recent data. Useful for determining short/long term trends. Can be used + // similar to moving average computation https://en.wikipedia.org/wiki/Moving_average + // Recency emphasis is the average number of steps that a point will be included in the sample. + // Call the number of steps that a point is included in the sample the "lifetime" of the point + // (which may be 0). Over a finite time window, the distribution of the lifetime of a point is + // approximately exponential with parameter lambda. In an exponential distribution, the average + // is the reciprocal of the rate parameter (λ). Thus, 1 / timmeDecay is approximately the + // average number of steps that a point will be included in the sample. + protected Integer recencyEmphasis; // validation error protected String errorMessage; protected ValidationIssueType issueType; - protected Imputer imputer; + protected Integer seasonIntervals; + protected Integer historyIntervals; public static String INVALID_RESULT_INDEX_NAME_SIZE = "Result index name size must contains less than " + MAX_RESULT_INDEX_NAME_SIZE + " characters"; - static { - zeroImputer = new ZeroImputer(); - previousImputer = new PreviousValueImputer(); - linearImputer = new LinearUniformImputer(false); - linearImputerIntegerSensitive = new LinearUniformImputer(true); - } - protected Config( String id, Long version, @@ -131,7 +132,11 @@ protected Config( User user, String resultIndex, TimeConfiguration interval, - ImputationOption imputationOption + ImputationOption imputationOption, + Integer recencyEmphasis, + Integer seasonIntervals, + ShingleGetter shingleGetter, + Integer historyIntervals ) { if (Strings.isBlank(name)) { errorMessage = CommonMessages.EMPTY_NAME; @@ -158,17 +163,68 @@ protected Config( return; } + if (invalidSeasonality(seasonIntervals)) { + errorMessage = "Suggested seasonality must be a positive integer no larger than " + + TimeSeriesSettings.MAX_SHINGLE_SIZE * 2 + + ". Got " + + seasonIntervals; + issueType = ValidationIssueType.SUGGESTED_SEASONALITY_FIELD; + return; + } + errorMessage = validateCustomResultIndex(resultIndex); if (errorMessage != null) { issueType = ValidationIssueType.RESULT_INDEX; return; } - if (imputationOption != null - && imputationOption.getMethod() == ImputationMethod.FIXED_VALUES - && imputationOption.getDefaultFill().isEmpty()) { - issueType = ValidationIssueType.IMPUTATION; - errorMessage = "No given values for fixed value interpolation"; + if (imputationOption != null && imputationOption.getMethod() == ImputationMethod.FIXED_VALUES) { + Optional defaultFill = imputationOption.getDefaultFill(); + if (defaultFill.isEmpty()) { + issueType = ValidationIssueType.IMPUTATION; + errorMessage = "No given values for fixed value interpolation"; + return; + } + + // Calculate the number of enabled features + long expectedFeatures = features.stream().filter(Feature::getEnabled).count(); + + // Check if the length of the defaultFill array matches the number of expected features + if (defaultFill.get().length != expectedFeatures) { + issueType = ValidationIssueType.IMPUTATION; + errorMessage = String + .format( + Locale.ROOT, + "Incorrect number of values to fill. Got: %d. Expected: %d.", + defaultFill.get().length, + expectedFeatures + ); + return; + } + } + + if (recencyEmphasis != null && (recencyEmphasis <= 0)) { + issueType = ValidationIssueType.RECENCY_EMPHASIS; + errorMessage = "recency emphasis has to be a positive integer"; + return; + } + + errorMessage = validateDescription(description); + if (errorMessage != null) { + issueType = ValidationIssueType.DESCRIPTION; + return; + } + + if (historyIntervals != null && (historyIntervals <= 0 || historyIntervals > TimeSeriesSettings.MAX_HISTORY_INTERVALS)) { + issueType = ValidationIssueType.HISTORY; + errorMessage = "We cannot look back more than " + TimeSeriesSettings.MAX_HISTORY_INTERVALS + " intervals."; + return; + } + + List redundantNames = findRedundantNames(features); + if (redundantNames.size() > 0) { + issueType = ValidationIssueType.FEATURE_ATTRIBUTES; + errorMessage = redundantNames + " appear more than once. Feature name has to be unique"; return; } @@ -182,7 +238,7 @@ protected Config( this.filterQuery = filterQuery; this.interval = interval; this.windowDelay = windowDelay; - this.shingleSize = getShingleSize(shingleSize); + this.shingleSize = shingleGetter.getShingleSize(shingleSize); this.uiMetadata = uiMetadata; this.schemaVersion = schemaVersion; this.lastUpdateTime = lastUpdateTime; @@ -190,9 +246,16 @@ protected Config( this.user = user; this.customResultIndex = Strings.trimToNull(resultIndex); this.imputationOption = imputationOption; - this.imputer = createImputer(); this.issueType = null; this.errorMessage = null; + // If recencyEmphasis is null, use the default value from TimeSeriesSettings + this.recencyEmphasis = Optional.ofNullable(recencyEmphasis).orElse(TimeSeriesSettings.DEFAULT_RECENCY_EMPHASIS); + this.seasonIntervals = seasonIntervals; + this.historyIntervals = historyIntervals == null ? suggestHistory() : historyIntervals; + } + + public int suggestHistory() { + return TimeSeriesSettings.NUM_MIN_SAMPLES + this.shingleSize; } public Config(StreamInput input) throws IOException { @@ -226,7 +289,9 @@ public Config(StreamInput input) throws IOException { } else { this.imputationOption = null; } - this.imputer = createImputer(); + this.recencyEmphasis = input.readInt(); + this.seasonIntervals = input.readInt(); + this.historyIntervals = input.readInt(); } /* @@ -236,9 +301,7 @@ public Config(StreamInput input) throws IOException { * "Implicit super constructor Config() is undefined. * Must explicitly invoke another constructor". */ - public Config() { - this.imputer = null; - } + public Config() {} @Override public void writeTo(StreamOutput output) throws IOException { @@ -275,23 +338,25 @@ public void writeTo(StreamOutput output) throws IOException { } else { output.writeBoolean(false); } - } - - /** - * If the given shingle size is null, return default; - * otherwise, return the given shingle size. - * - * @param customShingleSize Given shingle size - * @return Shingle size - */ - protected static Integer getShingleSize(Integer customShingleSize) { - return customShingleSize == null ? TimeSeriesSettings.DEFAULT_SHINGLE_SIZE : customShingleSize; + output.writeInt(recencyEmphasis); + output.writeInt(seasonIntervals); + output.writeInt(historyIntervals); } public boolean invalidShingleSizeRange(Integer shingleSizeToTest) { return shingleSizeToTest != null && (shingleSizeToTest < 1 || shingleSizeToTest > TimeSeriesSettings.MAX_SHINGLE_SIZE); } + public boolean invalidSeasonality(Integer seasonalityToTest) { + if (seasonalityToTest == null) { + return false; + } + // shingle size = suggested seasonality / 2 + // given seasonality, we can reuse shingle size verification + // cannot be smaller than 1 + return invalidShingleSizeRange(Math.max(1, seasonalityToTest / TimeSeriesSettings.SEASONALITY_TO_SHINGLE_RATIO)); + } + /** * * @return either ValidationAspect.FORECASTER or ValidationAspect.DETECTOR @@ -302,10 +367,12 @@ public boolean invalidShingleSizeRange(Integer shingleSizeToTest) { @Generated @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } Config config = (Config) o; // a few fields not included: // 1)didn't include uiMetadata since toXContent/parse will produce a map of map @@ -324,7 +391,10 @@ public boolean equals(Object o) { && Objects.equal(categoryFields, config.categoryFields) && Objects.equal(user, config.user) && Objects.equal(customResultIndex, config.customResultIndex) - && Objects.equal(imputationOption, config.imputationOption); + && Objects.equal(imputationOption, config.imputationOption) + && Objects.equal(recencyEmphasis, config.recencyEmphasis) + && Objects.equal(seasonIntervals, config.seasonIntervals) + && Objects.equal(historyIntervals, config.historyIntervals); } @Generated @@ -345,7 +415,10 @@ public int hashCode() { schemaVersion, user, customResultIndex, - imputationOption + imputationOption, + recencyEmphasis, + seasonIntervals, + historyIntervals ); } @@ -353,14 +426,16 @@ public int hashCode() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder .field(NAME_FIELD, name) - .field(DESCRIPTION_FIELD, description) + .field(DESCRIPTION_FIELD, Encode.forHtml(description)) .field(TIMEFIELD_FIELD, timeField) .field(INDICES_FIELD, indices.toArray()) .field(FILTER_QUERY_FIELD, filterQuery) .field(WINDOW_DELAY_FIELD, windowDelay) .field(SHINGLE_SIZE_FIELD, shingleSize) .field(CommonName.SCHEMA_VERSION_FIELD, schemaVersion) - .field(FEATURE_ATTRIBUTES_FIELD, featureAttributes.toArray()); + .field(FEATURE_ATTRIBUTES_FIELD, featureAttributes.toArray()) + .field(RECENCY_EMPHASIS_FIELD, recencyEmphasis) + .field(HISTORY_INTERVAL_FIELD, historyIntervals); if (uiMetadata != null && !uiMetadata.isEmpty()) { builder.field(UI_METADATA_FIELD, uiMetadata); @@ -380,6 +455,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (imputationOption != null) { builder.field(IMPUTATION_OPTION_FIELD, imputationOption); } + if (seasonIntervals != null) { + builder.field(SEASONALITY_FIELD, seasonIntervals); + } return builder; } @@ -505,6 +583,16 @@ public String validateCustomResultIndex(String resultIndex) { return null; } + public String validateDescription(String description) { + if (Strings.isEmpty(description)) { + return null; + } + if (description.length() > TimeSeriesSettings.MAX_DESCRIPTION_LENGTH) { + return CommonMessages.DESCRIPTION_LENGTH_TOO_LONG; + } + return null; + } + public static boolean isHC(List categoryFields) { return categoryFields != null && categoryFields.size() > 0; } @@ -513,46 +601,19 @@ public ImputationOption getImputationOption() { return imputationOption; } - public Imputer getImputer() { - if (imputer != null) { - return imputer; - } - imputer = createImputer(); - return imputer; - } - - protected Imputer createImputer() { - Imputer imputer = null; - - // default interpolator is using last known value - if (imputationOption == null) { - return previousImputer; - } - - switch (imputationOption.getMethod()) { - case ZERO: - imputer = zeroImputer; - break; - case FIXED_VALUES: - // we did validate default fill is not empty in the constructor - imputer = new FixedValueImputer(imputationOption.getDefaultFill().get()); - break; - case PREVIOUS: - imputer = previousImputer; - break; - case LINEAR: - if (imputationOption.isIntegerSentive()) { - imputer = linearImputerIntegerSensitive; - } else { - imputer = linearImputer; - } - break; - default: - logger.error("unsupported method: " + imputationOption.getMethod()); - imputer = new PreviousValueImputer(); - break; - } - return imputer; + /** + * Retrieves the transform decay value. + * + * This method implements an inverse relationship between the recency emphasis and the transform decay value, + * such that the transform decay is set to 1 / recency emphasis. For example, a transform decay of 0.02 + * implies a recency emphasis of 50 observations (1/0.02). + * + * The transform decay value is crucial in determining the rate at which older data loses its influence in the model. + * + * @return The current transform decay value, dictating the rate of exponential decay in the model. + */ + public Double getTimeDecay() { + return 1.0 / recencyEmphasis; } protected void checkAndThrowValidationErrors(ValidationAspect validationAspect) { @@ -572,4 +633,40 @@ public static Config parseConfig(Class configClass, XContentPa throw new IllegalArgumentException("Unsupported config type. Supported config types are [AnomalyDetector, Forecaster]"); } } + + public Integer getSeasonIntervals() { + return seasonIntervals; + } + + public Integer getRecencyEmphasis() { + return recencyEmphasis; + } + + public Integer getHistoryIntervals() { + return historyIntervals; + } + + /** + * Identifies redundant feature names. + * + * @param features the list of features to check + * @return a list of redundant feature names + */ + public static List findRedundantNames(List features) { + // Group features by name and count occurrences + Map nameCounts = features + .stream() + .map(Feature::getName) + .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); + + // Filter names that appear more than once and collect them into a list + List redundantNames = nameCounts + .entrySet() + .stream() + .filter(entry -> entry.getValue() > 1) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + + return redundantNames; + } } diff --git a/src/main/java/org/opensearch/timeseries/model/ConfigProfile.java b/src/main/java/org/opensearch/timeseries/model/ConfigProfile.java new file mode 100644 index 000000000..833f33765 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/ConfigProfile.java @@ -0,0 +1,453 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.model; + +import java.io.IOException; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.TaskProfile; +import org.opensearch.timeseries.constant.CommonName; + +public abstract class ConfigProfile> + implements + Writeable, + ToXContentObject, + Mergeable { + + protected ConfigState state; + protected String error; + protected ModelProfileOnNode[] modelProfile; + protected int shingleSize; + protected String coordinatingNode; + protected long totalSizeInBytes; + protected InitProgressProfile initProgress; + protected Long totalEntities; + protected Long activeEntities; + protected TaskProfileType taskProfile; + protected long modelCount; + protected String taskName; + + public ConfigProfile(StreamInput in) throws IOException { + if (in.readBoolean()) { + this.state = in.readEnum(ConfigState.class); + } + + this.error = in.readOptionalString(); + this.modelProfile = in.readOptionalArray(ModelProfileOnNode::new, ModelProfileOnNode[]::new); + this.shingleSize = in.readOptionalInt(); + this.coordinatingNode = in.readOptionalString(); + this.totalSizeInBytes = in.readOptionalLong(); + this.totalEntities = in.readOptionalLong(); + this.activeEntities = in.readOptionalLong(); + if (in.readBoolean()) { + this.initProgress = new InitProgressProfile(in); + } + if (in.readBoolean()) { + this.taskProfile = createTaskProfile(in); + } + this.modelCount = in.readVLong(); + } + + protected ConfigProfile() { + + } + + public XContentBuilder toXContent(XContentBuilder builder) throws IOException { + return toXContent(builder, ToXContent.EMPTY_PARAMS); + } + + public static abstract class Builder> { + protected ConfigState state = null; + protected String error = null; + protected ModelProfileOnNode[] modelProfile = null; + protected int shingleSize = -1; + protected String coordinatingNode = null; + protected long totalSizeInBytes = -1; + protected InitProgressProfile initProgress = null; + protected Long totalEntities; + protected Long activeEntities; + protected long modelCount = 0; + + public Builder() {} + + public Builder state(ConfigState state) { + this.state = state; + return this; + } + + public Builder error(String error) { + this.error = error; + return this; + } + + public Builder modelProfile(ModelProfileOnNode[] modelProfile) { + this.modelProfile = modelProfile; + return this; + } + + public Builder modelCount(long modelCount) { + this.modelCount = modelCount; + return this; + } + + public Builder shingleSize(int shingleSize) { + this.shingleSize = shingleSize; + return this; + } + + public Builder coordinatingNode(String coordinatingNode) { + this.coordinatingNode = coordinatingNode; + return this; + } + + public Builder totalSizeInBytes(long totalSizeInBytes) { + this.totalSizeInBytes = totalSizeInBytes; + return this; + } + + public Builder initProgress(InitProgressProfile initProgress) { + this.initProgress = initProgress; + return this; + } + + public Builder totalEntities(Long totalEntities) { + this.totalEntities = totalEntities; + return this; + } + + public Builder activeEntities(Long activeEntities) { + this.activeEntities = activeEntities; + return this; + } + + public abstract Builder taskProfile(TaskProfileType taskProfile); + + public abstract > ConfigProfileType build(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (state == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeEnum(state); + } + + out.writeOptionalString(error); + out.writeOptionalArray(modelProfile); + out.writeOptionalInt(shingleSize); + out.writeOptionalString(coordinatingNode); + out.writeOptionalLong(totalSizeInBytes); + out.writeOptionalLong(totalEntities); + out.writeOptionalLong(activeEntities); + if (initProgress == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + initProgress.writeTo(out); + } + if (taskProfile == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + taskProfile.writeTo(out); + } + out.writeVLong(modelCount); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + + if (state != null) { + xContentBuilder.field(CommonName.STATE, state); + } + if (error != null) { + xContentBuilder.field(CommonName.ERROR, error); + } + if (modelProfile != null && modelProfile.length > 0) { + xContentBuilder.startArray(CommonName.MODELS); + for (ModelProfileOnNode profile : modelProfile) { + profile.toXContent(xContentBuilder, params); + } + xContentBuilder.endArray(); + } + if (shingleSize != -1) { + xContentBuilder.field(CommonName.SHINGLE_SIZE, shingleSize); + } + if (coordinatingNode != null && !coordinatingNode.isEmpty()) { + xContentBuilder.field(CommonName.COORDINATING_NODE, coordinatingNode); + } + if (totalSizeInBytes != -1) { + xContentBuilder.field(CommonName.TOTAL_SIZE_IN_BYTES, totalSizeInBytes); + } + if (initProgress != null) { + xContentBuilder.field(CommonName.INIT_PROGRESS, initProgress); + } + if (totalEntities != null) { + xContentBuilder.field(CommonName.TOTAL_ENTITIES, totalEntities); + } + if (activeEntities != null) { + xContentBuilder.field(CommonName.ACTIVE_ENTITIES, activeEntities); + } + if (taskProfile != null) { + xContentBuilder.field(getTaskFieldName(), taskProfile); + } + if (modelCount > 0) { + xContentBuilder.field(CommonName.MODEL_COUNT, modelCount); + } + return xContentBuilder.endObject(); + } + + public ConfigState getState() { + return state; + } + + public void setState(ConfigState state) { + this.state = state; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public ModelProfileOnNode[] getModelProfile() { + return modelProfile; + } + + public void setModelProfile(ModelProfileOnNode[] modelProfile) { + this.modelProfile = modelProfile; + } + + public int getShingleSize() { + return shingleSize; + } + + public void setShingleSize(int shingleSize) { + this.shingleSize = shingleSize; + } + + public String getCoordinatingNode() { + return coordinatingNode; + } + + public void setCoordinatingNode(String coordinatingNode) { + this.coordinatingNode = coordinatingNode; + } + + public long getTotalSizeInBytes() { + return totalSizeInBytes; + } + + public void setTotalSizeInBytes(long totalSizeInBytes) { + this.totalSizeInBytes = totalSizeInBytes; + } + + public InitProgressProfile getInitProgress() { + return initProgress; + } + + public void setInitProgress(InitProgressProfile initProgress) { + this.initProgress = initProgress; + } + + public Long getTotalEntities() { + return totalEntities; + } + + public void setTotalEntities(Long totalEntities) { + this.totalEntities = totalEntities; + } + + public Long getActiveEntities() { + return activeEntities; + } + + public void setActiveEntities(Long activeEntities) { + this.activeEntities = activeEntities; + } + + public TaskProfileType getTaskProfile() { + return taskProfile; + } + + public void setTaskProfile(TaskProfileType taskProfile) { + this.taskProfile = taskProfile; + } + + public long getModelCount() { + return modelCount; + } + + public void setModelCount(long modelCount) { + this.modelCount = modelCount; + } + + @Override + public void merge(Mergeable other) { + if (this == other || other == null || getClass() != other.getClass()) { + return; + } + ConfigProfile otherProfile = (ConfigProfile) other; + if (otherProfile.getState() != null) { + this.state = otherProfile.getState(); + } + if (otherProfile.getError() != null) { + this.error = otherProfile.getError(); + } + if (otherProfile.getCoordinatingNode() != null) { + this.coordinatingNode = otherProfile.getCoordinatingNode(); + } + if (otherProfile.getShingleSize() != -1) { + this.shingleSize = otherProfile.getShingleSize(); + } + if (otherProfile.getModelProfile() != null) { + this.modelProfile = otherProfile.getModelProfile(); + } + if (otherProfile.getTotalSizeInBytes() != -1) { + this.totalSizeInBytes = otherProfile.getTotalSizeInBytes(); + } + if (otherProfile.getInitProgress() != null) { + this.initProgress = otherProfile.getInitProgress(); + } + if (otherProfile.getTotalEntities() != null) { + this.totalEntities = otherProfile.getTotalEntities(); + } + if (otherProfile.getActiveEntities() != null) { + this.activeEntities = otherProfile.getActiveEntities(); + } + if (otherProfile.getTaskProfile() != null) { + this.taskProfile = otherProfile.getTaskProfile(); + } + if (otherProfile.getModelCount() > 0) { + this.modelCount = otherProfile.getModelCount(); + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof ConfigProfile) { + ConfigProfile other = (ConfigProfile) obj; + + EqualsBuilder equalsBuilder = new EqualsBuilder(); + if (state != null) { + equalsBuilder.append(state, other.state); + } + if (error != null) { + equalsBuilder.append(error, other.error); + } + if (modelProfile != null && modelProfile.length > 0) { + equalsBuilder.append(modelProfile, other.modelProfile); + } + if (shingleSize != -1) { + equalsBuilder.append(shingleSize, other.shingleSize); + } + if (coordinatingNode != null) { + equalsBuilder.append(coordinatingNode, other.coordinatingNode); + } + if (totalSizeInBytes != -1) { + equalsBuilder.append(totalSizeInBytes, other.totalSizeInBytes); + } + if (initProgress != null) { + equalsBuilder.append(initProgress, other.initProgress); + } + if (totalEntities != null) { + equalsBuilder.append(totalEntities, other.totalEntities); + } + if (activeEntities != null) { + equalsBuilder.append(activeEntities, other.activeEntities); + } + if (taskProfile != null) { + equalsBuilder.append(taskProfile, other.taskProfile); + } + if (modelCount > 0) { + equalsBuilder.append(modelCount, other.modelCount); + } + return equalsBuilder.isEquals(); + } + return false; + } + + @Override + public int hashCode() { + return new HashCodeBuilder() + .append(state) + .append(error) + .append(modelProfile) + .append(shingleSize) + .append(coordinatingNode) + .append(totalSizeInBytes) + .append(initProgress) + .append(totalEntities) + .append(activeEntities) + .append(taskProfile) + .append(modelCount) + .toHashCode(); + } + + @Override + public String toString() { + ToStringBuilder toStringBuilder = new ToStringBuilder(this); + + if (state != null) { + toStringBuilder.append(CommonName.STATE, state); + } + if (error != null) { + toStringBuilder.append(CommonName.ERROR, error); + } + if (modelProfile != null && modelProfile.length > 0) { + toStringBuilder.append(modelProfile); + } + if (shingleSize != -1) { + toStringBuilder.append(CommonName.SHINGLE_SIZE, shingleSize); + } + if (coordinatingNode != null) { + toStringBuilder.append(CommonName.COORDINATING_NODE, coordinatingNode); + } + if (totalSizeInBytes != -1) { + toStringBuilder.append(CommonName.TOTAL_SIZE_IN_BYTES, totalSizeInBytes); + } + if (initProgress != null) { + toStringBuilder.append(CommonName.INIT_PROGRESS, initProgress); + } + if (totalEntities != null) { + toStringBuilder.append(CommonName.TOTAL_ENTITIES, totalEntities); + } + if (activeEntities != null) { + toStringBuilder.append(CommonName.ACTIVE_ENTITIES, activeEntities); + } + if (taskProfile != null) { + toStringBuilder.append(getTaskFieldName(), taskProfile); + } + if (modelCount > 0) { + toStringBuilder.append(CommonName.MODEL_COUNT, modelCount); + } + return toStringBuilder.toString(); + } + + protected abstract TaskProfileType createTaskProfile(StreamInput in) throws IOException; + + protected abstract String getTaskFieldName(); +} diff --git a/src/main/java/org/opensearch/ad/model/DetectorState.java b/src/main/java/org/opensearch/timeseries/model/ConfigState.java similarity index 83% rename from src/main/java/org/opensearch/ad/model/DetectorState.java rename to src/main/java/org/opensearch/timeseries/model/ConfigState.java index a4959417b..4af52f2ee 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorState.java +++ b/src/main/java/org/opensearch/timeseries/model/ConfigState.java @@ -9,9 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; -public enum DetectorState { +public enum ConfigState { DISABLED, INIT, RUNNING diff --git a/src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java b/src/main/java/org/opensearch/timeseries/model/ConfigValidationIssue.java similarity index 87% rename from src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java rename to src/main/java/org/opensearch/timeseries/model/ConfigValidationIssue.java index 48586e7f8..895c070ec 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java +++ b/src/main/java/org/opensearch/timeseries/model/ConfigValidationIssue.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.io.IOException; import java.util.Map; @@ -19,14 +19,11 @@ import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.model.ValidationAspect; -import org.opensearch.timeseries.model.ValidationIssueType; import com.google.common.base.Objects; /** - * DetectorValidationIssue is a single validation issue found for detector. + * ConfigValidationIssue is a single validation issue found for config. * * For example, if detector's multiple features are using wrong type field or non existing field * the issue would be in `detector` aspect, not `model`; @@ -35,7 +32,7 @@ * subIssues are issues for each feature; * suggestion is how to fix the issue/subIssues found */ -public class DetectorValidationIssue implements ToXContentObject, Writeable { +public class ConfigValidationIssue implements ToXContentObject, Writeable { private static final String MESSAGE_FIELD = "message"; private static final String SUGGESTED_FIELD_NAME = "suggested_value"; private static final String SUB_ISSUES_FIELD_NAME = "sub_issues"; @@ -66,7 +63,7 @@ public IntervalTimeConfiguration getIntervalSuggestion() { return intervalSuggestion; } - public DetectorValidationIssue( + public ConfigValidationIssue( ValidationAspect aspect, ValidationIssueType type, String message, @@ -80,11 +77,11 @@ public DetectorValidationIssue( this.intervalSuggestion = intervalSuggestion; } - public DetectorValidationIssue(ValidationAspect aspect, ValidationIssueType type, String message) { + public ConfigValidationIssue(ValidationAspect aspect, ValidationIssueType type, String message) { this(aspect, type, message, null, null); } - public DetectorValidationIssue(StreamInput input) throws IOException { + public ConfigValidationIssue(StreamInput input) throws IOException { aspect = input.readEnum(ValidationAspect.class); type = input.readEnum(ValidationIssueType.class); message = input.readString(); @@ -139,7 +136,7 @@ public boolean equals(Object o) { return true; if (o == null || getClass() != o.getClass()) return false; - DetectorValidationIssue anotherIssue = (DetectorValidationIssue) o; + ConfigValidationIssue anotherIssue = (ConfigValidationIssue) o; return Objects.equal(getAspect(), anotherIssue.getAspect()) && Objects.equal(getMessage(), anotherIssue.getMessage()) && Objects.equal(getSubIssues(), anotherIssue.getSubIssues()) diff --git a/src/main/java/org/opensearch/timeseries/model/Entity.java b/src/main/java/org/opensearch/timeseries/model/Entity.java index f05f5dc2a..8fc6f77c9 100644 --- a/src/main/java/org/opensearch/timeseries/model/Entity.java +++ b/src/main/java/org/opensearch/timeseries/model/Entity.java @@ -24,6 +24,7 @@ import java.util.SortedMap; import java.util.TreeMap; +import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.util.SetOnce; import org.opensearch.common.Numbers; import org.opensearch.common.hash.MurmurHash3; @@ -38,6 +39,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.xcontent.XContentParser.Token; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.timeseries.annotation.Generated; import org.opensearch.timeseries.constant.CommonName; @@ -339,9 +343,11 @@ public Map getAttributes() { } } * + * Used to query customer index + * *@return a list of term query builder */ - public List getTermQueryBuilders() { + public List getTermQueryForCustomerIndex() { List res = new ArrayList<>(); for (Map.Entry attribute : attributes.entrySet()) { res.add(new TermQueryBuilder(attribute.getKey(), attribute.getValue())); @@ -349,7 +355,7 @@ public List getTermQueryBuilders() { return res; } - public List getTermQueryBuilders(String pathPrefix) { + public List getTermQueryForCustomerIndex(String pathPrefix) { List res = new ArrayList<>(); for (Map.Entry attribute : attributes.entrySet()) { res.add(new TermQueryBuilder(pathPrefix + attribute.getKey(), attribute.getValue())); @@ -357,6 +363,62 @@ public List getTermQueryBuilders(String pathPrefix) { return res; } + /** + * Used to query result index. + * + * @return a list of term queries to locate documents containing the entity + */ + public List getTermQueryForResultIndex() { + String path = "entity"; + String entityName = path + ".name"; + String entityValue = path + ".value"; + + List res = new ArrayList<>(); + + for (Map.Entry attribute : attributes.entrySet()) { + /* + * each attribute pair corresponds to a nested query like + "nested": { + "query": { + "bool": { + "filter": [ + { + "term": { + "entity.name": { + "value": "turkey4", + "boost": 1 + } + } + }, + { + "term": { + "entity.value": { + "value": "Turkey", + "boost": 1 + } + } + } + ] + } + }, + "path": "entity", + "ignore_unmapped": false, + "score_mode": "none", + "boost": 1 + } + },*/ + BoolQueryBuilder nestedBoolQueryBuilder = new BoolQueryBuilder(); + + TermQueryBuilder entityNameFilterQuery = QueryBuilders.termQuery(entityName, attribute.getKey()); + nestedBoolQueryBuilder.filter(entityNameFilterQuery); + TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValue, attribute.getValue()); + nestedBoolQueryBuilder.filter(entityValueFilterQuery); + + res.add(new NestedQueryBuilder(path, nestedBoolQueryBuilder, ScoreMode.None)); + } + return res; + } + /** * From json to Entity instance * @param entityValue json array consisting attributes diff --git a/src/main/java/org/opensearch/ad/model/EntityProfile.java b/src/main/java/org/opensearch/timeseries/model/EntityProfile.java similarity index 95% rename from src/main/java/org/opensearch/ad/model/EntityProfile.java rename to src/main/java/org/opensearch/timeseries/model/EntityProfile.java index 4f2306e96..9eee3c8a9 100644 --- a/src/main/java/org/opensearch/ad/model/EntityProfile.java +++ b/src/main/java/org/opensearch/timeseries/model/EntityProfile.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.io.IOException; import java.util.Optional; @@ -17,12 +17,12 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; /** * Profile output for detector entity. @@ -168,13 +168,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(LAST_SAMPLE_TIMESTAMP, lastSampleTimestampMs); } if (initProgress != null) { - builder.field(ADCommonName.INIT_PROGRESS, initProgress); + builder.field(CommonName.INIT_PROGRESS, initProgress); } if (modelProfile != null) { - builder.field(ADCommonName.MODEL, modelProfile); + builder.field(CommonName.MODEL, modelProfile); } if (state != null && state != EntityState.UNKNOWN) { - builder.field(ADCommonName.STATE, state); + builder.field(CommonName.STATE, state); } builder.endObject(); return builder; @@ -213,13 +213,13 @@ public String toString() { builder.append(LAST_SAMPLE_TIMESTAMP, lastSampleTimestampMs); } if (initProgress != null) { - builder.append(ADCommonName.INIT_PROGRESS, initProgress); + builder.append(CommonName.INIT_PROGRESS, initProgress); } if (modelProfile != null) { - builder.append(ADCommonName.MODELS, modelProfile); + builder.append(CommonName.MODELS, modelProfile); } if (state != null && state != EntityState.UNKNOWN) { - builder.append(ADCommonName.STATE, state); + builder.append(CommonName.STATE, state); } return builder.toString(); } diff --git a/src/main/java/org/opensearch/ad/model/EntityProfileName.java b/src/main/java/org/opensearch/timeseries/model/EntityProfileName.java similarity index 75% rename from src/main/java/org/opensearch/ad/model/EntityProfileName.java rename to src/main/java/org/opensearch/timeseries/model/EntityProfileName.java index 84fd92987..c32636d5f 100644 --- a/src/main/java/org/opensearch/ad/model/EntityProfileName.java +++ b/src/main/java/org/opensearch/timeseries/model/EntityProfileName.java @@ -9,20 +9,20 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.util.Collection; import java.util.Set; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.timeseries.Name; +import org.opensearch.timeseries.constant.CommonName; public enum EntityProfileName implements Name { - INIT_PROGRESS(ADCommonName.INIT_PROGRESS), - ENTITY_INFO(ADCommonName.ENTITY_INFO), - STATE(ADCommonName.STATE), - MODELS(ADCommonName.MODELS); + INIT_PROGRESS(CommonName.INIT_PROGRESS), + ENTITY_INFO(CommonName.ENTITY_INFO), + STATE(CommonName.STATE), + MODELS(CommonName.MODELS); private String name; @@ -42,13 +42,13 @@ public String getName() { public static EntityProfileName getName(String name) { switch (name) { - case ADCommonName.INIT_PROGRESS: + case CommonName.INIT_PROGRESS: return INIT_PROGRESS; - case ADCommonName.ENTITY_INFO: + case CommonName.ENTITY_INFO: return ENTITY_INFO; - case ADCommonName.STATE: + case CommonName.STATE: return STATE; - case ADCommonName.MODELS: + case CommonName.MODELS: return MODELS; default: throw new IllegalArgumentException(ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); diff --git a/src/main/java/org/opensearch/ad/model/EntityState.java b/src/main/java/org/opensearch/timeseries/model/EntityState.java similarity index 89% rename from src/main/java/org/opensearch/ad/model/EntityState.java rename to src/main/java/org/opensearch/timeseries/model/EntityState.java index 1e0d05d8e..36ab0fc0e 100644 --- a/src/main/java/org/opensearch/ad/model/EntityState.java +++ b/src/main/java/org/opensearch/timeseries/model/EntityState.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; public enum EntityState { UNKNOWN, diff --git a/src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java b/src/main/java/org/opensearch/timeseries/model/EntityTaskProfile.java similarity index 90% rename from src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java rename to src/main/java/org/opensearch/timeseries/model/EntityTaskProfile.java index 3d473d0e2..5971af22f 100644 --- a/src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java +++ b/src/main/java/org/opensearch/timeseries/model/EntityTaskProfile.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -22,12 +22,11 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.timeseries.model.Entity; /** - * HC detector's entity task profile. + * HC analysis's entity task profile. */ -public class ADEntityTaskProfile implements ToXContentObject, Writeable { +public class EntityTaskProfile implements ToXContentObject, Writeable { public static final String SHINGLE_SIZE_FIELD = "shingle_size"; public static final String RCF_TOTAL_UPDATES_FIELD = "rcf_total_updates"; @@ -37,7 +36,7 @@ public class ADEntityTaskProfile implements ToXContentObject, Writeable { public static final String NODE_ID_FIELD = "node_id"; public static final String ENTITY_FIELD = "entity"; public static final String TASK_ID_FIELD = "task_id"; - public static final String AD_TASK_TYPE_FIELD = "task_type"; + public static final String TASK_TYPE_FIELD = "task_type"; private Integer shingleSize; private Long rcfTotalUpdates; @@ -47,9 +46,9 @@ public class ADEntityTaskProfile implements ToXContentObject, Writeable { private String nodeId; private Entity entity; private String taskId; - private String adTaskType; + private String taskType; - public ADEntityTaskProfile( + public EntityTaskProfile( Integer shingleSize, Long rcfTotalUpdates, Boolean thresholdModelTrained, @@ -68,10 +67,10 @@ public ADEntityTaskProfile( this.nodeId = nodeId; this.entity = entity; this.taskId = taskId; - this.adTaskType = adTaskType; + this.taskType = adTaskType; } - public static ADEntityTaskProfile parse(XContentParser parser) throws IOException { + public static EntityTaskProfile parse(XContentParser parser) throws IOException { Integer shingleSize = null; Long rcfTotalUpdates = null; Boolean thresholdModelTrained = null; @@ -112,7 +111,7 @@ public static ADEntityTaskProfile parse(XContentParser parser) throws IOExceptio case TASK_ID_FIELD: taskId = parser.text(); break; - case AD_TASK_TYPE_FIELD: + case TASK_TYPE_FIELD: taskType = parser.text(); break; default: @@ -120,7 +119,7 @@ public static ADEntityTaskProfile parse(XContentParser parser) throws IOExceptio break; } } - return new ADEntityTaskProfile( + return new EntityTaskProfile( shingleSize, rcfTotalUpdates, thresholdModelTrained, @@ -133,7 +132,7 @@ public static ADEntityTaskProfile parse(XContentParser parser) throws IOExceptio ); } - public ADEntityTaskProfile(StreamInput input) throws IOException { + public EntityTaskProfile(StreamInput input) throws IOException { this.shingleSize = input.readOptionalInt(); this.rcfTotalUpdates = input.readOptionalLong(); this.thresholdModelTrained = input.readOptionalBoolean(); @@ -146,7 +145,7 @@ public ADEntityTaskProfile(StreamInput input) throws IOException { this.entity = null; } this.taskId = input.readOptionalString(); - this.adTaskType = input.readOptionalString(); + this.taskType = input.readOptionalString(); } @Override @@ -164,7 +163,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalString(taskId); - out.writeOptionalString(adTaskType); + out.writeOptionalString(taskType); } @Override @@ -194,8 +193,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (taskId != null) { xContentBuilder.field(TASK_ID_FIELD, taskId); } - if (adTaskType != null) { - xContentBuilder.field(AD_TASK_TYPE_FIELD, adTaskType); + if (taskType != null) { + xContentBuilder.field(TASK_TYPE_FIELD, taskType); } return xContentBuilder.endObject(); } @@ -265,11 +264,11 @@ public void setTaskId(String taskId) { } public String getAdTaskType() { - return adTaskType; + return taskType; } public void setAdTaskType(String adTaskType) { - this.adTaskType = adTaskType; + this.taskType = adTaskType; } @Override @@ -278,7 +277,7 @@ public boolean equals(Object o) { return true; if (o == null || getClass() != o.getClass()) return false; - ADEntityTaskProfile that = (ADEntityTaskProfile) o; + EntityTaskProfile that = (EntityTaskProfile) o; return Objects.equals(shingleSize, that.shingleSize) && Objects.equals(rcfTotalUpdates, that.rcfTotalUpdates) && Objects.equals(thresholdModelTrained, that.thresholdModelTrained) @@ -286,7 +285,7 @@ public boolean equals(Object o) { && Objects.equals(modelSizeInBytes, that.modelSizeInBytes) && Objects.equals(nodeId, that.nodeId) && Objects.equals(taskId, that.taskId) - && Objects.equals(adTaskType, that.adTaskType) + && Objects.equals(taskType, that.taskType) && Objects.equals(entity, that.entity); } @@ -302,7 +301,7 @@ public int hashCode() { nodeId, entity, taskId, - adTaskType + taskType ); } } diff --git a/src/main/java/org/opensearch/timeseries/model/IndexableResult.java b/src/main/java/org/opensearch/timeseries/model/IndexableResult.java index 7ccc58b59..0393122bd 100644 --- a/src/main/java/org/opensearch/timeseries/model/IndexableResult.java +++ b/src/main/java/org/opensearch/timeseries/model/IndexableResult.java @@ -39,17 +39,6 @@ public abstract class IndexableResult implements Writeable, ToXContentObject { protected final Optional optionalEntity; protected User user; protected final Integer schemaVersion; - /* - * model id for easy aggregations of entities. The front end needs to query - * for entities ordered by the descending/ascending order of feature values. - * After supporting multi-category fields, it is hard to write such queries - * since the entity information is stored in a nested object array. - * Also, the front end has all code/queries/ helper functions in place to - * rely on a single key per entity combo. Adding model id to forecast result - * to help the transition to multi-categorical field less painful. - */ - protected final String modelId; - protected final String entityId; protected final String taskId; public IndexableResult( @@ -63,7 +52,6 @@ public IndexableResult( Optional entity, User user, Integer schemaVersion, - String modelId, String taskId ) { this.configId = configId; @@ -76,9 +64,7 @@ public IndexableResult( this.optionalEntity = entity; this.user = user; this.schemaVersion = schemaVersion; - this.modelId = modelId; this.taskId = taskId; - this.entityId = getEntityId(entity, configId); } public IndexableResult(StreamInput input) throws IOException { @@ -104,9 +90,7 @@ public IndexableResult(StreamInput input) throws IOException { user = null; } this.schemaVersion = input.readInt(); - this.modelId = input.readOptionalString(); this.taskId = input.readOptionalString(); - this.entityId = input.readOptionalString(); } @Override @@ -134,9 +118,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); // user does not exist } out.writeInt(schemaVersion); - out.writeOptionalString(modelId); out.writeOptionalString(taskId); - out.writeOptionalString(entityId); } public String getConfigId() { @@ -171,18 +153,10 @@ public Optional getEntity() { return optionalEntity; } - public String getModelId() { - return modelId; - } - public String getTaskId() { return taskId; } - public String getEntityId() { - return entityId; - } - /** * entityId equals to model Id. It is hard to explain to users what * modelId is. entityId is more user friendly. @@ -209,9 +183,7 @@ public boolean equals(Object o) { && Objects.equal(executionStartTime, that.executionStartTime) && Objects.equal(executionEndTime, that.executionEndTime) && Objects.equal(error, that.error) - && Objects.equal(optionalEntity, that.optionalEntity) - && Objects.equal(modelId, that.modelId) - && Objects.equal(entityId, that.entityId); + && Objects.equal(optionalEntity, that.optionalEntity); } @Generated @@ -227,9 +199,7 @@ public int hashCode() { executionStartTime, executionEndTime, error, - optionalEntity, - modelId, - entityId + optionalEntity ); } @@ -245,8 +215,6 @@ public String toString() { .append("executionEndTime", executionEndTime) .append("error", error) .append("entity", optionalEntity) - .append("modelId", modelId) - .append("entityId", entityId) .toString(); } diff --git a/src/main/java/org/opensearch/ad/model/InitProgressProfile.java b/src/main/java/org/opensearch/timeseries/model/InitProgressProfile.java similarity index 99% rename from src/main/java/org/opensearch/ad/model/InitProgressProfile.java rename to src/main/java/org/opensearch/timeseries/model/InitProgressProfile.java index 4147f8ef4..1b2a83f4c 100644 --- a/src/main/java/org/opensearch/ad/model/InitProgressProfile.java +++ b/src/main/java/org/opensearch/timeseries/model/InitProgressProfile.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.io.IOException; diff --git a/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java b/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java index eaa6301df..22c0fb416 100644 --- a/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java +++ b/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java @@ -103,6 +103,12 @@ public int hashCode() { return Objects.hashCode(interval, unit); } + @Generated + @Override + public String toString() { + return "IntervalTimeConfiguration [interval=" + interval + ", unit=" + unit + "]"; + } + public long getInterval() { return interval; } @@ -119,4 +125,13 @@ public ChronoUnit getUnit() { public Duration toDuration() { return Duration.of(interval, unit); } + + /** + * + * @param other interval to compare + * @return current interval is larger than or equal to the given interval + */ + public boolean gte(IntervalTimeConfiguration other) { + return toDuration().compareTo(other.toDuration()) >= 0; + } } diff --git a/src/main/java/org/opensearch/timeseries/model/Job.java b/src/main/java/org/opensearch/timeseries/model/Job.java index 958152e2c..d258279e7 100644 --- a/src/main/java/org/opensearch/timeseries/model/Job.java +++ b/src/main/java/org/opensearch/timeseries/model/Job.java @@ -12,13 +12,13 @@ package org.opensearch.timeseries.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.DEFAULT_JOB_LOC_DURATION_SECONDS; import java.io.IOException; import java.time.Instant; import org.opensearch.commons.authuser.User; import org.opensearch.core.ParseField; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -31,6 +31,8 @@ import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; import org.opensearch.jobscheduler.spi.schedule.Schedule; import org.opensearch.jobscheduler.spi.schedule.ScheduleParser; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.ParseUtils; import com.google.common.base.Objects; @@ -44,7 +46,7 @@ enum ScheduleType { INTERVAL } - public static final String PARSE_FIELD_NAME = "AnomalyDetectorJob"; + public static final String PARSE_FIELD_NAME = "TimeSeriesJob"; public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( Job.class, new ParseField(PARSE_FIELD_NAME), @@ -62,7 +64,9 @@ enum ScheduleType { public static final String DISABLED_TIME_FIELD = "disabled_time"; public static final String USER_FIELD = "user"; private static final String RESULT_INDEX_FIELD = "result_index"; + private static final String TYPE_FIELD = "type"; + // name is config id private final String name; private final Schedule schedule; private final TimeConfiguration windowDelay; @@ -73,6 +77,7 @@ enum ScheduleType { private final Long lockDurationSeconds; private final User user; private String resultIndex; + private AnalysisType analysisType; public Job( String name, @@ -84,7 +89,8 @@ public Job( Instant lastUpdateTime, Long lockDurationSeconds, User user, - String resultIndex + String resultIndex, + AnalysisType type ) { this.name = name; this.schedule = schedule; @@ -96,6 +102,7 @@ public Job( this.lockDurationSeconds = lockDurationSeconds; this.user = user; this.resultIndex = resultIndex; + this.analysisType = type; } public Job(StreamInput input) throws IOException { @@ -117,6 +124,8 @@ public Job(StreamInput input) throws IOException { user = null; } resultIndex = input.readOptionalString(); + String typeStr = input.readOptionalString(); + this.analysisType = input.readEnum(AnalysisType.class); } @Override @@ -129,7 +138,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(IS_ENABLED_FIELD, isEnabled) .field(ENABLED_TIME_FIELD, enabledTime.toEpochMilli()) .field(LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()) - .field(LOCK_DURATION_SECONDS, lockDurationSeconds); + .field(LOCK_DURATION_SECONDS, lockDurationSeconds) + .field(TYPE_FIELD, analysisType); if (disabledTime != null) { xContentBuilder.field(DISABLED_TIME_FIELD, disabledTime.toEpochMilli()); } @@ -164,6 +174,7 @@ public void writeTo(StreamOutput output) throws IOException { output.writeBoolean(false); // user does not exist } output.writeOptionalString(resultIndex); + output.writeEnum(analysisType); } public static Job parse(XContentParser parser) throws IOException { @@ -175,9 +186,10 @@ public static Job parse(XContentParser parser) throws IOException { Instant enabledTime = null; Instant disabledTime = null; Instant lastUpdateTime = null; - Long lockDurationSeconds = DEFAULT_JOB_LOC_DURATION_SECONDS; + Long lockDurationSeconds = TimeSeriesSettings.DEFAULT_JOB_LOC_DURATION_SECONDS; User user = null; String resultIndex = null; + String analysisType = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -215,6 +227,9 @@ public static Job parse(XContentParser parser) throws IOException { case RESULT_INDEX_FIELD: resultIndex = parser.text(); break; + case TYPE_FIELD: + analysisType = parser.text(); + break; default: parser.skipChildren(); break; @@ -230,7 +245,10 @@ public static Job parse(XContentParser parser) throws IOException { lastUpdateTime, lockDurationSeconds, user, - resultIndex + resultIndex, + (Strings.isEmpty(analysisType) || AnalysisType.AD == AnalysisType.valueOf(analysisType)) + ? AnalysisType.AD + : AnalysisType.FORECAST ); } @@ -248,12 +266,13 @@ public boolean equals(Object o) { && Objects.equal(getDisabledTime(), that.getDisabledTime()) && Objects.equal(getLastUpdateTime(), that.getLastUpdateTime()) && Objects.equal(getLockDurationSeconds(), that.getLockDurationSeconds()) - && Objects.equal(getCustomResultIndex(), that.getCustomResultIndex()); + && Objects.equal(getCustomResultIndex(), that.getCustomResultIndex()) + && Objects.equal(getAnalysisType(), that.getAnalysisType()); } @Override public int hashCode() { - return Objects.hashCode(name, schedule, isEnabled, enabledTime, lastUpdateTime); + return Objects.hashCode(name, schedule, isEnabled, enabledTime, lastUpdateTime, analysisType); } @Override @@ -301,4 +320,8 @@ public User getUser() { public String getCustomResultIndex() { return resultIndex; } + + public AnalysisType getAnalysisType() { + return analysisType; + } } diff --git a/src/main/java/org/opensearch/ad/model/Mergeable.java b/src/main/java/org/opensearch/timeseries/model/Mergeable.java similarity index 89% rename from src/main/java/org/opensearch/ad/model/Mergeable.java rename to src/main/java/org/opensearch/timeseries/model/Mergeable.java index 980dad1a4..bdb9ef49e 100644 --- a/src/main/java/org/opensearch/ad/model/Mergeable.java +++ b/src/main/java/org/opensearch/timeseries/model/Mergeable.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; public interface Mergeable { void merge(Mergeable other); diff --git a/src/main/java/org/opensearch/timeseries/model/MergeableList.java b/src/main/java/org/opensearch/timeseries/model/MergeableList.java index fd9f26e84..188c0fa44 100644 --- a/src/main/java/org/opensearch/timeseries/model/MergeableList.java +++ b/src/main/java/org/opensearch/timeseries/model/MergeableList.java @@ -13,8 +13,6 @@ import java.util.List; -import org.opensearch.ad.model.Mergeable; - public class MergeableList implements Mergeable { private final List elements; diff --git a/src/main/java/org/opensearch/ad/model/ModelProfile.java b/src/main/java/org/opensearch/timeseries/model/ModelProfile.java similarity index 97% rename from src/main/java/org/opensearch/ad/model/ModelProfile.java rename to src/main/java/org/opensearch/timeseries/model/ModelProfile.java index 1d6d0ce85..63fdbcd02 100644 --- a/src/main/java/org/opensearch/ad/model/ModelProfile.java +++ b/src/main/java/org/opensearch/timeseries/model/ModelProfile.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.io.IOException; @@ -22,7 +22,6 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.model.Entity; /** * Used to show model information in profile API diff --git a/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java b/src/main/java/org/opensearch/timeseries/model/ModelProfileOnNode.java similarity index 95% rename from src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java rename to src/main/java/org/opensearch/timeseries/model/ModelProfileOnNode.java index 1e45bcc7a..ed61342c8 100644 --- a/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java +++ b/src/main/java/org/opensearch/timeseries/model/ModelProfileOnNode.java @@ -9,19 +9,19 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.io.IOException; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; public class ModelProfileOnNode implements Writeable, ToXContent { // field name in toXContent @@ -98,7 +98,7 @@ public int hashCode() { @Override public String toString() { ToStringBuilder builder = new ToStringBuilder(this); - builder.append(ADCommonName.MODEL, modelProfile); + builder.append(CommonName.MODEL, modelProfile); builder.append(NODE_ID, nodeId); return builder.toString(); } diff --git a/src/main/java/org/opensearch/timeseries/model/ProfileName.java b/src/main/java/org/opensearch/timeseries/model/ProfileName.java new file mode 100644 index 000000000..dd3f1bac9 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/ProfileName.java @@ -0,0 +1,86 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.model; + +import java.util.Collection; +import java.util.Set; + +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.Name; +import org.opensearch.timeseries.constant.CommonName; + +public enum ProfileName implements Name { + STATE(CommonName.STATE), + ERROR(CommonName.ERROR), + COORDINATING_NODE(CommonName.COORDINATING_NODE), + SHINGLE_SIZE(CommonName.SHINGLE_SIZE), + TOTAL_SIZE_IN_BYTES(CommonName.TOTAL_SIZE_IN_BYTES), + MODELS(CommonName.MODELS), + INIT_PROGRESS(CommonName.INIT_PROGRESS), + TOTAL_ENTITIES(CommonName.TOTAL_ENTITIES), + ACTIVE_ENTITIES(CommonName.ACTIVE_ENTITIES), + // AD only + AD_TASK(ADCommonName.AD_TASK), + // Forecast only + FORECAST_TASK(ForecastCommonName.FORECAST_TASK); + + private String name; + + ProfileName(String name) { + this.name = name; + } + + /** + * Get profile name + * + * @return name + */ + @Override + public String getName() { + return name; + } + + public static ProfileName getName(String name) { + switch (name) { + case CommonName.STATE: + return STATE; + case CommonName.ERROR: + return ERROR; + case CommonName.COORDINATING_NODE: + return COORDINATING_NODE; + case CommonName.SHINGLE_SIZE: + return SHINGLE_SIZE; + case CommonName.TOTAL_SIZE_IN_BYTES: + return TOTAL_SIZE_IN_BYTES; + case CommonName.MODELS: + return MODELS; + case CommonName.INIT_PROGRESS: + return INIT_PROGRESS; + case CommonName.TOTAL_ENTITIES: + return TOTAL_ENTITIES; + case CommonName.ACTIVE_ENTITIES: + return ACTIVE_ENTITIES; + case ADCommonName.AD_TASK: + return AD_TASK; + case ForecastCommonName.FORECAST_TASK: + return FORECAST_TASK; + default: + throw new IllegalArgumentException(ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); + } + } + + public static Set getNames(Collection names) { + return Name.getNameFromCollection(names, ProfileName::getName); + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/ShingleGetter.java b/src/main/java/org/opensearch/timeseries/model/ShingleGetter.java new file mode 100644 index 000000000..bb674aa9d --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/ShingleGetter.java @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.model; + +public interface ShingleGetter { + Integer getShingleSize(Integer customShingleSize); +} diff --git a/src/main/java/org/opensearch/timeseries/model/TaskState.java b/src/main/java/org/opensearch/timeseries/model/TaskState.java index 2b5c4240e..6f845d49a 100644 --- a/src/main/java/org/opensearch/timeseries/model/TaskState.java +++ b/src/main/java/org/opensearch/timeseries/model/TaskState.java @@ -50,13 +50,32 @@ * */ public enum TaskState { - CREATED, - INIT, - RUNNING, - FAILED, - STOPPED, - FINISHED; + // AD task state + CREATED("Created"), + INIT("Init"), + RUNNING("Running"), + FAILED("Failed"), + STOPPED("Stopped"), + FINISHED("Finished"), + + // Forecast task state + INIT_TEST("Initializing test"), + TEST_COMPLETE("Test complete"), + INIT_TEST_FAILED("Initializing test failed"), + INACTIVE("Inactive"); + + private final String description; + + // Constructor + TaskState(String description) { + this.description = description; + } + + // Getter + public String getDescription() { + return description; + } public static List NOT_ENDED_STATES = ImmutableList - .of(TaskState.CREATED.name(), TaskState.INIT.name(), TaskState.RUNNING.name()); + .of(TaskState.CREATED.name(), TaskState.INIT.name(), TaskState.RUNNING.name(), INIT_TEST.name()); } diff --git a/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java b/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java index fd57de7cd..0384f1356 100644 --- a/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java +++ b/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java @@ -41,6 +41,8 @@ public abstract class TimeSeriesTask implements ToXContentObject, Writeable { public static final String ESTIMATED_MINUTES_LEFT_FIELD = "estimated_minutes_left"; public static final String USER_FIELD = "user"; public static final String HISTORICAL_TASK_PREFIX = "HISTORICAL"; + public static final String RUN_ONCE_TASK_PREFIX = "RUN_ONCE"; + public static final String REAL_TIME_TASK_PREFIX = "REALTIME"; protected String configId = null; protected String taskId = null; @@ -200,6 +202,14 @@ public boolean isHistoricalTask() { return taskType.startsWith(TimeSeriesTask.HISTORICAL_TASK_PREFIX); } + public boolean isRunOnceTask() { + return taskType.startsWith(TimeSeriesTask.RUN_ONCE_TASK_PREFIX); + } + + public boolean isRealTimeTask() { + return taskType.startsWith(TimeSeriesTask.REAL_TIME_TASK_PREFIX); + } + /** * Get config level task id. If a task has no parent task, the task is config level task. * @return config level task id @@ -440,7 +450,7 @@ public int hashCode() { ); } - public abstract boolean isEntityTask(); + public abstract boolean isHistoricalEntityTask(); public String getEntityModelId() { return entity == null ? null : entity.getModelId(configId).orElse(null); diff --git a/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java b/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java index 01913a9c6..bd4a86cee 100644 --- a/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java +++ b/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java @@ -13,12 +13,14 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.transport.SearchTopForecastResultRequest; import org.opensearch.timeseries.Name; public enum ValidationIssueType implements Name { NAME(Config.NAME_FIELD), TIMEFIELD_FIELD(Config.TIMEFIELD_FIELD), SHINGLE_SIZE_FIELD(Config.SHINGLE_SIZE_FIELD), + SUGGESTED_SEASONALITY_FIELD(Config.SEASONALITY_FIELD), INDICES(Config.INDICES_FIELD), FEATURE_ATTRIBUTES(Config.FEATURE_ATTRIBUTES_FIELD), CATEGORY(Config.CATEGORY_FIELD), @@ -32,7 +34,11 @@ public enum ValidationIssueType implements Name { IMPUTATION(Config.IMPUTATION_OPTION_FIELD), DETECTION_INTERVAL(AnomalyDetector.DETECTION_INTERVAL_FIELD), FORECAST_INTERVAL(Forecaster.FORECAST_INTERVAL_FIELD), - HORIZON_SIZE(Forecaster.HORIZON_FIELD); + HORIZON_SIZE(Forecaster.HORIZON_FIELD), + SUBAGGREGATION(SearchTopForecastResultRequest.SUBAGGREGATIONS_FIELD), + RECENCY_EMPHASIS(Config.RECENCY_EMPHASIS_FIELD), + DESCRIPTION(Config.DESCRIPTION_FIELD), + HISTORY(Config.HISTORY_INTERVAL_FIELD); private String name; diff --git a/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/BatchWorker.java similarity index 91% rename from src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/BatchWorker.java index 7ba8b4383..41f62e243 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/BatchWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -24,8 +24,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; /** @@ -46,8 +46,9 @@ public BatchWorker( Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - CircuitBreakerService adCircuitBreakerService, + CircuitBreakerService circuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -58,7 +59,8 @@ public BatchWorker( Duration executionTtl, Setting batchSizeSetting, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager timeSeriesNodeStateManager, + AnalysisType context ) { super( queueName, @@ -67,8 +69,9 @@ public BatchWorker( maxHeapPercentForQueueSetting, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, @@ -78,7 +81,8 @@ public BatchWorker( concurrencySetting, executionTtl, stateTtl, - nodeStateManager + timeSeriesNodeStateManager, + context ); this.batchSize = batchSizeSetting.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(batchSizeSetting, it -> batchSize = it); @@ -111,7 +115,7 @@ protected void execute(Runnable afterProcessCallback, Runnable emptyQueueCallbac ThreadedActionListener listener = new ThreadedActionListener<>( LOG, threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + threadPoolName, getResponseListener(toProcess, batchRequest), false ); diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckPointMaintainRequestAdapter.java similarity index 69% rename from src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckPointMaintainRequestAdapter.java index 91382a4b5..efbd27b84 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckPointMaintainRequestAdapter.java @@ -9,11 +9,10 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; -import java.time.Instant; import java.util.Map; import java.util.Optional; @@ -21,35 +20,42 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.update.UpdateRequest; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.util.DateUtils; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.Strings; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.util.DateUtils; -public class CheckPointMaintainRequestAdapter { +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Convert from ModelRequest to CheckpointWriteRequest + * + */ +public class CheckPointMaintainRequestAdapter & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CacheType extends TimeSeriesCache> { private static final Logger LOG = LogManager.getLogger(CheckPointMaintainRequestAdapter.class); - private CacheProvider cache; - private CheckpointDao checkpointDao; + private CheckpointDaoType checkpointDao; private String indexName; private Duration checkpointInterval; private Clock clock; + private Provider cache; public CheckPointMaintainRequestAdapter( - CacheProvider cache, - CheckpointDao checkpointDao, + CheckpointDaoType checkpointDao, String indexName, Setting checkpointIntervalSetting, Clock clock, ClusterService clusterService, - Settings settings + Settings settings, + Provider cache ) { - this.cache = cache; this.checkpointDao = checkpointDao; this.indexName = indexName; @@ -59,17 +65,17 @@ public CheckPointMaintainRequestAdapter( .addSettingsUpdateConsumer(checkpointIntervalSetting, it -> this.checkpointInterval = DateUtils.toDuration(it)); this.clock = clock; + this.cache = cache; } public Optional convert(CheckpointMaintainRequest request) { - String detectorId = request.getId(); - String modelId = request.getEntityModelId(); + String configId = request.getConfigId(); + String modelId = request.getModelId(); - Optional> stateToMaintain = cache.get().getForMaintainance(detectorId, modelId); - if (!stateToMaintain.isEmpty()) { - ModelState state = stateToMaintain.get(); - Instant instant = state.getLastCheckpointTime(); - if (!checkpointDao.shouldSave(instant, false, checkpointInterval, clock)) { + Optional> stateToMaintain = cache.get().getForMaintainance(configId, modelId); + if (stateToMaintain.isPresent()) { + ModelState state = stateToMaintain.get(); + if (!checkpointDao.shouldSave(state, false, checkpointInterval, clock)) { return Optional.empty(); } @@ -85,7 +91,7 @@ public Optional convert(CheckpointMaintainRequest reques .of( new CheckpointWriteRequest( request.getExpirationEpochMs(), - detectorId, + configId, request.getPriority(), // If the document does not already exist, the contents of the upsert element // are inserted as a new document. diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainRequest.java similarity index 58% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainRequest.java index 28fdfcc91..479965240 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainRequest.java @@ -9,17 +9,17 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; public class CheckpointMaintainRequest extends QueuedRequest { - private String entityModelId; + private String modelId; - public CheckpointMaintainRequest(long expirationEpochMs, String detectorId, RequestPriority priority, String entityModelId) { - super(expirationEpochMs, detectorId, priority); - this.entityModelId = entityModelId; + public CheckpointMaintainRequest(long expirationEpochMs, String configId, RequestPriority priority, String entityModelId) { + super(expirationEpochMs, configId, priority); + this.modelId = entityModelId; } - public String getEntityModelId() { - return entityModelId; + public String getModelId() { + return modelId; } } diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainWorker.java new file mode 100644 index 000000000..ba28043c9 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainWorker.java @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.function.Function; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; + +public abstract class CheckpointMaintainWorker extends ScheduledWorker { + + private Function> converter; + + public CheckpointMaintainWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + String threadPoolName, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + RateLimitedRequestWorker targetQueue, + Duration stateTtl, + NodeStateManager nodeStateManager, + Function> converter, + AnalysisType context + ) { + super( + workerName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + threadPoolName, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + targetQueue, + stateTtl, + nodeStateManager, + context + ); + this.converter = converter; + } + + @Override + protected List transformRequests(List requests) { + List allRequests = new ArrayList<>(); + for (CheckpointMaintainRequest request : requests) { + Optional converted = converter.apply(request); + if (!converted.isEmpty()) { + allRequests.add(converted.get()); + } + } + return allRequests; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java similarity index 59% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java index d4f1f99af..d6c13cf19 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java @@ -9,10 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -21,7 +18,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.Random; import java.util.Set; @@ -32,19 +28,8 @@ import org.opensearch.action.get.MultiGetItemResponse; import org.opensearch.action.get.MultiGetRequest; import org.opensearch.action.get.MultiGetResponse; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.indices.ADIndex; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.ml.ThresholdingResult; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.stats.ADStats; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; @@ -53,39 +38,44 @@ import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.TimeSeriesCache; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; import org.opensearch.timeseries.util.ExceptionUtil; -import org.opensearch.timeseries.util.ParseUtils; - -/** - * a queue for loading model checkpoint. The read is a multi-get query. Possible results are: - * a). If a checkpoint is not found, we forward that request to the cold start queue. - * b). When a request gets errors, the queue does not change its expiry time and puts - * that request to the end of the queue and automatically retries them before they expire. - * c) When a checkpoint is found, we load that point to memory and score the input - * data point and save the result if a complete model exists. Otherwise, we enqueue - * the sample. If we can host that model in memory (e.g., there is enough memory), - * we put the loaded model to cache. Otherwise (e.g., a cold entity), we write the - * updated checkpoint back to disk. - * - */ -public class CheckpointReadWorker extends BatchWorker { + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class CheckpointReadWorker, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker> + extends BatchWorker { + private static final Logger LOG = LogManager.getLogger(CheckpointReadWorker.class); - public static final String WORKER_NAME = "checkpoint-read"; - private final ModelManager modelManager; - private final CheckpointDao checkpointDao; - private final EntityColdStartWorker entityColdStartQueue; - private final ResultWriteWorker resultWriteQueue; - private final ADIndexManagement indexUtil; - private final CacheProvider cacheProvider; - private final CheckpointWriteWorker checkpointWriteQueue; - private final ADStats adStats; + + protected final ModelManagerType modelManager; + protected final CheckpointType checkpointDao; + protected final ColdStartWorkerType coldStartWorker; + protected final SaveResultStrategyType resultWriteWorker; + protected final IndexManagementType indexUtil; + protected final Stats timeSeriesStats; + protected final CheckpointWriteWorkerType checkpointWriteWorker; + protected final Provider> cacheProvider; + protected final String checkpointIndexName; + protected final StatNames modelCorruptionStat; public CheckpointReadWorker( + String workerName, long heapSizeInBytes, int singleRequestSizeInBytes, Setting maxHeapPercentForQueueSetting, @@ -93,6 +83,7 @@ public CheckpointReadWorker( Random random, CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -100,19 +91,24 @@ public CheckpointReadWorker( float lowSegmentPruneRatio, int maintenanceFreqConstant, Duration executionTtl, - ModelManager modelManager, - CheckpointDao checkpointDao, - EntityColdStartWorker entityColdStartQueue, - ResultWriteWorker resultWriteQueue, + ModelManagerType modelManager, + CheckpointType checkpointDao, + ColdStartWorkerType entityColdStartWorker, NodeStateManager stateManager, - ADIndexManagement indexUtil, - CacheProvider cacheProvider, + IndexManagementType indexUtil, + Provider> cacheProvider, Duration stateTtl, - CheckpointWriteWorker checkpointWriteQueue, - ADStats adStats + CheckpointWriteWorkerType checkpointWriteWorker, + Stats timeSeriesStats, + Setting concurrencySetting, + Setting batchSizeSetting, + String checkpointIndexName, + StatNames modelCorruptionStat, + AnalysisType context, + SaveResultStrategyType resultWriteWorker ) { super( - WORKER_NAME, + workerName, heapSizeInBytes, singleRequestSizeInBytes, maxHeapPercentForQueueSetting, @@ -120,27 +116,31 @@ public CheckpointReadWorker( random, adCircuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, - AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, + concurrencySetting, executionTtl, - AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + batchSizeSetting, stateTtl, - stateManager + stateManager, + context ); this.modelManager = modelManager; this.checkpointDao = checkpointDao; - this.entityColdStartQueue = entityColdStartQueue; - this.resultWriteQueue = resultWriteQueue; + this.coldStartWorker = entityColdStartWorker; this.indexUtil = indexUtil; this.cacheProvider = cacheProvider; - this.checkpointWriteQueue = checkpointWriteQueue; - this.adStats = adStats; + this.checkpointWriteWorker = checkpointWriteWorker; + this.timeSeriesStats = timeSeriesStats; + this.checkpointIndexName = checkpointIndexName; + this.modelCorruptionStat = modelCorruptionStat; + this.resultWriteWorker = resultWriteWorker; } @Override @@ -149,28 +149,29 @@ protected void executeBatchRequest(MultiGetRequest request, ActionListener toProcess) { + protected MultiGetRequest toBatchRequest(List toProcess) { MultiGetRequest multiGetRequest = new MultiGetRequest(); - for (EntityRequest request : toProcess) { - Optional modelId = request.getModelId(); - if (false == modelId.isPresent()) { + for (FeatureRequest request : toProcess) { + String modelId = request.getModelId(); + if (null == modelId) { continue; } - multiGetRequest.add(new MultiGetRequest.Item(ADCommonName.CHECKPOINT_INDEX_NAME, modelId.get())); + multiGetRequest.add(new MultiGetRequest.Item(checkpointIndexName, modelId)); } return multiGetRequest; } @Override - protected ActionListener getResponseListener(List toProcess, MultiGetRequest batchRequest) { + protected ActionListener getResponseListener(List toProcess, MultiGetRequest batchRequest) { return ActionListener.wrap(response -> { + final MultiGetItemResponse[] itemResponses = response.getResponses(); Map successfulRequests = new HashMap<>(); @@ -186,11 +187,11 @@ protected ActionListener getResponseListener(List getResponseListener(List getResponseListener(List modelId = origRequest.getModelId(); - if (modelId.isPresent() && notFoundModels.contains(modelId.get())) { + for (FeatureRequest origRequest : toProcess) { + String modelId = origRequest.getModelId(); + if (modelId != null && notFoundModels.contains(modelId)) { // submit to cold start queue - entityColdStartQueue.put(origRequest); + coldStartWorker.put(origRequest); } } } @@ -241,15 +242,17 @@ protected ActionListener getResponseListener(List modelId = origRequest.getModelId(); - if (modelId.isPresent() && stopDetectorRequests.containsKey(modelId.get())) { - String adID = origRequest.detectorId; + for (FeatureRequest origRequest : toProcess) { + String modelId = origRequest.getModelId(); + if (modelId != null && stopDetectorRequests.containsKey(modelId)) { + String configID = origRequest.getConfigId(); nodeStateManager .setException( - adID, - new EndRunException(adID, CommonMessages.BUG_RESPONSE, stopDetectorRequests.get(modelId.get()), false) + configID, + new EndRunException(configID, CommonMessages.BUG_RESPONSE, stopDetectorRequests.get(modelId), false) ); + // once one EndRunException is set, we can break; no point setting the exception repeatedly + break; } } } @@ -258,11 +261,10 @@ protected ActionListener getResponseListener(List { if (ExceptionUtil.isOverloaded(exception)) { - LOG.error("too many get AD model checkpoint requests or shard not available"); + LOG.error("too many get model checkpoint requests or shard not available"); setCoolDownStart(); } else if (ExceptionUtil.isRetryAble(exception)) { // retry all of them @@ -273,9 +275,9 @@ protected ActionListener getResponseListener(List toProcess, + List toProcess, Map successfulRequests, Set retryableRequests ) { @@ -287,42 +289,41 @@ private void processCheckpointIteration( // if false, finally will process next checkpoints boolean processNextInCallBack = false; try { - EntityFeatureRequest origRequest = toProcess.get(i); + FeatureRequest origRequest = toProcess.get(i); - Optional modelIdOptional = origRequest.getModelId(); - if (false == modelIdOptional.isPresent()) { + String modelId = origRequest.getModelId(); + if (null == modelId) { return; } - String detectorId = origRequest.getId(); - Entity entity = origRequest.getEntity(); - - String modelId = modelIdOptional.get(); + String configId = origRequest.getConfigId(); + Optional entity = origRequest.getEntity(); MultiGetItemResponse checkpointResponse = successfulRequests.get(modelId); if (checkpointResponse != null) { // successful requests - Optional> checkpoint = checkpointDao - .processGetResponse(checkpointResponse.getResponse(), modelId); + ModelState modelState = checkpointDao + .processHCGetResponse(checkpointResponse.getResponse(), modelId, configId); - if (false == checkpoint.isPresent()) { - // checkpoint is too big + if (null == modelState) { + // checkpoint is not available (e.g., too big or corrupted); cold start again + coldStartWorker.put(origRequest); return; } nodeStateManager .getConfig( - detectorId, - AnalysisType.AD, - onGetDetector( + configId, + context, + processIterationUsingConfig( origRequest, i, - detectorId, + configId, toProcess, successfulRequests, retryableRequests, - checkpoint, + modelState, entity, modelId ) @@ -339,39 +340,47 @@ private void processCheckpointIteration( } } - private ActionListener> onGetDetector( - EntityFeatureRequest origRequest, + protected ActionListener> processIterationUsingConfig( + FeatureRequest origRequest, int index, - String detectorId, - List toProcess, + String configId, + List toProcess, Map successfulRequests, Set retryableRequests, - Optional> checkpoint, - Entity entity, + ModelState restoredModelState, + Optional entity, String modelId ) { - return ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + return ActionListener.wrap(configOptional -> { + if (configOptional.isEmpty()) { + LOG.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); return; } - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); + Config config = configOptional.get(); - ModelState modelState = modelManager - .processEntityCheckpoint(checkpoint, entity, modelId, detectorId, detector.getShingleSize()); - - ThresholdingResult result = null; + RCFResultType result = null; try { result = modelManager - .getAnomalyResultForEntity(origRequest.getCurrentFeature(), modelState, modelId, entity, detector.getShingleSize()); + .getResult( + new Sample( + origRequest.getCurrentFeature(), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds()) + ), + restoredModelState, + modelId, + entity, + config, + origRequest.getTaskId() + ); } catch (IllegalArgumentException e) { // fail to score likely due to model corruption. Re-cold start to recover. LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", origRequest.getModelId()), e); - adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).increment(); - if (origRequest.getModelId().isPresent()) { - String entityModelId = origRequest.getModelId().get(); + timeSeriesStats.getStat(modelCorruptionStat.getName()).increment(); + if (null != origRequest.getModelId()) { + String entityModelId = origRequest.getModelId(); checkpointDao .deleteModelCheckpoint( entityModelId, @@ -383,56 +392,26 @@ private ActionListener> onGetDetector( ); } - entityColdStartQueue.put(origRequest); + coldStartWorker.put(origRequest); processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); return; } - if (result != null && result.getRcfScore() > 0) { - RequestPriority requestPriority = result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM; - - List resultsToSave = result - .toIndexableResults( - detector, - Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), - Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + detector.getIntervalInMilliseconds()), - Instant.now(), - Instant.now(), - ParseUtils.getFeatureData(origRequest.getCurrentFeature(), detector), - Optional.ofNullable(entity), - indexUtil.getSchemaVersion(ADIndex.RESULT), - modelId, - null, - null - ); - - for (AnomalyResult r : resultsToSave) { - resultWriteQueue - .put( - new ResultWriteRequest( - origRequest.getExpirationEpochMs(), - detectorId, - requestPriority, - r, - detector.getCustomResultIndex() - ) - ); - } - } + resultWriteWorker.saveResult(result, config, origRequest, modelId); // try to load to cache - boolean loaded = cacheProvider.get().hostIfPossible(detector, modelState); + boolean loaded = cacheProvider.get().hostIfPossible(config, restoredModelState); if (false == loaded) { // not in memory. Maybe cold entities or some other entities // have filled the slot while waiting for loading checkpoints. - checkpointWriteQueue.write(modelState, true, RequestPriority.LOW); + checkpointWriteWorker.write(restoredModelState, true, RequestPriority.LOW); } processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); }, exception -> { LOG.error(new ParameterizedMessage("fail to get checkpoint [{}]", modelId, exception)); - nodeStateManager.setException(detectorId, exception); + nodeStateManager.setException(configId, exception); processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); }); } diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteRequest.java similarity index 94% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteRequest.java index 9c41e55be..02a374f82 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import org.opensearch.action.update.UpdateRequest; diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteWorker.java similarity index 68% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteWorker.java index a26cb8b94..2d486634a 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteWorker.java @@ -1,22 +1,12 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; -import java.time.Instant; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -30,10 +20,6 @@ import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.update.UpdateRequest; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; @@ -43,58 +29,69 @@ import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.util.ExceptionUtil; -public class CheckpointWriteWorker extends BatchWorker { +public abstract class CheckpointWriteWorker & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao> + extends BatchWorker { private static final Logger LOG = LogManager.getLogger(CheckpointWriteWorker.class); - public static final String WORKER_NAME = "checkpoint-write"; - private final CheckpointDao checkpoint; - private final String indexName; - private final Duration checkpointInterval; + protected final CheckpointDaoType checkpoint; + protected final String indexName; + protected final Duration checkpointInterval; public CheckpointWriteWorker( - long heapSizeInBytes, - int singleRequestSizeInBytes, + String queueName, + long heapSize, + int singleRequestSize, Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - CircuitBreakerService adCircuitBreakerService, + CircuitBreakerService circuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, + Setting concurrencySetting, Duration executionTtl, - CheckpointDao checkpoint, + Setting batchSizeSetting, + Duration stateTtl, + NodeStateManager timeSeriesNodeStateManager, + CheckpointDaoType checkpoint, String indexName, Duration checkpointInterval, - NodeStateManager stateManager, - Duration stateTtl + AnalysisType context ) { super( - WORKER_NAME, - heapSizeInBytes, - singleRequestSizeInBytes, + queueName, + heapSize, + singleRequestSize, maxHeapPercentForQueueSetting, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, - AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + concurrencySetting, executionTtl, - AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + batchSizeSetting, stateTtl, - stateManager + timeSeriesNodeStateManager, + context ); this.checkpoint = checkpoint; this.indexName = indexName; @@ -133,7 +130,7 @@ protected ActionListener getResponseListener(List getResponseListener(List modelState, boolean forceWrite, RequestPriority priority) { - Instant instant = modelState.getLastCheckpointTime(); - if (!checkpoint.shouldSave(instant, forceWrite, checkpointInterval, clock)) { - return; - } - - if (modelState.getModel() != null) { - String detectorId = modelState.getId(); + public void write(ModelState modelState, boolean forceWrite, RequestPriority priority) { + if (checkpoint.shouldSave(modelState, forceWrite, checkpointInterval, clock)) { + String configId = modelState.getConfigId(); String modelId = modelState.getModelId(); - if (modelId == null || detectorId == null) { + if (modelId == null || configId == null) { return; } - nodeStateManager.getConfig(detectorId, AnalysisType.AD, onGetDetector(detectorId, modelId, modelState, priority)); + nodeStateManager.getConfig(configId, context, onGetConfig(configId, modelId, modelState, priority)); } } - private ActionListener> onGetDetector( - String detectorId, + private ActionListener> onGetConfig( + String configId, String modelId, - ModelState modelState, + ModelState modelState, RequestPriority priority ) { - return ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + return ActionListener.wrap(configOptional -> { + if (false == configOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); return; } - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); + Config config = configOptional.get(); try { Map source = checkpoint.toIndexSource(modelState); @@ -192,8 +184,8 @@ private ActionListener> onGetDetector( modelState.setLastCheckpointTime(clock.instant()); CheckpointWriteRequest request = new CheckpointWriteRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + configId, priority, // If the document does not already exist, the contents of the upsert element // are inserted as a new document. @@ -214,22 +206,21 @@ private ActionListener> onGetDetector( LOG.error(new ParameterizedMessage("Exception while serializing models for [{}]", modelId), e); } - }, exception -> { LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); }); + }, exception -> { LOG.error(new ParameterizedMessage("fail to get config [{}]", configId), exception); }); } - public void writeAll(List> modelStates, String detectorId, boolean forceWrite, RequestPriority priority) { - ActionListener> onGetForAll = ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + public void writeAll(List> modelStates, String configId, boolean forceWrite, RequestPriority priority) { + ActionListener> onGetForAll = ActionListener.wrap(configOptional -> { + if (false == configOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); return; } - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); + Config config = configOptional.get(); try { List allRequests = new ArrayList<>(); - for (ModelState state : modelStates) { - Instant instant = state.getLastCheckpointTime(); - if (!checkpoint.shouldSave(instant, forceWrite, checkpointInterval, clock)) { + for (ModelState state : modelStates) { + if (!checkpoint.shouldSave(state, forceWrite, checkpointInterval, clock)) { continue; } @@ -245,8 +236,8 @@ public void writeAll(List> modelStates, String detectorI allRequests .add( new CheckpointWriteRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + configId, priority, // If the document does not already exist, the contents of the upsert element // are inserted as a new document. @@ -266,11 +257,11 @@ public void writeAll(List> modelStates, String detectorI // As we are gonna retry serializing either when the entity is // evicted out of cache or during the next maintenance period, // don't do anything when the exception happens. - LOG.info(new ParameterizedMessage("Exception while serializing models for [{}]", detectorId), e); + LOG.info(new ParameterizedMessage("Exception while serializing models for [{}]", configId), e); } - }, exception -> { LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); }); + }, exception -> { LOG.error(new ParameterizedMessage("fail to get config [{}]", configId), exception); }); - nodeStateManager.getConfig(detectorId, AnalysisType.AD, onGetForAll); + nodeStateManager.getConfig(configId, context, onGetForAll); } } diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java new file mode 100644 index 000000000..703360a3f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.model.IndexableResult; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ColdEntityWorker & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, RCFResultType extends IntermediateResult, ModelManagerType extends ModelManager, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, CheckpointReadWorkerType extends CheckpointReadWorker> + extends ScheduledWorker { + + public ColdEntityWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + String threadPoolName, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + CheckpointReadWorkerType checkpointReadQueue, + Duration stateTtl, + NodeStateManager nodeStateManager, + Setting checkpointReadBatchSizeSetting, + Setting expectedColdEntityExecutionMillsSetting, + AnalysisType context + ) { + super( + workerName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + threadPoolName, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + checkpointReadQueue, + stateTtl, + nodeStateManager, + context + ); + + this.batchSize = checkpointReadBatchSizeSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(checkpointReadBatchSizeSetting, it -> this.batchSize = it); + + this.expectedExecutionTimeInMilliSecsPerRequest = expectedColdEntityExecutionMillsSetting.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(expectedColdEntityExecutionMillsSetting, it -> this.expectedExecutionTimeInMilliSecsPerRequest = it); + } + + @Override + protected List transformRequests(List requests) { + // guarantee we only send low priority requests + return requests.stream().filter(request -> request.getPriority() == RequestPriority.LOW).collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java new file mode 100644 index 000000000..dccfca98c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java @@ -0,0 +1,203 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.util.ExceptionUtil; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class ColdStartWorker & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, CacheType extends TimeSeriesCache, IndexableResultType extends IndexableResult, IntermediateResultType extends IntermediateResult, ModelManagerType extends ModelManager, SaveResultStrategyType extends SaveResultStrategy> + extends SingleRequestWorker { + private static final Logger LOG = LogManager.getLogger(ColdStartWorker.class); + + protected final ColdStarterType coldStarter; + protected final CacheType cacheProvider; + private final ModelManagerType modelManager; + private final SaveResultStrategyType resultSaver; + + public ColdStartWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + String threadPoolName, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Setting concurrency, + Duration executionTtl, + ColdStarterType coldStarter, + Duration stateTtl, + NodeStateManager nodeStateManager, + CacheType cacheProvider, + AnalysisType context, + ModelManagerType modelManager, + SaveResultStrategyType resultSaver + ) { + super( + workerName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + threadPoolName, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + concurrency, + executionTtl, + stateTtl, + nodeStateManager, + context + ); + this.coldStarter = coldStarter; + this.cacheProvider = cacheProvider; + this.modelManager = modelManager; + this.resultSaver = resultSaver; + } + + @Override + protected void executeRequest(FeatureRequest coldStartRequest, ActionListener listener) { + String configId = coldStartRequest.getConfigId(); + + String modelId = coldStartRequest.getModelId(); + + if (null == modelId) { + String error = String.format(Locale.ROOT, "Fail to get model id for request %s", coldStartRequest); + LOG.warn(error); + listener.onFailure(new RuntimeException(error)); + return; + } + ModelState modelState = createEmptyState(coldStartRequest, modelId, configId); + + ActionListener> coldStartListener = ActionListener.wrap(r -> { + nodeStateManager.getConfig(configId, context, ActionListener.wrap(configOptional -> { + try { + if (!configOptional.isPresent()) { + LOG + .error( + new ParameterizedMessage( + "fail to load trained model [{}] to cache due to the config not being found.", + modelState.getModelId() + ) + ); + return; + } + Config config = configOptional.get(); + + // score the current feature if training suceeded + if (modelState.getModel().isPresent()) { + String taskId = coldStartRequest.getTaskId(); + if (r != null) { + for (int i = 0; i < r.size(); i++) { + Sample entry = r.get(i); + IndexableResultType trainingResult = createIndexableResult( + config, + taskId, + modelId, + entry, + coldStartRequest.getEntity() + ); + resultSaver.saveResult(trainingResult, config); + } + } + + long dataStartTime = coldStartRequest.getDataStartTimeMillis(); + Sample currentSample = new Sample( + coldStartRequest.getCurrentFeature(), + Instant.ofEpochMilli(dataStartTime), + Instant.ofEpochMilli(dataStartTime + config.getIntervalInMilliseconds()) + ); + IntermediateResultType result = modelManager + .getResult(currentSample, modelState, modelId, coldStartRequest.getEntity(), config, taskId); + resultSaver.saveResult(result, config, coldStartRequest, modelId); + } + + // only load model to memory for real time analysis that has no task id + if (null == coldStartRequest.getTaskId()) { + cacheProvider.hostIfPossible(configOptional.get(), modelState); + } + + } finally { + listener.onResponse(null); + } + }, listener::onFailure)); + + }, e -> { + try { + if (ExceptionUtil.isOverloaded(e)) { + LOG.error("OpenSearch is overloaded"); + setCoolDownStart(); + } + nodeStateManager.setException(configId, e); + } finally { + listener.onFailure(e); + } + }); + + coldStarter.trainModel(coldStartRequest, configId, modelState, coldStartListener); + } + + protected abstract ModelState createEmptyState(FeatureRequest coldStartRequest, String modelId, String configId); + + protected abstract IndexableResultType createIndexableResult( + Config config, + String taskId, + String modelId, + Sample entry, + Optional entity + ); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ConcurrentWorker.java similarity index 92% rename from src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/ConcurrentWorker.java index 3df70c935..45f1e424b 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ConcurrentWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -23,8 +23,8 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; /** @@ -53,7 +53,7 @@ public abstract class ConcurrentWorker extend * rate AD's usage on ES threadpools. * @param clusterService Cluster service accessor * @param random Random number generator - * @param adCircuitBreakerService AD Circuit breaker service + * @param circuitBreakerService Circuit breaker service * @param threadPool threadpool accessor * @param settings Cluster settings getter * @param maxQueuedTaskRatio maximum queued tasks ratio in ES threadpools @@ -74,8 +74,9 @@ public ConcurrentWorker( Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - CircuitBreakerService adCircuitBreakerService, + CircuitBreakerService circuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -85,7 +86,8 @@ public ConcurrentWorker( Setting concurrencySetting, Duration executionTtl, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { super( queueName, @@ -94,8 +96,9 @@ public ConcurrentWorker( maxHeapPercentForQueueSetting, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, @@ -103,7 +106,8 @@ public ConcurrentWorker( lowSegmentPruneRatio, maintenanceFreqConstant, stateTtl, - nodeStateManager + nodeStateManager, + context ); this.permits = new Semaphore(concurrencySetting.get(settings)); @@ -132,7 +136,7 @@ public void maintenance() { */ @Override protected void triggerProcess() { - threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> { + threadPool.executor(threadPoolName).execute(() -> { if (permits.tryAcquire()) { try { lastExecuteTime = clock.instant(); diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/FeatureRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/FeatureRequest.java new file mode 100644 index 000000000..0749381f4 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/FeatureRequest.java @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.util.Optional; + +import org.opensearch.timeseries.model.Entity; + +public class FeatureRequest extends QueuedRequest { + private final double[] currentFeature; + private final long dataStartTimeMillis; + protected final String modelId; + private final Optional entity; + private final String taskId; + + // used in HC + public FeatureRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + double[] currentFeature, + long dataStartTimeMs, + Entity entity, + String taskId + ) { + super(expirationEpochMs, configId, priority); + this.currentFeature = currentFeature; + this.dataStartTimeMillis = dataStartTimeMs; + this.modelId = entity.getModelId(configId).isEmpty() ? null : entity.getModelId(configId).get(); + this.entity = Optional.ofNullable(entity); + this.taskId = taskId; + } + + // used in single-stream + public FeatureRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + String modelId, + double[] currentFeature, + long dataStartTimeMs, + String taskId + ) { + super(expirationEpochMs, configId, priority); + this.currentFeature = currentFeature; + this.dataStartTimeMillis = dataStartTimeMs; + this.modelId = modelId; + this.entity = Optional.empty(); + this.taskId = taskId; + } + + public double[] getCurrentFeature() { + return currentFeature; + } + + public long getDataStartTimeMillis() { + return dataStartTimeMillis; + } + + public String getModelId() { + return modelId; + } + + public Optional getEntity() { + return entity; + } + + public String getTaskId() { + return taskId; + } + + public boolean isRunOnce() { + return taskId != null; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/QueuedRequest.java similarity index 77% rename from src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java rename to src/main/java/org/opensearch/timeseries/ratelimit/QueuedRequest.java index 66c440db9..a13a490de 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/QueuedRequest.java @@ -9,22 +9,22 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; public abstract class QueuedRequest { protected long expirationEpochMs; - protected String detectorId; + protected String configId; protected RequestPriority priority; /** * * @param expirationEpochMs Request expiry time in milliseconds - * @param detectorId Detector Id + * @param configId Detector Id * @param priority how urgent the request is */ - protected QueuedRequest(long expirationEpochMs, String detectorId, RequestPriority priority) { + protected QueuedRequest(long expirationEpochMs, String configId, RequestPriority priority) { this.expirationEpochMs = expirationEpochMs; - this.detectorId = detectorId; + this.configId = configId; this.priority = priority; } @@ -47,11 +47,11 @@ public void setPriority(RequestPriority priority) { this.priority = priority; } - public String getId() { - return detectorId; + public String getConfigId() { + return configId; } public void setDetectorId(String detectorId) { - this.detectorId = detectorId; + this.configId = detectorId; } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/RateLimitedRequestWorker.java similarity index 89% rename from src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/RateLimitedRequestWorker.java index 911ae43a5..93df5b1ae 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/RateLimitedRequestWorker.java @@ -9,9 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; +import static org.opensearch.timeseries.settings.TimeSeriesSettings.COOLDOWN_MINUTES; import java.time.Clock; import java.time.Duration; @@ -39,10 +39,10 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.threadpool.ThreadPool; import org.opensearch.threadpool.ThreadPoolStats; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.ExpiringState; import org.opensearch.timeseries.MaintenanceState; import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.TimeSeriesException; @@ -156,6 +156,15 @@ public int clearExpiredRequests() { } return removed; } + + public boolean hasConfigId(String configId) { + for (RequestType request : content) { + if (configId.equals(request.getConfigId())) { + return true; + } + } + return false; + } } private static final Logger LOG = LogManager.getLogger(RateLimitedRequestWorker.class); @@ -175,8 +184,9 @@ public int clearExpiredRequests() { protected final ConcurrentSkipListMap requestQueues; private String lastSelectedRequestQueueId; protected Random random; - private CircuitBreakerService adCircuitBreakerService; + private CircuitBreakerService circuitBreakerService; protected ThreadPool threadPool; + protected String threadPoolName; protected Instant cooldownStart; protected int coolDownMinutes; private float maxQueuedTaskRatio; @@ -186,6 +196,7 @@ public int clearExpiredRequests() { protected int maintenanceFreqConstant; private final Duration stateTtl; protected final NodeStateManager nodeStateManager; + protected final AnalysisType context; public RateLimitedRequestWorker( String workerName, @@ -194,8 +205,9 @@ public RateLimitedRequestWorker( Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - CircuitBreakerService adCircuitBreakerService, + CircuitBreakerService circuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -203,7 +215,8 @@ public RateLimitedRequestWorker( float lowRequestQueuePruneRatio, int maintenanceFreqConstant, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { this.heapSize = heapSizeInBytes; this.singleRequestSize = singleRequestSizeInBytes; @@ -218,8 +231,9 @@ public RateLimitedRequestWorker( this.workerName = workerName; this.random = random; - this.adCircuitBreakerService = adCircuitBreakerService; + this.circuitBreakerService = circuitBreakerService; this.threadPool = threadPool; + this.threadPoolName = threadPoolName; this.maxQueuedTaskRatio = maxQueuedTaskRatio; this.clock = clock; this.mediumRequestQueuePruneRatio = mediumRequestQueuePruneRatio; @@ -228,22 +242,24 @@ public RateLimitedRequestWorker( this.lastSelectedRequestQueueId = null; this.requestQueues = new ConcurrentSkipListMap<>(); this.cooldownStart = Instant.MIN; - this.coolDownMinutes = (int) (AD_COOLDOWN_MINUTES.get(settings).getMinutes()); + this.coolDownMinutes = (int) (COOLDOWN_MINUTES.get(settings).getMinutes()); this.maintenanceFreqConstant = maintenanceFreqConstant; this.stateTtl = stateTtl; this.nodeStateManager = nodeStateManager; + this.context = context; } - protected String getWorkerName() { + public String getWorkerName() { return workerName; } /** - * To add fairness to multiple detectors, HCAD allocates queues at a per - * detector granularity and pulls off requests across similar queues in a - * round-robin fashion. This way, if one detector has a much higher - * cardinality than other detectors, the unfinished portion of that - * detector’s workload times out, and other detectors’ workloads continue + * To add fairness to multiple analyses, HC allocates queues at a per + * analysis (e.g., detector or forecaster) granularity and pulls off + * requests across similar queues in a round-robin fashion. + * This way, if one analysis has a much higher + * cardinality than other analysis, the unfinished portion of that + * analysis's workload times out, and other analyses’ workloads continue * operating with predictable performance. For example, for loading checkpoints, * HCAD pulls off 10 requests from one detector’ queues, issues a mget request * to ES, wait for it to finish, and then does it again for other detectors’ @@ -305,7 +321,7 @@ protected void putOnly(RequestType request) { // just use the RequestQueue priority (i.e., low or high) as the key of the RequestQueue map. RequestQueue requestQueue = requestQueues .computeIfAbsent( - RequestPriority.MEDIUM == request.getPriority() ? request.getId() : request.getPriority().name(), + RequestPriority.MEDIUM == request.getPriority() ? request.getConfigId() : request.getPriority().name(), k -> new RequestQueue() ); @@ -429,7 +445,7 @@ private void maintainForMemory() { int exceededSize = exceededSize(); if (exceededSize > 0) { prune(requestQueues, exceededSize); - } else if (adCircuitBreakerService.isOpen()) { + } else if (circuitBreakerService.isOpen()) { // remove a few items in each RequestQueue prune(requestQueues); } @@ -551,7 +567,7 @@ protected void process() { } catch (Exception e) { LOG.error(new ParameterizedMessage("Fail to process requests in [{}].", this.workerName), e); } - }, new TimeValue(coolDownMinutes, TimeUnit.MINUTES), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + }, new TimeValue(coolDownMinutes, TimeUnit.MINUTES), threadPoolName); } else { try { triggerProcess(); @@ -566,6 +582,29 @@ protected void process() { } } + /** + * + * @param configId Config Id + * @return whether there is any unfinished request belonging to a configId + */ + public boolean hasConfigId(String configId) { + for (Map.Entry requestQueueEntry : requestQueues.entrySet()) { + String requestId = requestQueueEntry.getKey(); + if (requestId.equals(RequestPriority.LOW.name()) || requestId.equals(RequestPriority.HIGH.name())) { + RequestQueue requests = requestQueueEntry.getValue(); + if (requests.hasConfigId(configId)) { + return true; + } + } else { + // requestId is config Id + if (requestId.equals(configId)) { + return true; + } + } + } + return false; + } + /** * How to execute requests is abstracted out and left to RateLimitedQueue's subclasses to implement. */ diff --git a/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java b/src/main/java/org/opensearch/timeseries/ratelimit/RequestPriority.java similarity index 88% rename from src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java rename to src/main/java/org/opensearch/timeseries/ratelimit/RequestPriority.java index 3193d2285..29fb14523 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/RequestPriority.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; public enum RequestPriority { LOW, diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteRequest.java similarity index 58% rename from src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java rename to src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteRequest.java index a25bf3924..6d5a069f1 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteRequest.java @@ -9,34 +9,28 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.io.IOException; -import org.opensearch.ad.model.AnomalyResult; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.timeseries.model.IndexableResult; -public class ResultWriteRequest extends QueuedRequest implements Writeable { - private final AnomalyResult result; +public abstract class ResultWriteRequest extends QueuedRequest implements Writeable { + private final ResultType result; // If resultIndex is null, result will be stored in default result index. private final String resultIndex; - public ResultWriteRequest( - long expirationEpochMs, - String detectorId, - RequestPriority priority, - AnomalyResult result, - String resultIndex - ) { - super(expirationEpochMs, detectorId, priority); + public ResultWriteRequest(long expirationEpochMs, String configId, RequestPriority priority, ResultType result, String resultIndex) { + super(expirationEpochMs, configId, priority); this.result = result; this.resultIndex = resultIndex; } - public ResultWriteRequest(StreamInput in) throws IOException { - this.result = new AnomalyResult(in); + public ResultWriteRequest(StreamInput in, Writeable.Reader resultReader) throws IOException { + this.result = resultReader.read(in); this.resultIndex = in.readOptionalString(); } @@ -46,11 +40,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(resultIndex); } - public AnomalyResult getResult() { + public ResultType getResult() { return result; } - public String getCustomResultIndex() { + public String getResultIndex() { return resultIndex; } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteWorker.java similarity index 59% rename from src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteWorker.java index 02152b086..faaf7852e 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteWorker.java @@ -1,19 +1,11 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY; +package org.opensearch.timeseries.ratelimit; +import java.io.IOException; import java.time.Clock; import java.time.Duration; import java.util.List; @@ -25,12 +17,8 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.DocWriteRequest; import org.opensearch.action.index.IndexRequest; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.transport.ADResultBulkRequest; -import org.opensearch.ad.transport.ADResultBulkResponse; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.CheckedFunction; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -44,63 +32,78 @@ import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.transport.ResultBulkRequest; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; import org.opensearch.timeseries.util.ExceptionUtil; -public class ResultWriteWorker extends BatchWorker { +public abstract class ResultWriteWorker, BatchRequestType extends ResultBulkRequest, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, ResultHandlerType extends IndexMemoryPressureAwareResultHandler> + extends BatchWorker { private static final Logger LOG = LogManager.getLogger(ResultWriteWorker.class); - public static final String WORKER_NAME = "result-write"; - - private final MultiEntityResultHandler resultHandler; - private NamedXContentRegistry xContentRegistry; + protected final ResultHandlerType resultHandler; + protected NamedXContentRegistry xContentRegistry; + private CheckedFunction resultParser; public ResultWriteWorker( - long heapSizeInBytes, - int singleRequestSizeInBytes, + String queueName, + long heapSize, + int singleRequestSize, Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, + Setting concurrencySetting, Duration executionTtl, - MultiEntityResultHandler resultHandler, + Setting batchSizeSetting, + Duration stateTtl, + NodeStateManager timeSeriesNodeStateManager, + ResultHandlerType resultHandler, NamedXContentRegistry xContentRegistry, - NodeStateManager stateManager, - Duration stateTtl + CheckedFunction resultParser, + AnalysisType context ) { super( - WORKER_NAME, - heapSizeInBytes, - singleRequestSizeInBytes, + queueName, + heapSize, + singleRequestSize, maxHeapPercentForQueueSetting, clusterService, random, adCircuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, - AD_RESULT_WRITE_QUEUE_CONCURRENCY, + concurrencySetting, executionTtl, - AD_RESULT_WRITE_QUEUE_BATCH_SIZE, + batchSizeSetting, stateTtl, - stateManager + timeSeriesNodeStateManager, + context ); this.resultHandler = resultHandler; this.xContentRegistry = xContentRegistry; + this.resultParser = resultParser; } @Override - protected void executeBatchRequest(ADResultBulkRequest request, ActionListener listener) { + protected void executeBatchRequest(BatchRequestType request, ActionListener listener) { if (request.numberOfActions() < 1) { listener.onResponse(null); return; @@ -109,19 +112,7 @@ protected void executeBatchRequest(ADResultBulkRequest request, ActionListener toProcess) { - final ADResultBulkRequest bulkRequest = new ADResultBulkRequest(); - for (ResultWriteRequest request : toProcess) { - bulkRequest.add(request); - } - return bulkRequest; - } - - @Override - protected ActionListener getResponseListener( - List toProcess, - ADResultBulkRequest bulkRequest - ) { + protected ActionListener getResponseListener(List toProcess, BatchRequestType bulkRequest) { return ActionListener.wrap(adResultBulkResponse -> { if (adResultBulkResponse == null || false == adResultBulkResponse.getRetryRequests().isPresent()) { // all successful @@ -134,12 +125,12 @@ protected ActionListener getResponseListener( // retry all of them super.putAll(toProcess); } else if (ExceptionUtil.isOverloaded(exception)) { - LOG.error("too many get AD model checkpoint requests or shard not avialble"); + LOG.error("too many get model checkpoint requests or shard not avialble"); setCoolDownStart(); } - for (ResultWriteRequest request : toProcess) { - nodeStateManager.setException(request.getId(), exception); + for (ResultWriteRequestType request : toProcess) { + nodeStateManager.setException(request.getConfigId(), exception); } LOG.error("Fail to save results", exception); }); @@ -150,50 +141,18 @@ private void enqueueRetryRequestIteration(List requestToRetry, int return; } DocWriteRequest currentRequest = requestToRetry.get(index); - Optional resultToRetry = getAnomalyResult(currentRequest); + Optional resultToRetry = getResult(currentRequest); if (false == resultToRetry.isPresent()) { enqueueRetryRequestIteration(requestToRetry, index + 1); return; } - AnomalyResult result = resultToRetry.get(); - String detectorId = result.getConfigId(); - nodeStateManager.getConfig(detectorId, AnalysisType.AD, onGetDetector(requestToRetry, index, detectorId, result)); - } - private ActionListener> onGetDetector( - List requestToRetry, - int index, - String detectorId, - AnomalyResult resultToRetry - ) { - return ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); - enqueueRetryRequestIteration(requestToRetry, index + 1); - return; - } - - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); - super.put( - new ResultWriteRequest( - // expire based on execute start time - resultToRetry.getExecutionStartTime().toEpochMilli() + detector.getIntervalInMilliseconds(), - detectorId, - resultToRetry.isHighPriority() ? RequestPriority.HIGH : RequestPriority.MEDIUM, - resultToRetry, - detector.getCustomResultIndex() - ) - ); - - enqueueRetryRequestIteration(requestToRetry, index + 1); - - }, exception -> { - LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); - enqueueRetryRequestIteration(requestToRetry, index + 1); - }); + ResultType result = resultToRetry.get(); + String id = result.getConfigId(); + nodeStateManager.getConfig(id, context, onGetConfig(requestToRetry, index, id, result)); } - private Optional getAnomalyResult(DocWriteRequest request) { + protected Optional getResult(DocWriteRequest request) { try { if (false == (request instanceof IndexRequest)) { LOG.error(new ParameterizedMessage("We should only send IndexRquest, but get [{}].", request)); @@ -211,11 +170,52 @@ private Optional getAnomalyResult(DocWriteRequest request) { // org.opensearch.core.common.ParsingException: Failed to parse object: expecting token of type [START_OBJECT] but found // [null] xContentParser.nextToken(); - return Optional.of(AnomalyResult.parse(xContentParser)); + return Optional.of(resultParser.apply(xContentParser)); } } catch (Exception e) { LOG.error(new ParameterizedMessage("Fail to parse index request [{}]", request), e); } return Optional.empty(); } + + private ActionListener> onGetConfig( + List requestToRetry, + int index, + String id, + ResultType resultToRetry + ) { + return ActionListener.wrap(configOptional -> { + if (false == configOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("Config [{}] is not available.", id)); + enqueueRetryRequestIteration(requestToRetry, index + 1); + return; + } + + Config config = configOptional.get(); + super.put( + createResultWriteRequest( + // expire based on execute start time + resultToRetry.getExecutionStartTime().toEpochMilli() + config.getIntervalInMilliseconds(), + id, + resultToRetry.isHighPriority() ? RequestPriority.HIGH : RequestPriority.MEDIUM, + resultToRetry, + config.getCustomResultIndex() + ) + ); + + enqueueRetryRequestIteration(requestToRetry, index + 1); + + }, exception -> { + LOG.error(new ParameterizedMessage("fail to get config [{}]", id), exception); + enqueueRetryRequestIteration(requestToRetry, index + 1); + }); + } + + protected abstract ResultWriteRequestType createResultWriteRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + ResultType result, + String resultIndex + ); } diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/SaveResultStrategy.java b/src/main/java/org/opensearch/timeseries/ratelimit/SaveResultStrategy.java new file mode 100644 index 000000000..d5c907c16 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/SaveResultStrategy.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.ratelimit; + +import java.time.Instant; +import java.util.Optional; + +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; + +public interface SaveResultStrategy> { + void saveResult(RCFResultType result, Config config, FeatureRequest origRequest, String modelId); + + void saveResult( + RCFResultType result, + Config config, + Instant dataStart, + Instant dataEnd, + String modelId, + double[] currentData, + Optional entity, + String taskId + ); + + void saveResult(IndexableResultType result, Config config); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ScheduledWorker.java similarity index 90% rename from src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/ScheduledWorker.java index 115d79882..04dfdd900 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ScheduledWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -18,18 +18,19 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; public abstract class ScheduledWorker extends RateLimitedRequestWorker { - private static final Logger LOG = LogManager.getLogger(ColdEntityWorker.class); + private static final Logger LOG = LogManager.getLogger(ADColdEntityWorker.class); // the number of requests forwarded to the target queue protected volatile int batchSize; @@ -47,6 +48,7 @@ public ScheduledWorker( Random random, CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -55,7 +57,8 @@ public ScheduledWorker( int maintenanceFreqConstant, RateLimitedRequestWorker targetQueue, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { super( workerName, @@ -66,6 +69,7 @@ public ScheduledWorker( random, adCircuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, @@ -73,7 +77,8 @@ public ScheduledWorker( lowSegmentPruneRatio, maintenanceFreqConstant, stateTtl, - nodeStateManager + nodeStateManager, + context ); this.targetQueue = targetQueue; @@ -114,7 +119,7 @@ private void pullRequests() { private synchronized void schedulePulling(TimeValue delay) { try { - threadPool.schedule(this::pullRequests, delay, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + threadPool.schedule(this::pullRequests, delay, threadPoolName); } catch (Exception e) { LOG.error("Fail to schedule cold entity pulling", e); } diff --git a/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/SingleRequestWorker.java similarity index 92% rename from src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/SingleRequestWorker.java index e789e36fa..9b11db99c 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/SingleRequestWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -24,6 +24,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.breaker.CircuitBreakerService; @@ -39,6 +40,7 @@ public SingleRequestWorker( Random random, CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -48,7 +50,8 @@ public SingleRequestWorker( Setting concurrencySetting, Duration executionTtl, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { super( queueName, @@ -59,6 +62,7 @@ public SingleRequestWorker( random, adCircuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, @@ -68,7 +72,8 @@ public SingleRequestWorker( concurrencySetting, executionTtl, stateTtl, - nodeStateManager + nodeStateManager, + context ); } diff --git a/src/main/java/org/opensearch/timeseries/rest/RestJobAction.java b/src/main/java/org/opensearch/timeseries/rest/RestJobAction.java new file mode 100644 index 000000000..f31e6ce0c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/RestJobAction.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.model.DateRange; + +import com.google.common.collect.ImmutableList; + +public abstract class RestJobAction extends BaseRestHandler { + protected DateRange parseInputDateRange(RestRequest request) throws IOException { + if (!request.hasContent()) { + return null; + } + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + DateRange dateRange = DateRange.parse(parser); + return dateRange; + } + + @Override + public List routes() { + return ImmutableList.of(); + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/RestStatsAction.java b/src/main/java/org/opensearch/timeseries/rest/RestStatsAction.java new file mode 100644 index 000000000..bb1585566 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/RestStatsAction.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import java.util.TreeSet; + +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.Strings; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public abstract class RestStatsAction extends BaseRestHandler { + private Stats timeSeriesStats; + private DiscoveryNodeFilterer nodeFilter; + + /** + * Constructor + * + * @param timeSeriesStats TimeSeriesStats object + * @param nodeFilter util class to get eligible data nodes + */ + public RestStatsAction(Stats timeSeriesStats, DiscoveryNodeFilterer nodeFilter) { + this.timeSeriesStats = timeSeriesStats; + this.nodeFilter = nodeFilter; + } + + /** + * Creates a StatsRequest from a RestRequest + * + * @param request RestRequest + * @return StatsRequest Request containing stats to be retrieved + */ + protected StatsRequest getRequest(RestRequest request) { + // parse the nodes the user wants to query the stats for + String nodesIdsStr = request.param("nodeId"); + Set validStats = timeSeriesStats.getStats().keySet(); + + StatsRequest statsRequest = null; + if (!Strings.isEmpty(nodesIdsStr)) { + String[] nodeIdsArr = nodesIdsStr.split(","); + statsRequest = new StatsRequest(nodeIdsArr); + } else { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + statsRequest = new StatsRequest(dataNodes); + } + + statsRequest.timeout(request.param("timeout")); + + // parse the stats the user wants to see + HashSet statsSet = null; + String statsStr = request.param("stat"); + if (!Strings.isEmpty(statsStr)) { + statsSet = new HashSet<>(Arrays.asList(statsStr.split(","))); + } + + if (statsSet == null) { + statsRequest.addAll(validStats); // retrieve all stats if none are specified + } else if (statsSet.size() == 1 && statsSet.contains(StatsRequest.ALL_STATS_KEY)) { + statsRequest.addAll(validStats); + } else if (statsSet.contains(StatsRequest.ALL_STATS_KEY)) { + throw new IllegalArgumentException( + "Request " + request.path() + " contains " + StatsRequest.ALL_STATS_KEY + " and individual stats" + ); + } else { + Set invalidStats = new TreeSet<>(); + for (String stat : statsSet) { + if (validStats.contains(stat)) { + statsRequest.addStat(stat); + } else { + invalidStats.add(stat); + } + } + + if (!invalidStats.isEmpty()) { + throw new IllegalArgumentException(unrecognized(request, invalidStats, statsRequest.getStatsToBeRetrieved(), "stat")); + } + } + return statsRequest; + } + +} diff --git a/src/main/java/org/opensearch/timeseries/rest/RestValidateAction.java b/src/main/java/org/opensearch/timeseries/rest/RestValidateAction.java new file mode 100644 index 000000000..fa546c3d9 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/RestValidateAction.java @@ -0,0 +1,117 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigValidationIssue; +import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; +import org.opensearch.timeseries.transport.ValidateConfigRequest; +import org.opensearch.timeseries.transport.ValidateConfigResponse; + +/** + * This class consists of the REST handler to validate anomaly detector configurations. + */ +public class RestValidateAction { + private AnalysisType context; + private Integer maxSingleStreamConfigs; + private Integer maxHCConfigs; + private Integer maxFeatures; + private Integer maxCategoricalFields; + private TimeValue requestTimeout; + + public RestValidateAction( + AnalysisType context, + Integer maxSingleStreamConfigs, + Integer maxHCConfigs, + Integer maxFeatures, + Integer maxCategoricalFields, + TimeValue requestTimeout + ) { + this.context = context; + this.maxSingleStreamConfigs = maxSingleStreamConfigs; + this.maxHCConfigs = maxHCConfigs; + this.maxFeatures = maxFeatures; + this.maxCategoricalFields = maxCategoricalFields; + this.requestTimeout = requestTimeout; + } + + public void sendValidationParseResponse(ConfigValidationIssue issue, RestChannel channel) throws IOException { + try { + BytesRestResponse restResponse = new BytesRestResponse( + RestStatus.OK, + new ValidateConfigResponse(issue).toXContent(channel.newBuilder()) + ); + channel.sendResponse(restResponse); + } catch (Exception e) { + channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); + } + } + + private Boolean validationTypesAreAccepted(String validationType) { + Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); + return (!Collections.disjoint(typesInRequest, AbstractTimeSeriesActionHandler.ALL_VALIDATION_ASPECTS_STRS)); + } + + public ValidateConfigRequest prepareRequest(RestRequest request, NodeClient client, String typesStr) throws IOException { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + // if type param isn't blank and isn't a part of possible validation types throws exception + if (!StringUtils.isBlank(typesStr)) { + if (!validationTypesAreAccepted(typesStr)) { + throw new IllegalStateException(CommonMessages.NOT_EXISTENT_VALIDATION_TYPE); + } + } + + Config config = null; + + if (context.isAD()) { + config = AnomalyDetector.parse(parser); + } else if (context.isForecast()) { + config = Forecaster.parse(parser); + } else { + throw new UnsupportedOperationException("This method is not supported"); + } + + ValidateConfigRequest validateAnomalyDetectorRequest = new ValidateConfigRequest( + context, + config, + typesStr, + maxSingleStreamConfigs, + maxHCConfigs, + maxFeatures, + requestTimeout, + maxCategoricalFields + ); + return validateAnomalyDetectorRequest; + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java b/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java new file mode 100644 index 000000000..bb23c890e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java @@ -0,0 +1,890 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest.handler; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.constant.CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG; +import static org.opensearch.timeseries.util.ParseUtils.parseAggregators; +import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.isExceptionCausedByInvalidQuery; + +import java.io.IOException; +import java.time.Clock; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.commons.lang.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsAction; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsRequest; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.MergeableList; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.Sets; + +/** + * AbstractTimeSeriesActionHandler serves as the foundational base for handling various time series actions + * such as creating, updating, and validating configurations related to time series analysis. This class encapsulates + * common logic and utilities for managing time series indices, processing requests, and interacting with the OpenSearch cluster + * to execute time series tasks. + * + * Responsibilities include: + * - Validating and processing REST requests for time series configurations, ensuring they comply with predefined + * constraints and formats. + * - Managing interactions with the underlying time series indices, including index creation, document indexing, + * and configuration retrieval. + * - Serving as a base for specialized action handlers that implement specific logic for different types of time series tasks + * (e.g., anomaly detection, forecasting). + * - Handling security and permission validations for time series operations, leveraging OpenSearch's security features + * to ensure operations are performed by authorized users. + * + * The class is designed to be extended by specific action handlers that implement the abstract methods provided, + * allowing for flexible and modular enhancement of the time series capabilities within OpenSearch. + * + * Usage of this class requires extending it to implement the abstract methods, which include but are not limited to + * configuration validation, indexing logic, and model validation. Implementers will benefit from the common utilities + * and framework provided by this class, focusing on the unique logic pertinent to their specific time series task. + */ +public abstract class AbstractTimeSeriesActionHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager> + implements + Processor { + + protected final Logger logger = LogManager.getLogger(AbstractTimeSeriesActionHandler.class); + + public static final String NAME_REGEX = "[a-zA-Z0-9._-]+"; + public static final Integer MAX_NAME_SIZE = 64; + public static final String CATEGORY_NOT_FOUND_ERR_MSG = "Can't find the categorical field %s"; + + public static String INVALID_NAME_SIZE = "Name should be shortened. The maximum limit is " + + AbstractTimeSeriesActionHandler.MAX_NAME_SIZE + + " characters."; + + public static final Set ALL_VALIDATION_ASPECTS_STRS = Arrays + .asList(ValidationAspect.values()) + .stream() + .map(aspect -> aspect.getName()) + .collect(Collectors.toSet()); + + protected final Config config; + protected final IndexManagement timeSeriesIndices; + protected final boolean isDryRun; + protected final Client client; + protected final String id; + protected final SecurityClientUtil clientUtil; + protected final User user; + protected final RestRequest.Method method; + protected final ConfigUpdateConfirmer handler; + protected final ClusterService clusterService; + protected final NamedXContentRegistry xContentRegistry; + protected final TimeValue requestTimeout; + protected final WriteRequest.RefreshPolicy refreshPolicy; + protected final Long seqNo; + protected final Long primaryTerm; + protected final String validationType; + protected final SearchFeatureDao searchFeatureDao; + protected final Integer maxFeatures; + protected final Integer maxCategoricalFields; + protected final AnalysisType context; + protected final List batchTasks; + protected final boolean canUpdateEverything; + + protected final Integer maxSingleStreamConfigs; + protected final Integer maxHCConfigs; + protected final Clock clock; + protected final Settings settings; + + public AbstractTimeSeriesActionHandler( + Config config, + IndexManagement timeSeriesIndices, + boolean isDryRun, + Client client, + String id, + SecurityClientUtil clientUtil, + User user, + RestRequest.Method method, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + TransportService transportService, + TimeValue requestTimeout, + WriteRequest.RefreshPolicy refreshPolicy, + Long seqNo, + Long primaryTerm, + String validationType, + SearchFeatureDao searchFeatureDao, + Integer maxFeatures, + Integer maxCategoricalFields, + AnalysisType context, + TaskManagerType taskManager, + List batchTasks, + boolean canUpdateCategoryField, + Integer maxSingleStreamConfigs, + Integer maxHCConfigs, + Clock clock, + Settings settings + ) { + this.config = config; + this.timeSeriesIndices = timeSeriesIndices; + this.isDryRun = isDryRun; + this.client = client; + this.id = id == null ? "" : id; + this.clientUtil = clientUtil; + this.user = user; + this.method = method; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.requestTimeout = requestTimeout; + this.refreshPolicy = refreshPolicy; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.validationType = validationType; + this.searchFeatureDao = searchFeatureDao; + this.maxFeatures = maxFeatures; + this.maxCategoricalFields = maxCategoricalFields; + this.context = context; + this.batchTasks = batchTasks; + this.canUpdateEverything = canUpdateCategoryField; + this.maxSingleStreamConfigs = maxSingleStreamConfigs; + this.maxHCConfigs = maxHCConfigs; + this.clock = clock; + this.settings = settings; + this.handler = new ConfigUpdateConfirmer<>(taskManager, transportService); + } + + /** + * Start function to process create/update/validate config request. + * + * If validation type is detector/forecaster then all validation in this class involves validation + * checks against the configurations. + * Any issues raised here would block user from creating the config (e.g., anomaly detector). + * If validation Aspect is of type model then further non-blocker validation will be executed + * after the blocker validation is executed. Any issues that are raised for model validation + * are simply warnings for the user in terms of how configuration could be changed to lead to + * a higher likelihood of model training completing successfully. + * + * For custom index validation, if config is not using custom result index, check if config + * index exist first, if not, will create first. Otherwise, check if custom + * result index exists or not. If exists, will check if index mapping matches + * config result index mapping and if user has correct permission to write index. + * If doesn't exist, will create custom result index with result index + * mapping. + */ + @Override + public void start(ActionListener listener) { + String resultIndex = config.getCustomResultIndex(); + // use default detector result index which is system index + if (resultIndex == null) { + createOrUpdateConfig(listener); + return; + } + + if (this.isDryRun) { + if (timeSeriesIndices.doesIndexExist(resultIndex)) { + timeSeriesIndices + .validateResultIndexAndExecute( + resultIndex, + () -> createOrUpdateConfig(listener), + false, + ActionListener.wrap(r -> createOrUpdateConfig(listener), ex -> { + logger.error(ex); + listener.onFailure(createValidationException(ex.getMessage(), ValidationIssueType.RESULT_INDEX)); + return; + }) + ); + return; + } else { + createOrUpdateConfig(listener); + return; + } + } + // use custom result index if not validating and resultIndex not null + timeSeriesIndices.initCustomResultIndexAndExecute(resultIndex, () -> createOrUpdateConfig(listener), listener); + } + + // if isDryRun is true then this method is being executed through Validation API meaning actual + // index won't be created, only validation checks will be executed throughout the class + private void createOrUpdateConfig(ActionListener listener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (!timeSeriesIndices.doesConfigIndexExist() && !this.isDryRun) { + logger.info("Config Indices do not exist"); + timeSeriesIndices + .initConfigIndex( + ActionListener + .wrap( + response -> onCreateMappingsResponse(response, false, listener), + exception -> listener.onFailure(exception) + ) + ); + } else { + logger.info("DryRun variable " + this.isDryRun); + validateName(this.isDryRun, listener); + } + } catch (Exception e) { + logger.error("Failed to create or update forecaster " + id, e); + listener.onFailure(e); + } + } + + protected void validateName(boolean indexingDryRun, ActionListener listener) { + if (!config.getName().matches(NAME_REGEX)) { + listener.onFailure(createValidationException(CommonMessages.INVALID_NAME, ValidationIssueType.NAME)); + return; + + } + if (config.getName().length() > MAX_NAME_SIZE) { + listener.onFailure(createValidationException(AbstractTimeSeriesActionHandler.INVALID_NAME_SIZE, ValidationIssueType.NAME)); + return; + } + validateTimeField(indexingDryRun, listener); + } + + protected void validateTimeField(boolean indexingDryRun, ActionListener listener) { + String givenTimeField = config.getTimeField(); + GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); + getMappingsRequest.indices(config.getIndices().toArray(new String[0])).fields(givenTimeField); + getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); + + // comments explaining fieldMappingResponse parsing can be found inside validateCategoricalField(String, boolean) + ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { + boolean foundField = false; + Map> mappingsByIndex = getMappingsResponse.mappings(); + + for (Map mappingsByField : mappingsByIndex.values()) { + for (Map.Entry field2Metadata : mappingsByField.entrySet()) { + + GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); + if (fieldMetadata != null) { + // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field + Map fieldMap = fieldMetadata.sourceAsMap(); + if (fieldMap != null) { + for (Object type : fieldMap.values()) { + if (type instanceof Map) { + foundField = true; + Map metadataMap = (Map) type; + String typeName = (String) metadataMap.get(CommonName.TYPE); + if (!typeName.equals(CommonName.DATE_TYPE)) { + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, CommonMessages.INVALID_TIMESTAMP, givenTimeField), + ValidationIssueType.TIMEFIELD_FIELD, + ValidationAspect.FORECASTER + ) + ); + return; + } + } + } + } + } + } + } + if (!foundField) { + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, CommonMessages.NON_EXISTENT_TIMESTAMP, givenTimeField), + ValidationIssueType.TIMEFIELD_FIELD, + ValidationAspect.FORECASTER + ) + ); + return; + } + prepareConfigIndexing(indexingDryRun, listener); + }, error -> { + String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", config.getIndices()); + logger.error(message, error); + listener.onFailure(new IllegalArgumentException(message)); + }); + clientUtil + .executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, context, mappingsListener); + } + + /** + * Prepare for indexing a new config. + * @param indexingDryRun if this is dryrun for indexing; when validation, it is true; when create/update, it is false + */ + protected void prepareConfigIndexing(boolean indexingDryRun, ActionListener listener) { + if (method == RestRequest.Method.PUT) { + handler + .confirmJobRunning( + clusterService, + client, + id, + listener, + () -> updateConfig(id, indexingDryRun, listener), + xContentRegistry + ); + } else { + createConfig(indexingDryRun, listener); + } + } + + protected void updateConfig(String id, boolean indexingDryRun, ActionListener listener) { + GetRequest request = new GetRequest(CommonName.CONFIG_INDEX, id); + client + .get( + request, + ActionListener + .wrap( + response -> onGetConfigResponse(response, indexingDryRun, id, listener), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onGetConfigResponse(GetResponse response, boolean indexingDryRun, String id, ActionListener listener) { + if (!response.isExists()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + id, RestStatus.NOT_FOUND)); + return; + } + try (XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Config existingConfig = parse(parser, response); + // If category field changed, frontend may not be able to render AD result for different config types correctly. + // For example, if an anomaly detector changed from HC to single entity detector, AD result page may show multiple anomaly + // result points on the same time point if there are multiple entities have anomaly results. + // If single-category HC changed category field from IP to error type, the AD result page may show both IP and error type + // in top N entities list. That's confusing. + // So we decide to block updating detector category field. + // for forecasting, we will not show results after forecaster configuration change (excluding changes like description) + // thus it is safe to allow updating everything. In the future, we might change AD to allow such behavior. + if (!canUpdateEverything) { + if (!ParseUtils.listEqualsWithoutConsideringOrder(existingConfig.getCategoryFields(), config.getCategoryFields())) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CATEGORY_FIELD, RestStatus.BAD_REQUEST)); + return; + } + if (!Objects.equals(existingConfig.getCustomResultIndex(), config.getCustomResultIndex())) { + listener + .onFailure( + new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CUSTOM_RESULT_INDEX, RestStatus.BAD_REQUEST) + ); + return; + } + } + + ActionListener confirmBatchRunningListener = ActionListener + .wrap( + r -> searchConfigInputIndices(id, indexingDryRun, listener), + // can't update config if there is task running + listener::onFailure + ); + + handler.confirmBatchRunning(id, batchTasks, confirmBatchRunningListener); + } catch (IOException e) { + String message = "Failed to parse anomaly detector " + id; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + + } + + protected void validateAgainstExistingHCConfig(String detectorId, boolean indexingDryRun, ActionListener listener) { + if (timeSeriesIndices.doesConfigIndexExist()) { + QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(Config.CATEGORY_FIELD)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + client + .search( + searchRequest, + ActionListener + .wrap( + response -> onSearchHCConfigResponse(response, detectorId, indexingDryRun, listener), + exception -> listener.onFailure(exception) + ) + ); + } else { + validateCategoricalField(detectorId, indexingDryRun, listener); + } + + } + + protected void createConfig(boolean indexingDryRun, ActionListener listener) { + try { + List categoricalFields = config.getCategoryFields(); + if (categoricalFields != null && categoricalFields.size() > 0) { + validateAgainstExistingHCConfig(null, indexingDryRun, listener); + } else { + if (timeSeriesIndices.doesConfigIndexExist()) { + QueryBuilder query = QueryBuilders.matchAllQuery(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + + client + .search( + searchRequest, + ActionListener + .wrap( + response -> onSearchSingleStreamConfigResponse(response, indexingDryRun, listener), + exception -> listener.onFailure(exception) + ) + ); + } else { + searchConfigInputIndices(null, indexingDryRun, listener); + } + + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void onSearchSingleStreamConfigResponse(SearchResponse response, boolean indexingDryRun, ActionListener listener) + throws IOException { + if (response.getHits().getTotalHits().value >= getMaxSingleStreamConfigs()) { + String errorMsgSingleEntity = getExceedMaxSingleStreamConfigsErrorMsg(getMaxSingleStreamConfigs()); + logger.error(errorMsgSingleEntity); + if (indexingDryRun) { + listener.onFailure(createValidationException(errorMsgSingleEntity, ValidationIssueType.GENERAL_SETTINGS)); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsgSingleEntity)); + } else { + searchConfigInputIndices(null, indexingDryRun, listener); + } + } + + protected void onSearchHCConfigResponse(SearchResponse response, String detectorId, boolean indexingDryRun, ActionListener listener) + throws IOException { + if (response.getHits().getTotalHits().value >= getMaxHCConfigs()) { + String errorMsg = getExceedMaxHCConfigsErrorMsg(getMaxHCConfigs()); + logger.error(errorMsg); + if (indexingDryRun) { + listener.onFailure(createValidationException(errorMsg, ValidationIssueType.GENERAL_SETTINGS)); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsg)); + } else { + validateCategoricalField(detectorId, indexingDryRun, listener); + } + } + + @SuppressWarnings("unchecked") + protected void validateCategoricalField(String detectorId, boolean indexingDryRun, ActionListener listener) { + List categoryField = config.getCategoryFields(); + + if (categoryField == null) { + searchConfigInputIndices(detectorId, indexingDryRun, listener); + return; + } + + // we only support a certain number of categorical field + // If there is more fields than required, Config's constructor + // throws validation exception before reaching this line + int maxCategoryFields = maxCategoricalFields; + if (categoryField.size() > maxCategoryFields) { + listener + .onFailure( + createValidationException(CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields), ValidationIssueType.CATEGORY) + ); + return; + } + + String categoryField0 = categoryField.get(0); + + GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); + getMappingsRequest.indices(config.getIndices().toArray(new String[0])).fields(categoryField.toArray(new String[0])); + getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); + + ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { + // example getMappingsResponse: + // GetFieldMappingsResponse{mappings={server-metrics={_doc={service=FieldMappingMetadata{fullName='service', + // source=org.opensearch.core.common.bytes.BytesArray@7ba87dbd}}}}} + // for nested field, it would be + // GetFieldMappingsResponse{mappings={server-metrics={_doc={host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', + // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08}}}}} + boolean foundField = false; + + // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata + Map> mappingsByIndex = getMappingsResponse.mappings(); + + for (Map mappingsByField : mappingsByIndex.values()) { + for (Map.Entry field2Metadata : mappingsByField.entrySet()) { + // example output: + // host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', + // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08} + + // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata + + GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); + + if (fieldMetadata != null) { + // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field + Map fieldMap = fieldMetadata.sourceAsMap(); + if (fieldMap != null) { + for (Object type : fieldMap.values()) { + if (type != null && type instanceof Map) { + foundField = true; + Map metadataMap = (Map) type; + String typeName = (String) metadataMap.get(CommonName.TYPE); + if (!typeName.equals(CommonName.KEYWORD_TYPE) && !typeName.equals(CommonName.IP_TYPE)) { + listener + .onFailure( + createValidationException(CATEGORICAL_FIELD_TYPE_ERR_MSG, ValidationIssueType.CATEGORY) + ); + return; + } + } + } + } + + } + } + } + + if (foundField == false) { + listener + .onFailure( + createValidationException( + String.format(Locale.ROOT, CATEGORY_NOT_FOUND_ERR_MSG, categoryField0), + ValidationIssueType.CATEGORY + ) + ); + return; + } + + searchConfigInputIndices(detectorId, indexingDryRun, listener); + }, error -> { + String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", config.getIndices()); + logger.error(message, error); + listener.onFailure(new IllegalArgumentException(message)); + }); + + clientUtil + .executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, context, mappingsListener); + } + + protected void searchConfigInputIndices(String detectorId, boolean indexingDryRun, ActionListener listener) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(QueryBuilders.matchAllQuery()) + .size(0) + .timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + + ActionListener searchResponseListener = ActionListener + .wrap( + searchResponse -> onSearchConfigInputIndicesResponse(searchResponse, detectorId, indexingDryRun, listener), + exception -> listener.onFailure(exception) + ); + + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, context, searchResponseListener); + } + + protected void onSearchConfigInputIndicesResponse( + SearchResponse response, + String detectorId, + boolean indexingDryRun, + ActionListener listener + ) throws IOException { + if (response.getHits().getTotalHits().value == 0) { + String errorMsg = getNoDocsInUserIndexErrorMsg(Arrays.toString(config.getIndices().toArray(new String[0]))); + logger.error(errorMsg); + if (indexingDryRun) { + listener.onFailure(createValidationException(errorMsg, ValidationIssueType.INDICES)); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsg)); + } else { + validateConfigFeatures(detectorId, indexingDryRun, listener); + } + } + + protected void checkConfigNameExists(String configId, boolean indexingDryRun, ActionListener listener) throws IOException { + if (timeSeriesIndices.doesConfigIndexExist()) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + // src/main/resources/mappings/config.json#L14 + boolQueryBuilder.must(QueryBuilders.termQuery("name.keyword", config.getName())); + if (StringUtils.isNotBlank(configId)) { + boolQueryBuilder.mustNot(QueryBuilders.termQuery(RestHandlerUtils._ID, configId)); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQueryBuilder).timeout(requestTimeout); + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + client + .search( + searchRequest, + ActionListener + .wrap( + searchResponse -> onSearchConfigNameResponse(searchResponse, config.getName(), indexingDryRun, listener), + exception -> listener.onFailure(exception) + ) + ); + } else { + tryIndexingConfig(indexingDryRun, listener); + } + + } + + protected void onSearchConfigNameResponse(SearchResponse response, String name, boolean indexingDryRun, ActionListener listener) + throws IOException { + if (response.getHits().getTotalHits().value > 0) { + String errorMsg = getDuplicateConfigErrorMsg(name); + logger.warn(errorMsg); + if (indexingDryRun) { + listener.onFailure(createValidationException(errorMsg, ValidationIssueType.NAME)); + } else { + listener.onFailure(new OpenSearchStatusException(errorMsg, RestStatus.CONFLICT)); + } + } else { + tryIndexingConfig(indexingDryRun, listener); + } + } + + protected void tryIndexingConfig(boolean indexingDryRun, ActionListener listener) throws IOException { + if (!indexingDryRun) { + indexConfig(id, listener); + } else { + finishConfigValidationOrContinueToModelValidation(listener); + } + } + + protected Set getValidationTypes(String validationType) { + if (StringUtils.isBlank(validationType)) { + return getDefaultValidationType(); + } else { + Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); + return ValidationAspect + .getNames(Sets.intersection(AbstractTimeSeriesActionHandler.ALL_VALIDATION_ASPECTS_STRS, typesInRequest)); + } + } + + protected void finishConfigValidationOrContinueToModelValidation(ActionListener listener) { + logger.info("Skipping indexing detector. No blocking issue found so far."); + if (!getValidationTypes(validationType).contains(ValidationAspect.MODEL)) { + listener.onResponse(null); + } else { + validateModel(listener); + } + } + + @SuppressWarnings("unchecked") + protected void indexConfig(String id, ActionListener listener) throws IOException { + Config copiedConfig = copyConfig(user, config); + IndexRequest indexRequest = new IndexRequest(CommonName.CONFIG_INDEX) + .setRefreshPolicy(refreshPolicy) + .source(copiedConfig.toXContent(XContentFactory.jsonBuilder(), XCONTENT_WITH_TYPE)) + .setIfSeqNo(seqNo) + .setIfPrimaryTerm(primaryTerm) + .timeout(requestTimeout); + if (StringUtils.isNotBlank(id)) { + indexRequest.id(id); + } + + client.index(indexRequest, new ActionListener() { + @Override + public void onResponse(IndexResponse indexResponse) { + String errorMsg = checkShardsFailure(indexResponse); + if (errorMsg != null) { + listener.onFailure(new OpenSearchStatusException(errorMsg, indexResponse.status())); + return; + } + listener.onResponse(createIndexConfigResponse(indexResponse, copiedConfig)); + } + + @Override + public void onFailure(Exception e) { + logger.warn("Failed to update config", e); + if (e.getMessage() != null && e.getMessage().contains("version conflict")) { + listener.onFailure(new IllegalArgumentException("There was a problem updating the config:[" + id + "]")); + } else { + listener.onFailure(e); + } + } + }); + } + + protected void onCreateMappingsResponse(CreateIndexResponse response, boolean indexingDryRun, ActionListener listener) { + if (response.isAcknowledged()) { + logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); + prepareConfigIndexing(indexingDryRun, listener); + } else { + logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); + listener + .onFailure( + new OpenSearchStatusException( + "Created " + CommonName.CONFIG_INDEX + "with mappings call not acknowledged.", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } + } + + protected String checkShardsFailure(IndexResponse response) { + StringBuilder failureReasons = new StringBuilder(); + if (response.getShardInfo().getFailed() > 0) { + for (ReplicationResponse.ShardInfo.Failure failure : response.getShardInfo().getFailures()) { + failureReasons.append(failure); + } + return failureReasons.toString(); + } + return null; + } + + /** + * Validate config/syntax, and runtime error of config features + * @param id config id + * @param indexingDryRun if false, then will eventually index detector; true, skip indexing detector + * @throws IOException when fail to parse feature aggregation + */ + // TODO: move this method to util class so that it can be re-usable for more use cases + // https://github.com/opensearch-project/anomaly-detection/issues/39 + protected void validateConfigFeatures(String id, boolean indexingDryRun, ActionListener listener) throws IOException { + if (config != null && (config.getFeatureAttributes() == null || config.getFeatureAttributes().isEmpty())) { + checkConfigNameExists(id, indexingDryRun, listener); + return; + } + // checking configuration/syntax error of detector features + String error = RestHandlerUtils.checkFeaturesSyntax(config, maxFeatures); + if (StringUtils.isNotBlank(error)) { + if (indexingDryRun) { + listener.onFailure(createValidationException(error, ValidationIssueType.FEATURE_ATTRIBUTES)); + return; + } + listener.onFailure(new OpenSearchStatusException(error, RestStatus.BAD_REQUEST)); + return; + } + // checking runtime error from feature query + ActionListener>> validateFeatureQueriesListener = ActionListener.wrap(response -> { + checkConfigNameExists(id, indexingDryRun, listener); + }, exception -> { listener.onFailure(createValidationException(exception.getMessage(), ValidationIssueType.FEATURE_ATTRIBUTES)); }); + MultiResponsesDelegateActionListener>> multiFeatureQueriesResponseListener = + new MultiResponsesDelegateActionListener>>( + validateFeatureQueriesListener, + config.getFeatureAttributes().size(), + getFeatureErrorMsg(config.getName()), + false + ); + + for (Feature feature : config.getFeatureAttributes()) { + SearchSourceBuilder ssb = new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery()); + AggregatorFactories.Builder internalAgg = parseAggregators( + feature.getAggregation().toString(), + xContentRegistry, + feature.getId() + ); + ssb.aggregation(internalAgg.getAggregatorFactories().iterator().next()); + SearchRequest searchRequest = new SearchRequest().indices(config.getIndices().toArray(new String[0])).source(ssb); + ActionListener searchResponseListener = ActionListener.wrap(response -> { + Optional aggFeatureResult = searchFeatureDao.parseResponse(response, Arrays.asList(feature.getId())); + if (aggFeatureResult.isPresent()) { + multiFeatureQueriesResponseListener + .onResponse( + new MergeableList>(new ArrayList>(Arrays.asList(aggFeatureResult))) + ); + } else { + String errorMessage = CommonMessages.FEATURE_WITH_EMPTY_DATA_MSG + feature.getName(); + logger.error(errorMessage); + multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); + } + }, e -> { + String errorMessage; + if (isExceptionCausedByInvalidQuery(e)) { + errorMessage = CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG + feature.getName(); + } else { + errorMessage = CommonMessages.UNKNOWN_SEARCH_QUERY_EXCEPTION_MSG + feature.getName(); + } + logger.error(errorMessage, e); + multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST, e)); + }); + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, context, searchResponseListener); + } + } + + protected Integer getMaxSingleStreamConfigs() { + return maxSingleStreamConfigs; + } + + protected Integer getMaxHCConfigs() { + return maxHCConfigs; + } + + protected abstract TimeSeriesException createValidationException(String msg, ValidationIssueType type); + + protected abstract Config parse(XContentParser parser, GetResponse response) throws IOException; + + protected abstract String getExceedMaxSingleStreamConfigsErrorMsg(int maxSingleStreamConfigs); + + protected abstract String getExceedMaxHCConfigsErrorMsg(int maxHCConfigs); + + protected abstract String getNoDocsInUserIndexErrorMsg(String suppliedIndices); + + protected abstract String getDuplicateConfigErrorMsg(String nane); + + protected abstract String getFeatureErrorMsg(String id); + + protected abstract Config copyConfig(User user, Config config); + + protected abstract T createIndexConfigResponse(IndexResponse indexResponse, Config config); + + protected abstract Set getDefaultValidationType(); + + /** + * Validate model + * @param listener listener to return response + */ + protected abstract void validateModel(ActionListener listener); +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/ConfigUpdateConfirmer.java b/src/main/java/org/opensearch/timeseries/rest/handler/ConfigUpdateConfirmer.java new file mode 100644 index 000000000..8e676113c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/ConfigUpdateConfirmer.java @@ -0,0 +1,140 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.rest.handler; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +/** + * Get job to make sure job has been stopped before updating a config. + */ +public class ConfigUpdateConfirmer & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager> { + + private final Logger logger = LogManager.getLogger(ConfigUpdateConfirmer.class); + + private final TaskManagerType taskManager; + private final TransportService transportService; + + public ConfigUpdateConfirmer(TaskManagerType taskManager, TransportService transportService) { + this.taskManager = taskManager; + this.transportService = transportService; + } + + /** + * Get job for update/delete config. + * If job exist, will return error message; otherwise, execute function. + * + * @param clusterService OS cluster service + * @param client OS node client + * @param id job identifier + * @param listener Listener to send response + * @param function time series function + * @param xContentRegistry Registry which is used for XContentParser + */ + public void confirmJobRunning( + ClusterService clusterService, + Client client, + String id, + ActionListener listener, + ExecutorFunction function, + NamedXContentRegistry xContentRegistry + ) { + // forecasting and ad share the same job index + if (clusterService.state().metadata().indices().containsKey(CommonName.JOB_INDEX)) { + GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(id); + client + .get( + request, + ActionListener.wrap(response -> onGetJobResponseForWrite(response, listener, function, xContentRegistry), exception -> { + logger.error("Fail to get job: " + id, exception); + listener.onFailure(exception); + }) + ); + } else { + function.execute(); + } + } + + private void onGetJobResponseForWrite( + GetResponse response, + ActionListener listener, + ExecutorFunction function, + NamedXContentRegistry xContentRegistry + ) { + if (response.isExists()) { + String jobId = response.getId(); + if (jobId != null) { + // check if job is running, if yes, we can't delete the config + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job adJob = Job.parse(parser); + if (adJob.isEnabled()) { + listener.onFailure(new OpenSearchStatusException("Job is running: " + jobId, RestStatus.BAD_REQUEST)); + return; + } + } catch (IOException e) { + String message = "Failed to parse job " + jobId; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.BAD_REQUEST)); + } + } + } + function.execute(); + } + + /** + * Confirm if any historical or run once is running. If there is still any left over tasks running, + * listener returns failure complaining task running. Otherwise, listener response returns null + * (indicating no batch running). + * @param configId Config id + * @param tasks tasks to check. + * @param listener to return response or failure. + */ + public void confirmBatchRunning(String configId, List tasks, ActionListener listener) { + taskManager.getAndExecuteOnLatestConfigLevelTask(configId, tasks, (task) -> { + if (task.isPresent() && !task.get().isDone()) { + // can't update config if there is task running + listener.onFailure(new OpenSearchStatusException("Run once or historical is running", RestStatus.BAD_REQUEST)); + } else { + listener.onResponse(null); + } + }, transportService, false, listener); // false means don't reset task state as inactive/stopped state + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/HistogramAggregationHelper.java b/src/main/java/org/opensearch/timeseries/rest/handler/HistogramAggregationHelper.java new file mode 100644 index 000000000..eeaff0dc1 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/HistogramAggregationHelper.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest.handler; + +import java.time.Duration; +import java.time.ZonedDateTime; +import java.util.List; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.opensearch.search.aggregations.bucket.histogram.Histogram; +import org.opensearch.search.aggregations.bucket.histogram.Histogram.Bucket; +import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +/** + * the class provides helper methods specifically for histogram aggregations + * + */ +public class HistogramAggregationHelper { + protected static final Logger logger = LogManager.getLogger(HistogramAggregationHelper.class); + + protected static final String AGGREGATION = "agg"; + + private Config config; + private final TimeValue requestTimeout; + + public HistogramAggregationHelper(Config config, TimeValue requestTimeout) { + this.config = config; + this.requestTimeout = requestTimeout; + } + + public Histogram checkBucketResultErrors(SearchResponse response) { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date with + // the large amounts of changes there). For this reason I'm not throwing a SearchException but instead a validation exception + // which will be converted to validation response. + logger.warn("Unexpected null aggregation."); + throw new ValidationException( + CommonMessages.MODEL_VALIDATION_FAILED_UNEXPECTEDLY, + ValidationIssueType.AGGREGATION, + ValidationAspect.MODEL + ); + } + Histogram aggregate = aggs.get(AGGREGATION); + if (aggregate == null) { + throw new IllegalArgumentException("Failed to find valid aggregation result"); + } + return aggregate; + } + + public AggregationBuilder getBucketAggregation(int intervalInMinutes, LongBounds timeStampBound) { + return AggregationBuilders + .dateHistogram(AGGREGATION) + .field(config.getTimeField()) + .minDocCount(1) + .hardBounds(timeStampBound) + .fixedInterval(DateHistogramInterval.minutes(intervalInMinutes)); + } + + public Long timeConfigToMilliSec(TimeConfiguration timeConfig) { + return Optional.ofNullable((IntervalTimeConfiguration) timeConfig).map(t -> t.toDuration().toMillis()).orElse(0L); + } + + public LongBounds getTimeRangeBounds(long endMillis, long intervalInMillis) { + Long startMillis = endMillis - (getNumberOfSamples(intervalInMillis) * intervalInMillis); + return new LongBounds(startMillis, endMillis); + } + + public int getNumberOfSamples(long intervalInMillis) { + return Math + .max( + (int) (Duration.ofHours(TimeSeriesSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS).toMillis() / intervalInMillis), + TimeSeriesSettings.MIN_TRAIN_SAMPLES + ); + } + + /** + * @param histogram buckets returned via Date historgram aggregation + * @param intervalInMillis suggested interval to use + * @return the number of buckets having data + */ + public double processBucketAggregationResults(Histogram histogram, long intervalInMillis, Config config) { + // In all cases, when the specified end time does not exist, the actual end time is the closest available time after the specified + // end. + // so we only have non-empty buckets + List bucketsInResponse = histogram.getBuckets(); + if (bucketsInResponse.size() >= config.getShingleSize() + TimeSeriesSettings.NUM_MIN_SAMPLES) { + long minTimestampMillis = convertKeyToEpochMillis(bucketsInResponse.get(0).getKey()); + long maxTimestampMillis = convertKeyToEpochMillis(bucketsInResponse.get(bucketsInResponse.size() - 1).getKey()); + double totalBuckets = (maxTimestampMillis - minTimestampMillis) / intervalInMillis; + return histogram.getBuckets().size() / totalBuckets; + } + return 0; + } + + public SearchSourceBuilder getSearchSourceBuilder(QueryBuilder query, AggregationBuilder aggregation) { + return new SearchSourceBuilder().query(query).aggregation(aggregation).size(0).timeout(requestTimeout); + } + + public static long convertKeyToEpochMillis(Object key) { + return key instanceof ZonedDateTime ? ((ZonedDateTime) key).toInstant().toEpochMilli() + : key instanceof Double ? ((Double) key).longValue() + : key instanceof Long ? (Long) key + : -1L; + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/IndexJobActionHandler.java b/src/main/java/org/opensearch/timeseries/rest/handler/IndexJobActionHandler.java new file mode 100644 index 000000000..92cb6ad65 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/IndexJobActionHandler.java @@ -0,0 +1,594 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.rest.handler; + +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionType; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.schedule.Schedule; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.timeseries.transport.ResultRequest; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +import com.google.common.base.Throwables; + +/** + * job REST action handler to process POST/PUT request. + */ +public abstract class IndexJobActionHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager, IndexableResultType extends IndexableResult, ProfileActionType extends ActionType, ExecuteResultResponseRecorderType extends ExecuteResultResponseRecorder> { + + private final IndexManagementType indexManagement; + private final Client client; + private final NamedXContentRegistry xContentRegistry; + protected final TaskManagerType taskManager; + + private final Logger logger = LogManager.getLogger(IndexJobActionHandler.class); + private final TimeValue requestTimeout; + private final ExecuteResultResponseRecorderType recorder; + private final ActionType> resultAction; + private final AnalysisType analysisType; + private final String stateIndex; + private final ActionType stopConfigAction; + protected final NodeStateManager nodeStateManager; + + /** + * Constructor function. + * + * @param client ES node client that executes actions on the local node + * @param indexManagement index manager + * @param xContentRegistry Registry which is used for XContentParser + * @param taskManager task manager + * @param recorder Utility to record AnomalyResultAction execution result + * @param resultAction result action + * @param analysisType analysis type + * @param stateIndex State index name + * @param stopConfigAction Stop config action + * @param nodeStateManager Node state manager + * @param settings Node settings + * @param timeoutSetting timeout setting + */ + public IndexJobActionHandler( + Client client, + IndexManagementType indexManagement, + NamedXContentRegistry xContentRegistry, + TaskManagerType taskManager, + ExecuteResultResponseRecorderType recorder, + ActionType> resultAction, + AnalysisType analysisType, + String stateIndex, + ActionType stopConfigAction, + NodeStateManager nodeStateManager, + Settings settings, + Setting timeoutSetting + ) { + this.client = client; + this.indexManagement = indexManagement; + this.xContentRegistry = xContentRegistry; + this.taskManager = taskManager; + this.recorder = recorder; + this.resultAction = resultAction; + this.analysisType = analysisType; + this.stateIndex = stateIndex; + this.stopConfigAction = stopConfigAction; + this.nodeStateManager = nodeStateManager; + this.requestTimeout = timeoutSetting.get(settings); + } + + /** + * Start job. + * 1. If job doesn't exist, create new job. + * 2. If job exists: a). if job enabled, return error message; b). if job disabled, enable job. + * @param config config accessor + * @param listener Listener to send responses + */ + public void startJob(Config config, TransportService transportService, ActionListener listener) { + // this start listener is created & injected throughout the job handler so that whenever the job response is received, + // there's the extra step of trying to index results and update detector state with a 60s delay. + ActionListener startListener = ActionListener.wrap(r -> { + try { + Instant executionEndTime = Instant.now(); + IntervalTimeConfiguration schedule = (IntervalTimeConfiguration) config.getInterval(); + Instant executionStartTime = executionEndTime.minus(schedule.getInterval(), schedule.getUnit()); + ResultRequest getRequest = createResultRequest( + config.getId(), + executionStartTime.toEpochMilli(), + executionEndTime.toEpochMilli() + ); + client + .execute( + resultAction, + getRequest, + ActionListener + .wrap(response -> recorder.indexResult(executionStartTime, executionEndTime, response, config), exception -> { + + recorder + .indexResultException( + executionStartTime, + executionEndTime, + Throwables.getStackTraceAsString(exception), + null, + config + ); + }) + ); + } catch (Exception ex) { + listener.onFailure(ex); + return; + } + listener.onResponse(r); + + }, listener::onFailure); + if (!indexManagement.doesJobIndexExist()) { + indexManagement.initJobIndex(ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); + createJob(config, transportService, startListener); + } else { + logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); + startListener + .onFailure( + new OpenSearchStatusException( + "Created " + CommonName.CONFIG_INDEX + " with mappings call not acknowledged.", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } + }, exception -> startListener.onFailure(exception))); + } else { + createJob(config, transportService, startListener); + } + } + + private void createJob(Config config, TransportService transportService, ActionListener listener) { + try { + IntervalTimeConfiguration interval = (IntervalTimeConfiguration) config.getInterval(); + Schedule schedule = new IntervalSchedule(Instant.now(), (int) interval.getInterval(), interval.getUnit()); + Duration duration = Duration.of(interval.getInterval(), interval.getUnit()); + + Job job = new Job( + config.getId(), + schedule, + config.getWindowDelay(), + true, + Instant.now(), + null, + Instant.now(), + duration.getSeconds(), + config.getUser(), + config.getCustomResultIndex(), + analysisType + ); + + getJobForWrite(config, job, transportService, listener); + } catch (Exception e) { + String message = "Failed to parse job " + config.getId(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + private void getJobForWrite(Config config, Job job, TransportService transportService, ActionListener listener) { + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(config.getId()); + + client + .get( + getRequest, + ActionListener + .wrap( + response -> onGetJobForWrite(response, config, job, transportService, listener), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onGetJobForWrite( + GetResponse response, + Config config, + Job job, + TransportService transportService, + ActionListener listener + ) throws IOException { + if (response.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job currentAdJob = Job.parse(parser); + if (currentAdJob.isEnabled()) { + listener + .onFailure( + new OpenSearchStatusException("Anomaly detector job is already running: " + config.getId(), RestStatus.OK) + ); + return; + } else { + Job newJob = new Job( + job.getName(), + job.getSchedule(), + job.getWindowDelay(), + job.isEnabled(), + Instant.now(), + currentAdJob.getDisabledTime(), + Instant.now(), + job.getLockDurationSeconds(), + job.getUser(), + job.getCustomResultIndex(), + job.getAnalysisType() + ); + // Get latest realtime task and check its state before index job. Will reset running realtime task + // as STOPPED first if job disabled, then start new job and create new realtime task. + startConfig( + config, + null, + job.getUser(), + transportService, + ActionListener.wrap(r -> { indexJob(newJob, null, listener); }, e -> { + // Have logged error message in ADTaskManager#startDetector + listener.onFailure(e); + }) + ); + } + } catch (IOException e) { + String message = "Failed to parse anomaly detector job " + job.getName(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } else { + startConfig( + config, + null, + job.getUser(), + transportService, + ActionListener.wrap(r -> { indexJob(job, null, listener); }, e -> listener.onFailure(e)) + ); + } + } + + /** + * Start config. + * For historical analysis, this method will be called on coordinating node. + * For realtime task, we won't know AD job coordinating node until AD job starts. So + * this method will be called on vanilla node. + * + * Will init task index if not exist and write new AD task to index. If task index + * exists, will check if there is task running. If no running task, reset old task + * as not latest and clean old tasks which exceeds max old task doc limitation. + * Then find out node with least load and dispatch task to that node(worker node). + * + * @param config anomaly detector + * @param dateRange detection date range + * @param user user + * @param transportService transport service + * @param listener action listener + */ + public void startConfig( + Config config, + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener + ) { + try { + if (indexManagement.doesStateIndexExist()) { + // If state index exist, check if latest AD task is running + taskManager.getAndExecuteOnLatestConfigLevelTask(config, dateRange, false, user, transportService, listener); + } else { + // If state index doesn't exist, create index and execute detector. + indexManagement.initStateIndex(ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("Created {} with mappings.", stateIndex); + taskManager.updateLatestFlagOfOldTasksAndCreateNewTask(config, dateRange, false, user, TaskState.CREATED, listener); + } else { + String error = String.format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, stateIndex); + logger.warn(error); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + taskManager.updateLatestFlagOfOldTasksAndCreateNewTask(config, dateRange, false, user, TaskState.CREATED, listener); + } else { + logger.error("Failed to init anomaly detection state index", e); + listener.onFailure(e); + } + })); + } + } catch (Exception e) { + logger.error("Failed to start detector " + config.getId(), e); + listener.onFailure(e); + } + } + + private void indexJob(Job job, ExecutorFunction function, ActionListener listener) throws IOException { + IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(job.toXContent(XContentFactory.jsonBuilder(), RestHandlerUtils.XCONTENT_WITH_TYPE)) + .timeout(requestTimeout) + .id(job.getName()); + client + .index( + indexRequest, + ActionListener + .wrap( + response -> onIndexAnomalyDetectorJobResponse(response, function, listener), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onIndexAnomalyDetectorJobResponse( + IndexResponse response, + ExecutorFunction function, + ActionListener listener + ) { + if (response == null || (response.getResult() != CREATED && response.getResult() != UPDATED)) { + String errorMsg = ExceptionUtil.getShardsFailure(response); + listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); + return; + } + if (function != null) { + function.execute(); + } else { + JobResponse anomalyDetectorJobResponse = new JobResponse(response.getId()); + listener.onResponse(anomalyDetectorJobResponse); + } + } + + /** + * Stop config job. + * 1.If job not exists, return error message + * 2.If job exists: a).if job state is disabled, return error message; b).if job state is enabled, disable job. + * + * @param configId config identifier + * @param listener Listener to send responses + */ + public void stopJob(String configId, TransportService transportService, ActionListener listener) { + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(configId); + + client.get(getRequest, ActionListener.wrap(response -> { + if (response.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job job = Job.parse(parser); + if (!job.isEnabled()) { + taskManager.stopLatestRealtimeTask(configId, TaskState.STOPPED, null, transportService, listener); + } else { + Job newJob = new Job( + job.getName(), + job.getSchedule(), + job.getWindowDelay(), + false, // disable job + job.getEnabledTime(), + Instant.now(), + Instant.now(), + job.getLockDurationSeconds(), + job.getUser(), + job.getCustomResultIndex(), + job.getAnalysisType() + ); + indexJob( + newJob, + () -> client + .execute( + stopConfigAction, + new StopConfigRequest(configId), + stopConfigListener(configId, transportService, listener) + ), + listener + ); + } + } catch (IOException e) { + String message = "Failed to parse job " + configId; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } else { + listener.onResponse(new JobResponse(configId)); + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + listener.onResponse(new JobResponse(configId)); + } else { + listener.onFailure(exception); + } + })); + } + + private ActionListener stopConfigListener( + String configId, + TransportService transportService, + ActionListener listener + ) { + return new ActionListener() { + @Override + public void onResponse(StopConfigResponse stopDetectorResponse) { + if (stopDetectorResponse.success()) { + logger.info("model deleted successfully for config {}", configId); + // e.g., StopDetectorTransportAction will send out DeleteModelAction which will clear all realtime cache. + // Pass null transport service to method "stopLatestRealtimeTask" to not re-clear coordinating node cache. + taskManager.stopLatestRealtimeTask(configId, TaskState.STOPPED, null, null, listener); + } else { + logger.error("Failed to delete model for config {}", configId); + // If failed to clear all realtime cache, will try to re-clear coordinating node cache. + taskManager + .stopLatestRealtimeTask( + configId, + TaskState.FAILED, + new OpenSearchStatusException("Failed to delete model", RestStatus.INTERNAL_SERVER_ERROR), + transportService, + listener + ); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to delete model for config " + configId, e); + // If failed to clear all realtime cache, will try to re-clear coordinating node cache. + taskManager + .stopLatestRealtimeTask( + configId, + TaskState.FAILED, + new OpenSearchStatusException("Failed to execute stop config action", RestStatus.INTERNAL_SERVER_ERROR), + transportService, + listener + ); + } + }; + } + + /** + * Start config. Will create schedule job for realtime analysis, + * and start task for historical/run once. + * + * @param configId config id + * @param dateRange historical analysis date range + * @param user user + * @param transportService transport service + * @param context thread context + * @param listener action listener + */ + public void startConfig( + String configId, + DateRange dateRange, + User user, + TransportService transportService, + ThreadContext.StoredContext context, + ActionListener listener + ) { + // upgrade index mapping + indexManagement.update(); + + nodeStateManager.getConfig(configId, analysisType, (config) -> { + if (!config.isPresent()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); + return; + } + + // Validate if config is ready to start. Will return null if ready to start. + String errorMessage = validateConfig(config.get()); + if (errorMessage != null) { + listener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); + return; + } + String resultIndex = config.get().getCustomResultIndex(); + if (resultIndex == null) { + startRealtimeOrHistoricalAnalysis(dateRange, user, transportService, listener, config); + return; + } + context.restore(); + indexManagement + .initCustomResultIndexAndExecute( + resultIndex, + () -> startRealtimeOrHistoricalAnalysis(dateRange, user, transportService, listener, config), + listener + ); + + }, listener); + } + + private String validateConfig(Config detector) { + String error = null; + if (detector.getFeatureAttributes().size() == 0) { + error = "Can't start job as no features configured"; + } else if (detector.getEnabledFeatureIds().size() == 0) { + error = "Can't start job as no enabled features configured"; + } + return error; + } + + private void startRealtimeOrHistoricalAnalysis( + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener, + Optional config + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (dateRange == null) { + // start realtime job + startJob(config.get(), transportService, listener); + } else { + // start historical analysis task + taskManager.startHistorical(config.get(), dateRange, user, transportService, listener); + } + } catch (Exception e) { + logger.error("Failed to stash context", e); + listener.onFailure(e); + } + } + + protected abstract ResultRequest createResultRequest(String configID, long start, long end); + + protected abstract List getBatchConfigTaskTypes(); + + public abstract void stopConfig( + String configId, + boolean historical, + User user, + TransportService transportService, + ActionListener listener + ); +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java b/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java new file mode 100644 index 000000000..e3c8a403e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java @@ -0,0 +1,431 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest.handler; + +import java.io.IOException; +import java.time.Clock; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.bucket.histogram.Histogram; +import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class IntervalCalculation { + private final Logger logger = LogManager.getLogger(IntervalCalculation.class); + + private final Config config; + private final TimeValue requestTimeout; + private final HistogramAggregationHelper histogramAggHelper; + private final Client client; + private final SecurityClientUtil clientUtil; + private final User user; + private final AnalysisType context; + private final Clock clock; + private final FullBucketRatePredicate acceptanceCriteria; + + public IntervalCalculation( + Config config, + TimeValue requestTimeout, + Client client, + SecurityClientUtil clientUtil, + User user, + AnalysisType context, + Clock clock + ) { + this.config = config; + this.requestTimeout = requestTimeout; + this.histogramAggHelper = new HistogramAggregationHelper(config, requestTimeout); + this.client = client; + this.clientUtil = clientUtil; + this.user = user; + this.context = context; + this.clock = clock; + this.acceptanceCriteria = new FullBucketRatePredicate(); + + } + + public void findInterval(long latestTime, Map topEntity, ActionListener listener) { + ActionListener> minimumIntervalListener = ActionListener.wrap(minIntervalAndValidity -> { + if (minIntervalAndValidity.getRight()) { + // the minimum interval is also the interval passing acceptance criteria and we can return immediately + listener.onResponse(minIntervalAndValidity.getLeft()); + } else if (minIntervalAndValidity.getLeft() == null) { + // the minimum interval is too large + listener.onResponse(null); + } else { + // starting exploring larger interval + getBucketAggregates(latestTime, topEntity, minIntervalAndValidity.getLeft(), listener); + } + }, listener::onFailure); + // we use 1 minute = 60000 milliseconds to find minimum interval + LongBounds longBounds = histogramAggHelper.getTimeRangeBounds(latestTime, 60000); + findMinimumInterval(topEntity, longBounds, minimumIntervalListener); + } + + private void getBucketAggregates( + long latestTime, + Map topEntity, + IntervalTimeConfiguration minimumInterval, + ActionListener listener + ) throws IOException { + + try { + int newIntervalInMinutes = increaseAndGetNewInterval(minimumInterval); + LongBounds timeStampBounds = histogramAggHelper.getTimeRangeBounds(latestTime, newIntervalInMinutes); + SearchRequest searchRequest = composeIntervalQuery(topEntity, newIntervalInMinutes, timeStampBounds); + ActionListener intervalListener = ActionListener + .wrap(interval -> listener.onResponse(interval), exception -> { + listener.onFailure(exception); + logger.error("Failed to get interval recommendation", exception); + }); + final ActionListener searchResponseListener = new IntervalRecommendationListener( + intervalListener, + searchRequest.source(), + (IntervalTimeConfiguration) config.getInterval(), + clock.millis() + TimeSeriesSettings.TOP_VALIDATE_TIMEOUT_IN_MILLIS, + latestTime, + timeStampBounds + ); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } catch (ValidationException ex) { + listener.onFailure(ex); + } + } + + /** + * + * @param oldInterval + * @return new interval in minutes + */ + private int increaseAndGetNewInterval(IntervalTimeConfiguration oldInterval) { + return (int) Math + .ceil( + IntervalTimeConfiguration.getIntervalInMinute(oldInterval) + * TimeSeriesSettings.INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER + ); + } + + /** + * ActionListener class to handle execution of multiple bucket aggregations one after the other + * Bucket aggregation with different interval lengths are executed one by one to check if the data is dense enough + * We only need to execute the next query if the previous one led to data that is too sparse. + */ + class IntervalRecommendationListener implements ActionListener { + private final ActionListener intervalListener; + SearchSourceBuilder searchSourceBuilder; + IntervalTimeConfiguration currentIntervalToTry; + private final long expirationEpochMs; + private final long latestTime; + private LongBounds currentTimeStampBounds; + + IntervalRecommendationListener( + ActionListener intervalListener, + SearchSourceBuilder searchSourceBuilder, + IntervalTimeConfiguration currentIntervalToTry, + long expirationEpochMs, + long latestTime, + LongBounds timeStampBounds + ) { + this.intervalListener = intervalListener; + this.searchSourceBuilder = searchSourceBuilder; + this.currentIntervalToTry = currentIntervalToTry; + this.expirationEpochMs = expirationEpochMs; + this.latestTime = latestTime; + this.currentTimeStampBounds = timeStampBounds; + } + + @Override + public void onResponse(SearchResponse response) { + try { + Histogram aggregate = null; + try { + aggregate = histogramAggHelper.checkBucketResultErrors(response); + } catch (ValidationException e) { + intervalListener.onFailure(e); + } + + if (aggregate == null) { + intervalListener.onResponse(null); + return; + } + + int newIntervalMinute = increaseAndGetNewInterval(currentIntervalToTry); + double fullBucketRate = histogramAggHelper.processBucketAggregationResults(aggregate, newIntervalMinute * 60000, config); + // If rate is above success minimum then return interval suggestion. + if (fullBucketRate > TimeSeriesSettings.INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE) { + intervalListener.onResponse(this.currentIntervalToTry); + } else if (expirationEpochMs < clock.millis()) { + intervalListener + .onFailure( + new ValidationException( + CommonMessages.TIMEOUT_ON_INTERVAL_REC, + ValidationIssueType.TIMEOUT, + ValidationAspect.MODEL + ) + ); + logger.info(CommonMessages.TIMEOUT_ON_INTERVAL_REC); + // keep trying higher intervals as new interval is below max, and we aren't decreasing yet + } else if (newIntervalMinute < TimeSeriesSettings.MAX_INTERVAL_REC_LENGTH_IN_MINUTES) { + searchWithDifferentInterval(newIntervalMinute); + // The below block is executed only the first time when new interval is above max and + // we aren't decreasing yet, at this point we will start decreasing for the first time + // if we are inside the below block + } else { + // newIntervalMinute >= MAX_INTERVAL_REC_LENGTH_IN_MINUTES + intervalListener.onResponse(null); + } + + } catch (Exception e) { + onFailure(e); + } + } + + private void searchWithDifferentInterval(int newIntervalMinuteValue) { + this.currentIntervalToTry = new IntervalTimeConfiguration(newIntervalMinuteValue, ChronoUnit.MINUTES); + this.currentTimeStampBounds = histogramAggHelper.getTimeRangeBounds(latestTime, newIntervalMinuteValue); + // Searching again using an updated interval + SearchSourceBuilder updatedSearchSourceBuilder = histogramAggHelper + .getSearchSourceBuilder( + searchSourceBuilder.query(), + histogramAggHelper.getBucketAggregation(newIntervalMinuteValue, currentTimeStampBounds) + ); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + new SearchRequest().indices(config.getIndices().toArray(new String[0])).source(updatedSearchSourceBuilder), + client::search, + user, + client, + context, + this + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to recommend new interval", e); + intervalListener + .onFailure( + new ValidationException( + CommonMessages.MODEL_VALIDATION_FAILED_UNEXPECTEDLY, + ValidationIssueType.AGGREGATION, + ValidationAspect.MODEL + ) + ); + } + } + + /** + * This method calculates median timestamp difference as minimum interval. + * + * + * Using the median timestamp difference as a minimum sampling interval is a heuristic approach + * that can be beneficial in specific contexts, especially when dealing with irregularly spaced data. + * + * Advantages: + * 1. Robustness: The median is less sensitive to outliers compared to the mean. This makes it a + * more stable metric in the presence of irregular data points or anomalies. + * 2. Reflects Typical Intervals: The median provides a measure of the "typical" interval between + * data points, which can be useful when there are varying intervals. + * + * Disadvantages: + * 1. Not Standard in Signal Processing: Traditional signal processing often relies on fixed + * sampling rates determined by the Nyquist-Shannon sampling theorem. The median-based approach + * is more of a data-driven heuristic. + * 2. May Not Capture All Features: Depending on the nature of the data, using the median interval + * might miss some rapid events or features in the data. + * + * In summary, while not a standard practice, using the median timestamp difference as a sampling + * interval can be a practical approach in scenarios where data arrival is irregular and there's + * a need to balance between capturing data features and avoiding over-sampling. + * + * @param topEntity top entity to use + * @param timeStampBounds Used to determine start and end date range to search for data + * @param listener returns minimum interval and whether the interval passes data density test + */ + private void findMinimumInterval( + Map topEntity, + LongBounds timeStampBounds, + ActionListener> listener + ) { + try { + SearchRequest searchRequest = composeIntervalQuery(topEntity, 1, timeStampBounds); + final ActionListener searchResponseListener = ActionListener.wrap(response -> { + Histogram aggregate = null; + try { + aggregate = histogramAggHelper.checkBucketResultErrors(response); + } catch (ValidationException e) { + listener.onFailure(e); + } + + if (aggregate == null) { + // fail to find the minimum interval. Return one minute. + logger.warn("Fail to get aggregated result"); + listener.onResponse(Pair.of(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES), Boolean.FALSE)); + return; + } + // In all cases, when the specified end time does not exist, the actual end time is the closest available time after the + // specified end. + // so we only have non-empty buckets + // in the original order, buckets are sorted in the ascending order of timestamps. + // Since the stream processing preserves the order of elements, we don't need to sort timestamps again. + List timestamps = aggregate + .getBuckets() + .stream() + .map(entry -> HistogramAggregationHelper.convertKeyToEpochMillis(entry.getKey())) + .collect(Collectors.toList()); + + if (timestamps.isEmpty()) { + logger.warn("empty data, return one minute by default"); + listener.onResponse(Pair.of(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES), Boolean.FALSE)); + return; + } + + double medianDifference = calculateMedianDifference(timestamps); + long minimumMinutes = millisecondsToCeilMinutes(((Double) medianDifference).longValue()); + if (minimumMinutes > TimeSeriesSettings.MAX_INTERVAL_REC_LENGTH_IN_MINUTES) { + logger.warn("The minimum interval is too large: {}", minimumMinutes); + listener.onResponse(Pair.of(null, false)); + return; + } + listener + .onResponse( + Pair + .of( + new IntervalTimeConfiguration(minimumMinutes, ChronoUnit.MINUTES), + acceptanceCriteria.test(aggregate, minimumMinutes) + ) + ); + }, listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private static double calculateMedianDifference(List timestamps) { + List differences = new ArrayList<>(); + + for (int i = 1; i < timestamps.size(); i++) { + differences.add(timestamps.get(i) - timestamps.get(i - 1)); + } + + Collections.sort(differences); + + int middle = differences.size() / 2; + if (differences.size() % 2 == 0) { + // If even number of differences, return the average of the two middle values + return (differences.get(middle - 1) + differences.get(middle)) / 2.0; + } else { + // If odd number of differences, return the middle value + return differences.get(middle); + } + } + + /** + * Convert a duration in milliseconds to the nearest minute value that is greater than + * or equal to the given duration. + * + * For example, a duration of 123456 milliseconds is slightly more than 2 minutes. + * So, it gets rounded up and the method returns 3. + * + * @param milliseconds The duration in milliseconds. + * @return The rounded up value in minutes. + */ + private static long millisecondsToCeilMinutes(long milliseconds) { + // Since there are 60,000 milliseconds in a minute, we divide by 60,000 to get + // the number of complete minutes. We add 59,999 before division to ensure + // that any duration that exceeds a whole minute but is less than the next + // whole minute is rounded up to the next minute. + return (milliseconds + 59999) / 60000; + } + + private SearchRequest composeIntervalQuery(Map topEntity, int intervalInMinutes, LongBounds timeStampBounds) { + AggregationBuilder aggregation = histogramAggHelper.getBucketAggregation(intervalInMinutes, timeStampBounds); + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(config.getFilterQuery()); + if (config.isHighCardinality()) { + if (topEntity.isEmpty()) { + throw new ValidationException( + CommonMessages.CATEGORY_FIELD_TOO_SPARSE, + ValidationIssueType.CATEGORY, + ValidationAspect.MODEL + ); + } + for (Map.Entry entry : topEntity.entrySet()) { + query.filter(QueryBuilders.termQuery(entry.getKey(), entry.getValue())); + } + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(query) + .aggregation(aggregation) + .size(0) + .timeout(requestTimeout); + return new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + } + + interface HistogramPredicate { + boolean test(Histogram histogram, long minimumMinutes); + } + + class FullBucketRatePredicate implements HistogramPredicate { + + @Override + public boolean test(Histogram histogram, long minimumMinutes) { + double fullBucketRate = histogramAggHelper.processBucketAggregationResults(histogram, minimumMinutes * 60000, config); + // If rate is above success minimum then return true. + return fullBucketRate > TimeSeriesSettings.INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE; + } + + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/LatestTimeRetriever.java b/src/main/java/org/opensearch/timeseries/rest/handler/LatestTimeRetriever.java new file mode 100644 index 000000000..5d0393842 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/LatestTimeRetriever.java @@ -0,0 +1,186 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest.handler; + +import java.time.Instant; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.PipelineAggregatorBuilders; +import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.aggregations.bucket.terms.Terms; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class LatestTimeRetriever { + public static final Logger logger = LogManager.getLogger(LatestTimeRetriever.class); + + protected static final String AGG_NAME_TOP = "top_agg"; + + private final Config config; + // private final ActionListener> listener; + private final HistogramAggregationHelper histogramAggHelper; + private final SecurityClientUtil clientUtil; + private final Client client; + private final User user; + private final AnalysisType context; + private final SearchFeatureDao searchFeatureDao; + + public LatestTimeRetriever( + Config config, + TimeValue requestTimeout, + SecurityClientUtil clientUtil, + Client client, + User user, + AnalysisType context, + SearchFeatureDao searchFeatureDao + ) { + this.config = config; + this.histogramAggHelper = new HistogramAggregationHelper(config, requestTimeout); + this.clientUtil = clientUtil; + this.client = client; + this.user = user; + this.context = context; + this.searchFeatureDao = searchFeatureDao; + } + + /** + * Need to first check if HC analysis or not before retrieving latest date time. + * If the config is HC then we will find the top entity and treat as single stream for + * validation purposes + * @param listener to return latest time and entity attributes if the config is HC + */ + public void checkIfHC(ActionListener, Map>> listener) { + ActionListener> topEntityListener = ActionListener + .wrap( + topEntity -> searchFeatureDao + .getLatestDataTime( + config, + Optional.of(Entity.createEntityByReordering(topEntity)), + context, + ActionListener.wrap(latestTime -> listener.onResponse(Pair.of(latestTime, topEntity)), listener::onFailure) + ), + exception -> { + listener.onFailure(exception); + logger.error("Failed to get top entity for categorical field", exception); + } + ); + if (config.isHighCardinality()) { + getTopEntity(topEntityListener); + } else { + topEntityListener.onResponse(Collections.emptyMap()); + } + } + + // For single category HCs, this method uses bucket aggregation and sort to get the category field + // that have the highest document count in order to use that top entity for further validation + // For multi-category HCs we use a composite aggregation to find the top fields for the entity + // with the highest doc count. + public void getTopEntity(ActionListener> topEntityListener) { + // Look at data back to the lower bound given the max interval we recommend or one given + long maxIntervalInMinutes = Math.max(TimeSeriesSettings.MAX_INTERVAL_REC_LENGTH_IN_MINUTES, config.getIntervalInMinutes()); + LongBounds timeRangeBounds = histogramAggHelper.getTimeRangeBounds(Instant.now().toEpochMilli(), maxIntervalInMinutes * 60000); + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(config.getTimeField()) + .from(timeRangeBounds.getMin()) + .to(timeRangeBounds.getMax()); + AggregationBuilder bucketAggs; + Map topKeys = new HashMap<>(); + if (config.getCategoryFields().size() == 1) { + bucketAggs = AggregationBuilders.terms(AGG_NAME_TOP).field(config.getCategoryFields().get(0)).order(BucketOrder.count(true)); + } else { + bucketAggs = AggregationBuilders + .composite( + AGG_NAME_TOP, + config.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) + ) + .size(1000) + .subAggregation( + PipelineAggregatorBuilders + .bucketSort("bucketSort", Collections.singletonList(new FieldSortBuilder("_count").order(SortOrder.DESC))) + .size(1) + ); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(rangeQuery) + .aggregation(bucketAggs) + .trackTotalHits(false) + .size(0); + SearchRequest searchRequest = new SearchRequest().indices(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(response -> { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + topEntityListener.onResponse(Collections.emptyMap()); + return; + } + if (config.getCategoryFields().size() == 1) { + Terms entities = aggs.get(AGG_NAME_TOP); + Object key = entities + .getBuckets() + .stream() + .max(Comparator.comparingInt(entry -> (int) entry.getDocCount())) + .map(MultiBucketsAggregation.Bucket::getKeyAsString) + .orElse(null); + topKeys.put(config.getCategoryFields().get(0), key); + } else { + CompositeAggregation compositeAgg = aggs.get(AGG_NAME_TOP); + topKeys + .putAll( + compositeAgg + .getBuckets() + .stream() + .flatMap(bucket -> bucket.getKey().entrySet().stream()) // this would create a flattened stream of map entries + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())) + ); + } + for (Map.Entry entry : topKeys.entrySet()) { + if (entry.getValue() == null) { + topEntityListener.onResponse(Collections.emptyMap()); + return; + } + } + topEntityListener.onResponse(topKeys); + }, topEntityListener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/ModelValidationActionHandler.java b/src/main/java/org/opensearch/timeseries/rest/handler/ModelValidationActionHandler.java new file mode 100644 index 000000000..7e49f3ead --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/ModelValidationActionHandler.java @@ -0,0 +1,482 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest.handler; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CONFIG_BUCKET_MINIMUM_SUCCESS_RATE; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.bucket.histogram.Histogram; +import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.MergeableList; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; + +/** + *

This class executes all validation checks that are not blocking on the 'model' level. + * This mostly involves checking if the data is generally dense enough to complete model training + * which is based on if enough buckets in the last x intervals have at least 1 document present.

+ *

Initially different bucket aggregations are executed with with every configuration applied and with + * different varying intervals in order to find the best interval for the data. If no interval is found with all + * configuration applied then each configuration is tested sequentially for sparsity

+ */ +// TODO: Add more UT and IT +public class ModelValidationActionHandler { + + protected final Config config; + protected final ClusterService clusterService; + protected final Logger logger = LogManager.getLogger(ModelValidationActionHandler.class); + protected final TimeValue requestTimeout; + protected final Client client; + protected final SecurityClientUtil clientUtil; + protected final NamedXContentRegistry xContentRegistry; + protected final ActionListener listener; + protected final Clock clock; + protected final String validationType; + protected final Settings settings; + protected final User user; + protected final AnalysisType context; + private final HistogramAggregationHelper histogramAggHelper; + private final IntervalCalculation intervalCalculation; + // time range bounds to verify configured interval makes sense or not + private LongBounds timeRangeToSearchForConfiguredInterval; + private final LatestTimeRetriever latestTimeRetriever; + + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client OS node client that executes actions on the local node + * @param clientUtil client util + * @param listener OS channel used to construct bytes / builder based outputs, and send responses + * @param config config instance + * @param requestTimeout request time out configuration + * @param xContentRegistry Registry which is used for XContentParser + * @param searchFeatureDao Search feature DAO + * @param validationType Specified type for validation + * @param clock clock object to know when to timeout + * @param settings Node settings + * @param user User info + * @param context Analysis type + */ + public ModelValidationActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + ActionListener listener, + Config config, + TimeValue requestTimeout, + NamedXContentRegistry xContentRegistry, + SearchFeatureDao searchFeatureDao, + String validationType, + Clock clock, + Settings settings, + User user, + AnalysisType context + ) { + this.clusterService = clusterService; + this.client = client; + this.clientUtil = clientUtil; + this.listener = listener; + this.config = config; + this.requestTimeout = requestTimeout; + this.xContentRegistry = xContentRegistry; + this.validationType = validationType; + this.clock = clock; + this.settings = settings; + this.user = user; + this.context = context; + this.histogramAggHelper = new HistogramAggregationHelper(config, requestTimeout); + this.intervalCalculation = new IntervalCalculation(config, requestTimeout, client, clientUtil, user, context, clock); + // calculate the bounds in a lazy manner + this.timeRangeToSearchForConfiguredInterval = null; + this.latestTimeRetriever = new LatestTimeRetriever(config, requestTimeout, clientUtil, client, user, context, searchFeatureDao); + } + + public void start() { + ActionListener, Map>> latestTimeListener = ActionListener + .wrap( + latestEntityAttributes -> getSampleRangesForValidationChecks( + latestEntityAttributes.getLeft(), + config, + listener, + latestEntityAttributes.getRight() + ), + exception -> { + listener.onFailure(exception); + logger.error("Failed to create search request for last data point", exception); + } + ); + latestTimeRetriever.checkIfHC(latestTimeListener); + } + + private void getSampleRangesForValidationChecks( + Optional latestTime, + Config config, + ActionListener listener, + Map topEntity + ) { + if (!latestTime.isPresent() || latestTime.get() <= 0) { + listener + .onFailure( + new ValidationException( + CommonMessages.TIME_FIELD_NOT_ENOUGH_HISTORICAL_DATA, + ValidationIssueType.TIMEFIELD_FIELD, + ValidationAspect.MODEL + ) + ); + return; + } + long timeRangeEnd = Math.min(Instant.now().toEpochMilli(), latestTime.get()); + intervalCalculation + .findInterval( + timeRangeEnd, + topEntity, + ActionListener.wrap(interval -> processIntervalRecommendation(interval, latestTime.get()), listener::onFailure) + ); + } + + private void processIntervalRecommendation(IntervalTimeConfiguration interval, long latestTime) { + // if interval suggestion is null that means no interval could be found with all the configurations + // applied, our next step then is to check density just with the raw data and then add each configuration + // one at a time to try and find root cause of low density + if (interval == null) { + checkRawDataSparsity(latestTime); + } else { + if (((IntervalTimeConfiguration) config.getInterval()).gte(interval)) { + logger.info("Using the current interval there is enough dense data "); + // Check if there is a window delay recommendation if everything else is successful and send exception + if (Instant.now().toEpochMilli() - latestTime > histogramAggHelper.timeConfigToMilliSec(config.getWindowDelay())) { + sendWindowDelayRec(latestTime); + return; + } + // The rate of buckets with at least 1 doc with given interval is above the success rate + listener.onResponse(null); + return; + } + // return response with interval recommendation + listener + .onFailure( + new ValidationException( + CommonMessages.INTERVAL_REC + interval.getInterval(), + ValidationIssueType.DETECTION_INTERVAL, + ValidationAspect.MODEL, + interval + ) + ); + } + } + + public AggregationBuilder getBucketAggregation(long latestTime) { + IntervalTimeConfiguration interval = (IntervalTimeConfiguration) config.getInterval(); + long intervalInMinutes = IntervalTimeConfiguration.getIntervalInMinute(interval); + if (timeRangeToSearchForConfiguredInterval == null) { + timeRangeToSearchForConfiguredInterval = histogramAggHelper.getTimeRangeBounds(latestTime, intervalInMinutes * 60000); + } + + return histogramAggHelper.getBucketAggregation((int) intervalInMinutes, timeRangeToSearchForConfiguredInterval); + } + + private void checkRawDataSparsity(long latestTime) { + AggregationBuilder aggregation = getBucketAggregation(latestTime); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(aggregation).size(0).timeout(requestTimeout); + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processRawDataResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } + + public double processBucketAggregationResults(Histogram buckets, long latestTime) { + long intervalInMillis = config.getIntervalInMilliseconds(); + return histogramAggHelper.processBucketAggregationResults(buckets, intervalInMillis, config); + } + + private void processRawDataResults(SearchResponse response, long latestTime) { + Histogram aggregate = null; + try { + aggregate = histogramAggHelper.checkBucketResultErrors(response); + } catch (ValidationException e) { + listener.onFailure(e); + } + + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate, latestTime); + if (fullBucketRate < TimeSeriesSettings.INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE) { + listener + .onFailure( + new ValidationException(CommonMessages.RAW_DATA_TOO_SPARSE, ValidationIssueType.INDICES, ValidationAspect.MODEL) + ); + } else { + checkDataFilterSparsity(latestTime); + } + } + + private void checkDataFilterSparsity(long latestTime) { + AggregationBuilder aggregation = getBucketAggregation(latestTime); + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(config.getFilterQuery()); + SearchSourceBuilder searchSourceBuilder = histogramAggHelper.getSearchSourceBuilder(query, aggregation); + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processDataFilterResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } + + private void processDataFilterResults(SearchResponse response, long latestTime) { + Histogram aggregate = null; + try { + aggregate = histogramAggHelper.checkBucketResultErrors(response); + } catch (ValidationException e) { + listener.onFailure(e); + } + + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate, latestTime); + if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { + listener + .onFailure( + new ValidationException( + CommonMessages.FILTER_QUERY_TOO_SPARSE, + ValidationIssueType.FILTER_QUERY, + ValidationAspect.MODEL + ) + ); + // blocks below are executed if data is dense enough with filter query applied. + // If HCAD then category fields will be added to bucket aggregation to see if they + // are the root cause of the issues and if not the feature queries will be checked for sparsity + } else if (config.isHighCardinality()) { + getTopEntityForCategoryField(latestTime); + } else { + try { + checkFeatureQueryDelegate(latestTime); + } catch (Exception ex) { + logger.error(ex); + listener.onFailure(ex); + } + } + } + + private void getTopEntityForCategoryField(long latestTime) { + ActionListener> getTopEntityListener = ActionListener + .wrap(topEntity -> checkCategoryFieldSparsity(topEntity, latestTime), exception -> { + listener.onFailure(exception); + logger.error("Failed to get top entity for categorical field", exception); + return; + }); + latestTimeRetriever.getTopEntity(getTopEntityListener); + } + + private void checkCategoryFieldSparsity(Map topEntity, long latestTime) { + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(config.getFilterQuery()); + for (Map.Entry entry : topEntity.entrySet()) { + query.filter(QueryBuilders.termQuery(entry.getKey(), entry.getValue())); + } + AggregationBuilder aggregation = getBucketAggregation(latestTime); + SearchSourceBuilder searchSourceBuilder = histogramAggHelper.getSearchSourceBuilder(query, aggregation); + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processTopEntityResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } + + private void processTopEntityResults(SearchResponse response, long latestTime) { + Histogram aggregate = null; + try { + aggregate = histogramAggHelper.checkBucketResultErrors(response); + } catch (ValidationException e) { + listener.onFailure(e); + } + + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate, latestTime); + if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { + listener + .onFailure( + new ValidationException(CommonMessages.CATEGORY_FIELD_TOO_SPARSE, ValidationIssueType.CATEGORY, ValidationAspect.MODEL) + ); + } else { + try { + checkFeatureQueryDelegate(latestTime); + } catch (Exception ex) { + logger.error(ex); + listener.onFailure(ex); + } + } + } + + private void checkFeatureQueryDelegate(long latestTime) throws IOException { + ActionListener> validateFeatureQueriesListener = ActionListener.wrap(response -> { + windowDelayRecommendation(latestTime); + }, exception -> { + listener + .onFailure(new ValidationException(exception.getMessage(), ValidationIssueType.FEATURE_ATTRIBUTES, ValidationAspect.MODEL)); + }); + MultiResponsesDelegateActionListener> multiFeatureQueriesResponseListener = + new MultiResponsesDelegateActionListener<>( + validateFeatureQueriesListener, + config.getFeatureAttributes().size(), + CommonMessages.FEATURE_QUERY_TOO_SPARSE, + false + ); + + for (Feature feature : config.getFeatureAttributes()) { + AggregationBuilder aggregation = getBucketAggregation(latestTime); + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(config.getFilterQuery()); + List featureFields = ParseUtils.getFieldNamesForFeature(feature, xContentRegistry); + for (String featureField : featureFields) { + query.filter(QueryBuilders.existsQuery(featureField)); + } + SearchSourceBuilder searchSourceBuilder = histogramAggHelper.getSearchSourceBuilder(query, aggregation); + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(response -> { + try { + Histogram aggregate = histogramAggHelper.checkBucketResultErrors(response); + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate, latestTime); + if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { + multiFeatureQueriesResponseListener + .onFailure( + new ValidationException( + CommonMessages.FEATURE_QUERY_TOO_SPARSE, + ValidationIssueType.FEATURE_ATTRIBUTES, + ValidationAspect.MODEL + ) + ); + } else { + multiFeatureQueriesResponseListener + .onResponse(new MergeableList<>(new ArrayList<>(Collections.singletonList(new double[] { fullBucketRate })))); + } + } catch (ValidationException e) { + listener.onFailure(e); + } + + }, e -> { + logger.error(e); + multiFeatureQueriesResponseListener + .onFailure(new OpenSearchStatusException(CommonMessages.FEATURE_QUERY_TOO_SPARSE, RestStatus.BAD_REQUEST, e)); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } + } + + private void sendWindowDelayRec(long latestTimeInMillis) { + long minutesSinceLastStamp = (long) Math.ceil((Instant.now().toEpochMilli() - latestTimeInMillis) / 60000.0); + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, CommonMessages.WINDOW_DELAY_REC, minutesSinceLastStamp, minutesSinceLastStamp), + ValidationIssueType.WINDOW_DELAY, + ValidationAspect.MODEL, + new IntervalTimeConfiguration(minutesSinceLastStamp, ChronoUnit.MINUTES) + ) + ); + } + + private void windowDelayRecommendation(long latestTime) { + // Check if there is a better window-delay to recommend and if one was recommended + // then send exception and return, otherwise continue to let user know data is too sparse as explained below + if (Instant.now().toEpochMilli() - latestTime > histogramAggHelper.timeConfigToMilliSec(config.getWindowDelay())) { + sendWindowDelayRec(latestTime); + return; + } + // This case has been reached if following conditions are met: + // 1. no interval recommendation was found that leads to a bucket success rate of >= 0.75 + // 2. bucket success rate with the given interval and just raw data is also below 0.75. + // 3. no single configuration during the following checks reduced the bucket success rate below 0.25 + // This means the rate with all configs applied or just raw data was below 0.75 but the rate when checking each configuration at + // a time was always above 0.25 meaning the best suggestion is to simply ingest more data or change interval since + // we have no more insight regarding the root cause of the lower density. + listener + .onFailure(new ValidationException(CommonMessages.RAW_DATA_TOO_SPARSE, ValidationIssueType.INDICES, ValidationAspect.MODEL)); + } + +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/Processor.java b/src/main/java/org/opensearch/timeseries/rest/handler/Processor.java new file mode 100644 index 000000000..548f31bba --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/Processor.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest.handler; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; + +/** + * Represents a processor capable of initiating a certain process + * and then notifying a listener upon completion. + * + * @param the type of response expected after processing, which must be a subtype of ActionResponse. + */ +public interface Processor { + + /** + * Starts the processing action. Once the processing is completed, + * the provided listener is notified with the outcome. + * + * @param listener the listener to be notified upon the completion of the processing action. + */ + public void start(ActionListener listener); +} diff --git a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSetting.java b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSetting.java new file mode 100644 index 000000000..3e7499175 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSetting.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.settings; + +import static java.util.Collections.unmodifiableMap; +import static org.opensearch.common.settings.Setting.Property.Dynamic; +import static org.opensearch.common.settings.Setting.Property.NodeScope; + +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.common.settings.Setting; + +public class TimeSeriesEnabledSetting extends DynamicNumericSetting { + + /** + * Singleton instance + */ + private static TimeSeriesEnabledSetting INSTANCE; + + /** + * Settings name + */ + public static final String BREAKER_ENABLED = "plugins.timeseries.breaker.enabled"; + + public static final Map> settings = unmodifiableMap(new HashMap>() { + { + /** + * forecast breaker enable/disable setting + */ + put(BREAKER_ENABLED, Setting.boolSetting(BREAKER_ENABLED, true, NodeScope, Dynamic)); + } + }); + + private TimeSeriesEnabledSetting(Map> settings) { + super(settings); + } + + public static synchronized TimeSeriesEnabledSetting getInstance() { + if (INSTANCE == null) { + INSTANCE = new TimeSeriesEnabledSetting(settings); + } + return INSTANCE; + } + + /** + * Whether circuit breaker is enabled or not. If disabled, an open circuit breaker wouldn't cause a real-time job to be stopped. + * @return whether circuit breaker is enabled or not. + */ + public static boolean isBreakerEnabled() { + return TimeSeriesEnabledSetting.getInstance().getSettingValue(TimeSeriesEnabledSetting.BREAKER_ENABLED); + } +} diff --git a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java index 56bbe187a..408d9fa1b 100644 --- a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java +++ b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java @@ -19,17 +19,43 @@ public class TimeSeriesSettings { // max shingle size we have seen from external users // the larger shingle size, the harder to fill in a complete shingle - public static final int MAX_SHINGLE_SIZE = 60; + public static final int MAX_SHINGLE_SIZE = 64; + + // shingle size = seasonality / 2 + public static final int SEASONALITY_TO_SHINGLE_RATIO = 2; public static final String CONFIG_INDEX_MAPPING_FILE = "mappings/config.json"; public static final String JOBS_INDEX_MAPPING_FILE = "mappings/job.json"; - // 100,000 insertions costs roughly 1KB. + /** + * Memory Usage Estimation for a Map<String, Integer> with 100,000 entries: + * + * 1. HashMap Object Overhead: This can vary, but let's assume it's about 36 bytes. + * 2. Array Overhead: + * - The array size will be the nearest power of 2 greater than or equal to 100,000 / load factor. + * - Assuming a load factor of 0.75, the array size will be 2^17 = 131,072. + * - The memory usage will be 131,072 * 4 bytes = 524,288 bytes. + * 3. Entry Overhead: Each entry has an overhead of about 32 bytes (object header, hash code, and three references). + * 4. Key Overhead: + * - Each key has an overhead of about 36 bytes (object header, length, hash cache) plus the character data. + * - Assuming the character data is 64 bytes, the total key overhead per entry is 100 bytes. + * 5. Value Overhead: Each Integer object has an overhead of about 16 bytes (object header plus int value). + * + * Total Memory Usage Formula: + * Total Memory Usage = HashMap Object Overhead + Array Overhead + + * (Entry Overhead + Key Overhead + Value Overhead) * Number of Entries + * + * Plugging in the numbers: + * Total Memory Usage = 36 + 524,288 + (32 + 100 + 16) * 100,000 + * ≈ 14,965 kilobytes (≈ 15 MB) + * + * Note: + * This estimation is quite simplistic and the actual memory usage may be different based on the JVM implementation, + * the actual Map implementation being used, and other factors. + */ public static final int DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION = 100_000; - public static final double DOOR_KEEPER_FALSE_POSITIVE_RATE = 0.01; - // clean up door keeper every 60 intervals public static final int DOOR_KEEPER_MAINTENANCE_FREQ = 60; @@ -40,6 +66,10 @@ public class TimeSeriesSettings { // only has to do one update/scoring per interval public static final double REAL_TIME_BOUNDING_BOX_CACHE_RATIO = 0; + // max number of historical buckets for cold start. Corresponds to max buckets in OpenSearch. + // We send one query including one bucket per interval. So we don't want to surpass OS limit. + public static final int MAX_HISTORY_INTERVALS = 10000; + // ====================================== // Historical analysis // ====================================== @@ -92,7 +122,7 @@ public class TimeSeriesSettings { public static int CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES = 200_000; /** - * ResultWriteRequest consists of index request (roughly 1KB), and QueuedRequest + * ADResultWriteRequest consists of index request (roughly 1KB), and QueuedRequest * fields (148 bytes, read comments of ENTITY_REQUEST_SIZE_CONSTANT). * Plus Java object size (12 bytes), we have roughly 1160 bytes per request * @@ -104,18 +134,18 @@ public class TimeSeriesSettings { public static int RESULT_WRITE_QUEUE_SIZE_IN_BYTES = 1160; /** - * FeatureRequest has entityName (# category fields * 256, the recommended limit - * of a keyword field length), model Id (roughly 256 bytes), and QueuedRequest - * fields including config Id(roughly 128 bytes), dataStartTimeMillis (long, + * FeatureRequest has entity (max 2 category fields * 256, the recommended limit + * of a keyword field length, 512 bytes), model Id (roughly 256 bytes), runOnce + * boolean (roughly 8 bytes), dataStartTimeMillis (long, * 8 bytes), and currentFeature (16 bytes, assume two features on average). - * Plus Java object size (12 bytes), we have roughly 932 bytes per request + * Plus Java object size (12 bytes), we have roughly 812 bytes per request * assuming we have 2 categorical fields (plan to support 2 categorical fields now). * We don't want the total size exceeds 0.1% of the heap. - * We can have at most 0.1% heap / 932 = heap / 932,000. + * We can have at most 0.1% heap / 812 = heap / 812,000. * For t3.small, 0.1% heap is of 1MB. The queue's size is up to - * 10^ 6 / 932 = 1072 + * 10^ 6 / 812 = 1231 */ - public static int FEATURE_REQUEST_SIZE_IN_BYTES = 932; + public static int FEATURE_REQUEST_SIZE_IN_BYTES = 812; /** * CheckpointMaintainRequest has model Id (roughly 256 bytes), and QueuedRequest @@ -146,9 +176,9 @@ public class TimeSeriesSettings { // RCF public static final int NUM_SAMPLES_PER_TREE = 256; - public static final int NUM_TREES = 30; + public static final int NUM_TREES = 50; - public static final double TIME_DECAY = 0.0001; + public static final int DEFAULT_RECENCY_EMPHASIS = 10 * NUM_SAMPLES_PER_TREE; // If we have 32 + shingleSize (hopefully recent) values, RCF can get up and running. It will be noisy — // there is a reason that default size is 256 (+ shingle size), but it may be more useful for people to @@ -158,6 +188,11 @@ public class TimeSeriesSettings { // for a batch operation, we want all of the bounding box in-place for speed public static final double BATCH_BOUNDING_BOX_CACHE_RATIO = 1; + // feature processing + public static final int TRAIN_SAMPLE_TIME_RANGE_IN_HOURS = 24; + + public static final int MIN_TRAIN_SAMPLES = 512; + // ====================================== // Cold start setting // ====================================== @@ -209,4 +244,28 @@ public class TimeSeriesSettings { // such as "there are at least 10000 entities", the default is set to 10,000. That is, requests will count the // total entities up to 10,000. public static final int MAX_TOTAL_ENTITIES_TO_TRACK = 10_000; + + // ====================================== + // Validate Detector API setting + // ====================================== + public static final long TOP_VALIDATE_TIMEOUT_IN_MILLIS = 10_000; + + public static final double INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE = 0.75; + + public static final double INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER = 1.2; + + public static final long MAX_INTERVAL_REC_LENGTH_IN_MINUTES = 60L; + + public static final int MAX_DESCRIPTION_LENGTH = 1000; + + // ====================================== + // Cache setting + // ====================================== + // We don't want to retry cold start once it exceeds the threshold. + // It is larger than 1 since cx may have ingested new data or the + // system is unstable + public static final int COLD_START_DOOR_KEEPER_COUNT_THRESHOLD = 3; + + // we don't admit model to cache before it exceeds the threshold + public static final int CACHE_DOOR_KEEPER_COUNT_THRESHOLD = 1; } diff --git a/src/main/java/org/opensearch/ad/stats/InternalStatNames.java b/src/main/java/org/opensearch/timeseries/stats/InternalStatNames.java similarity index 95% rename from src/main/java/org/opensearch/ad/stats/InternalStatNames.java rename to src/main/java/org/opensearch/timeseries/stats/InternalStatNames.java index 56ff012a5..356a7828d 100644 --- a/src/main/java/org/opensearch/ad/stats/InternalStatNames.java +++ b/src/main/java/org/opensearch/timeseries/stats/InternalStatNames.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats; +package org.opensearch.timeseries.stats; /** * Enum containing names of all internal stats which will not be returned diff --git a/src/main/java/org/opensearch/timeseries/stats/StatNames.java b/src/main/java/org/opensearch/timeseries/stats/StatNames.java index a72e3f1b0..8ea32dffe 100644 --- a/src/main/java/org/opensearch/timeseries/stats/StatNames.java +++ b/src/main/java/org/opensearch/timeseries/stats/StatNames.java @@ -19,30 +19,46 @@ * AD stats REST API. */ public enum StatNames { - AD_EXECUTE_REQUEST_COUNT("ad_execute_request_count"), - AD_EXECUTE_FAIL_COUNT("ad_execute_failure_count"), - AD_HC_EXECUTE_REQUEST_COUNT("ad_hc_execute_request_count"), - AD_HC_EXECUTE_FAIL_COUNT("ad_hc_execute_failure_count"), - DETECTOR_COUNT("detector_count"), - SINGLE_ENTITY_DETECTOR_COUNT("single_entity_detector_count"), - MULTI_ENTITY_DETECTOR_COUNT("multi_entity_detector_count"), - ANOMALY_DETECTORS_INDEX_STATUS("anomaly_detectors_index_status"), - ANOMALY_RESULTS_INDEX_STATUS("anomaly_results_index_status"), - MODELS_CHECKPOINT_INDEX_STATUS("models_checkpoint_index_status"), - ANOMALY_DETECTION_JOB_INDEX_STATUS("anomaly_detection_job_index_status"), - ANOMALY_DETECTION_STATE_STATUS("anomaly_detection_state_status"), - MODEL_INFORMATION("models"), - AD_EXECUTING_BATCH_TASK_COUNT("ad_executing_batch_task_count"), - AD_CANCELED_BATCH_TASK_COUNT("ad_canceled_batch_task_count"), - AD_TOTAL_BATCH_TASK_EXECUTION_COUNT("ad_total_batch_task_execution_count"), - AD_BATCH_TASK_FAILURE_COUNT("ad_batch_task_failure_count"), - MODEL_COUNT("model_count"), - MODEL_CORRUTPION_COUNT("model_corruption_count"); + // common stats + CONFIG_INDEX_STATUS("config_index_status", StatType.TIMESERIES), + JOB_INDEX_STATUS("job_index_status", StatType.TIMESERIES), + // AD stats + AD_EXECUTE_REQUEST_COUNT("ad_execute_request_count", StatType.AD), + AD_EXECUTE_FAIL_COUNT("ad_execute_failure_count", StatType.AD), + AD_HC_EXECUTE_REQUEST_COUNT("ad_hc_execute_request_count", StatType.AD), + AD_HC_EXECUTE_FAIL_COUNT("ad_hc_execute_failure_count", StatType.AD), + DETECTOR_COUNT("detector_count", StatType.AD), + SINGLE_STREAM_DETECTOR_COUNT("single_stream_detector_count", StatType.AD), + HC_DETECTOR_COUNT("hc_detector_count", StatType.AD), + ANOMALY_RESULTS_INDEX_STATUS("anomaly_results_index_status", StatType.AD), + AD_MODELS_CHECKPOINT_INDEX_STATUS("anomaly_models_checkpoint_index_status", StatType.AD), + ANOMALY_DETECTION_STATE_STATUS("anomaly_detection_state_status", StatType.AD), + MODEL_INFORMATION("models", StatType.AD), + AD_EXECUTING_BATCH_TASK_COUNT("ad_executing_batch_task_count", StatType.AD), + AD_CANCELED_BATCH_TASK_COUNT("ad_canceled_batch_task_count", StatType.AD), + AD_TOTAL_BATCH_TASK_EXECUTION_COUNT("ad_total_batch_task_execution_count", StatType.AD), + AD_BATCH_TASK_FAILURE_COUNT("ad_batch_task_failure_count", StatType.AD), + MODEL_COUNT("model_count", StatType.AD), + AD_MODEL_CORRUTPION_COUNT("ad_model_corruption_count", StatType.AD), + // forecast stats + FORECAST_EXECUTE_REQUEST_COUNT("forecast_execute_request_count", StatType.FORECAST), + FORECAST_EXECUTE_FAIL_COUNT("forecast_execute_failure_count", StatType.FORECAST), + FORECAST_HC_EXECUTE_REQUEST_COUNT("forecast_hc_execute_request_count", StatType.FORECAST), + FORECAST_HC_EXECUTE_FAIL_COUNT("forecast_hc_execute_failure_count", StatType.FORECAST), + FORECAST_RESULTS_INDEX_STATUS("forecast_results_index_status", StatType.FORECAST), + FORECAST_MODELS_CHECKPOINT_INDEX_STATUS("forecast_models_checkpoint_index_status", StatType.FORECAST), + FORECAST_STATE_STATUS("forecastn_state_status", StatType.FORECAST), + FORECASTER_COUNT("forecaster_count", StatType.FORECAST), + SINGLE_STREAM_FORECASTER_COUNT("single_stream_forecaster_count", StatType.FORECAST), + HC_FORECASTER_COUNT("hc_forecaster_count", StatType.FORECAST), + FORECAST_MODEL_CORRUTPION_COUNT("forecast_model_corruption_count", StatType.FORECAST); - private String name; + private final String name; + private final StatType type; - StatNames(String name) { + StatNames(String name, StatType type) { this.name = name; + this.type = type; } /** @@ -54,6 +70,10 @@ public String getName() { return name; } + public StatType getType() { + return type; + } + /** * Get set of stat names * diff --git a/src/main/java/org/opensearch/timeseries/stats/StatType.java b/src/main/java/org/opensearch/timeseries/stats/StatType.java new file mode 100644 index 000000000..cca482bc7 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/stats/StatType.java @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.stats; + +public enum StatType { + AD, + FORECAST, + TIMESERIES +} diff --git a/src/main/java/org/opensearch/timeseries/stats/Stats.java b/src/main/java/org/opensearch/timeseries/stats/Stats.java new file mode 100644 index 000000000..f9b168392 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/stats/Stats.java @@ -0,0 +1,80 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.stats; + +import java.util.HashMap; +import java.util.Map; + +public class Stats { + private Map> stats; + + /** + * Constructor + * + * @param stats Map of the stats that are to be kept + */ + public Stats(Map> stats) { + this.stats = stats; + } + + /** + * Get the stats + * + * @return all of the stats + */ + public Map> getStats() { + return stats; + } + + /** + * Get individual stat by stat name + * + * @param key Name of stat + * @return TimeSeriesStat + * @throws IllegalArgumentException thrown on illegal statName + */ + public TimeSeriesStat getStat(String key) throws IllegalArgumentException { + if (!stats.keySet().contains(key)) { + throw new IllegalArgumentException("Stat=\"" + key + "\" does not exist"); + } + return stats.get(key); + } + + /** + * Get a map of the stats that are kept at the node level + * + * @return Map of stats kept at the node level + */ + public Map> getNodeStats() { + return getClusterOrNodeStats(false); + } + + /** + * Get a map of the stats that are kept at the cluster level + * + * @return Map of stats kept at the cluster level + */ + public Map> getClusterStats() { + return getClusterOrNodeStats(true); + } + + private Map> getClusterOrNodeStats(Boolean getClusterStats) { + Map> statsMap = new HashMap<>(); + + for (Map.Entry> entry : stats.entrySet()) { + if (entry.getValue().isClusterLevel() == getClusterStats) { + statsMap.put(entry.getKey(), entry.getValue()); + } + } + return statsMap; + } +} diff --git a/src/main/java/org/opensearch/ad/stats/ADStat.java b/src/main/java/org/opensearch/timeseries/stats/TimeSeriesStat.java similarity index 86% rename from src/main/java/org/opensearch/ad/stats/ADStat.java rename to src/main/java/org/opensearch/timeseries/stats/TimeSeriesStat.java index 531205907..e10ab9127 100644 --- a/src/main/java/org/opensearch/ad/stats/ADStat.java +++ b/src/main/java/org/opensearch/timeseries/stats/TimeSeriesStat.java @@ -9,17 +9,17 @@ * GitHub history for details. */ -package org.opensearch.ad.stats; +package org.opensearch.timeseries.stats; import java.util.function.Supplier; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; /** * Class represents a stat the plugin keeps track of */ -public class ADStat { +public class TimeSeriesStat { private Boolean clusterLevel; private Supplier supplier; @@ -29,7 +29,7 @@ public class ADStat { * @param clusterLevel whether the stat has clusterLevel scope or nodeLevel scope * @param supplier supplier that returns the stat's value */ - public ADStat(Boolean clusterLevel, Supplier supplier) { + public TimeSeriesStat(Boolean clusterLevel, Supplier supplier) { this.clusterLevel = clusterLevel; this.supplier = supplier; } diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java b/src/main/java/org/opensearch/timeseries/stats/suppliers/CounterSupplier.java similarity index 95% rename from src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java rename to src/main/java/org/opensearch/timeseries/stats/suppliers/CounterSupplier.java index 39acd94ff..0953e9450 100644 --- a/src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java +++ b/src/main/java/org/opensearch/timeseries/stats/suppliers/CounterSupplier.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats.suppliers; +package org.opensearch.timeseries.stats.suppliers; import java.util.concurrent.atomic.LongAdder; import java.util.function.Supplier; diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java b/src/main/java/org/opensearch/timeseries/stats/suppliers/IndexStatusSupplier.java similarity index 92% rename from src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java rename to src/main/java/org/opensearch/timeseries/stats/suppliers/IndexStatusSupplier.java index ab9177cb5..1da433108 100644 --- a/src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java +++ b/src/main/java/org/opensearch/timeseries/stats/suppliers/IndexStatusSupplier.java @@ -9,11 +9,11 @@ * GitHub history for details. */ -package org.opensearch.ad.stats.suppliers; +package org.opensearch.timeseries.stats.suppliers; import java.util.function.Supplier; -import org.opensearch.ad.util.IndexUtils; +import org.opensearch.timeseries.util.IndexUtils; /** * IndexStatusSupplier provides the status of an index as the value diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java b/src/main/java/org/opensearch/timeseries/stats/suppliers/SettableSupplier.java similarity index 94% rename from src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java rename to src/main/java/org/opensearch/timeseries/stats/suppliers/SettableSupplier.java index b39ecdde5..e5e60c6ba 100644 --- a/src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java +++ b/src/main/java/org/opensearch/timeseries/stats/suppliers/SettableSupplier.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats.suppliers; +package org.opensearch.timeseries.stats.suppliers; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; diff --git a/src/main/java/org/opensearch/timeseries/task/RealtimeTaskCache.java b/src/main/java/org/opensearch/timeseries/task/RealtimeTaskCache.java index 5fe0c3850..8765e19c9 100644 --- a/src/main/java/org/opensearch/timeseries/task/RealtimeTaskCache.java +++ b/src/main/java/org/opensearch/timeseries/task/RealtimeTaskCache.java @@ -35,8 +35,8 @@ public class RealtimeTaskCache { // track last job run time, will clean up cache if no access after 2 intervals private long lastJobRunTime; - // detector interval in milliseconds. - private long detectorIntervalInMillis; + // interval in milliseconds. + private long intervalInMillis; // we query result index to check if there are any result generated for detector to tell whether it passed initialization of not. // To avoid repeated query when there is no data, record whether we have done that or not. @@ -47,7 +47,7 @@ public RealtimeTaskCache(String state, Float initProgress, String error, long de this.initProgress = initProgress; this.error = error; this.lastJobRunTime = Instant.now().toEpochMilli(); - this.detectorIntervalInMillis = detectorIntervalInMillis; + this.intervalInMillis = detectorIntervalInMillis; this.queriedResultIndex = false; } @@ -88,6 +88,6 @@ public void setQueriedResultIndex(boolean queriedResultIndex) { } public boolean expired() { - return lastJobRunTime + 2 * detectorIntervalInMillis < Instant.now().toEpochMilli(); + return lastJobRunTime + 2 * intervalInMillis < Instant.now().toEpochMilli(); } } diff --git a/src/main/java/org/opensearch/timeseries/task/TaskCacheManager.java b/src/main/java/org/opensearch/timeseries/task/TaskCacheManager.java index fe08f94c8..d0a87d9a2 100644 --- a/src/main/java/org/opensearch/timeseries/task/TaskCacheManager.java +++ b/src/main/java/org/opensearch/timeseries/task/TaskCacheManager.java @@ -15,9 +15,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.task.ForecastTaskManager; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.transport.TransportService; public class TaskCacheManager { private final Logger logger = LogManager.getLogger(TaskCacheManager.class); @@ -39,7 +43,7 @@ public class TaskCacheManager { protected volatile Integer maxCachedDeletedTask; /** * This field is to cache deleted detector IDs. Hourly cron will poll this queue - * and clean AD results. Check ADTaskManager#cleanResultOfDeletedConfig() + * and clean AD results. Check {@link ADTaskManager#cleanResultOfDeletedConfig} *

Node: any data node servers delete detector request

*/ protected Queue deletedConfigs; @@ -146,16 +150,16 @@ public boolean isRealtimeTaskChangeNeeded(String detectorId, String newState, Fl * * If realtime task cache doesn't exist, will do nothing. Next realtime job run will re-init * realtime task cache when it finds task cache not inited yet. - * Check ADTaskManager#initCacheWithCleanupIfRequired(String, AnomalyDetector, TransportService, ActionListener), - * ADTaskManager#updateLatestRealtimeTaskOnCoordinatingNode(String, String, Long, Long, String, ActionListener) * - * @param detectorId detector id + * Check {@link TaskManager#initRealtimeTaskCacheAndCleanupStaleCache(String, Config, TransportService, ActionListener)} + * + * @param configId detector id * @param newState new task state * @param newInitProgress new init progress * @param newError new error */ - public void updateRealtimeTaskCache(String detectorId, String newState, Float newInitProgress, String newError) { - RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(detectorId); + public void updateRealtimeTaskCache(String configId, String newState, Float newInitProgress, String newError) { + RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(configId); if (realtimeTaskCache != null) { if (newState != null) { realtimeTaskCache.setState(newState); @@ -168,47 +172,47 @@ public void updateRealtimeTaskCache(String detectorId, String newState, Float ne } if (newState != null && !TaskState.NOT_ENDED_STATES.contains(newState)) { // If task is done, will remove its realtime task cache. - logger.info("Realtime task done with state {}, remove RT task cache for detector ", newState, detectorId); - removeRealtimeTaskCache(detectorId); + logger.info("Realtime task done with state {}, remove RT task cache for config ", newState, configId); + removeRealtimeTaskCache(configId); } } else { - logger.debug("Realtime task cache is not inited yet for detector {}", detectorId); + logger.debug("Realtime task cache is not inited yet for config {}", configId); } } - public void refreshRealtimeJobRunTime(String detectorId) { - RealtimeTaskCache taskCache = realtimeTaskCaches.get(detectorId); + public void refreshRealtimeJobRunTime(String configId) { + RealtimeTaskCache taskCache = realtimeTaskCaches.get(configId); if (taskCache != null) { taskCache.setLastJobRunTime(Instant.now().toEpochMilli()); } } /** - * Get detector IDs from realtime task cache. - * @return array of detector id + * Get config IDs from realtime task cache. + * @return array of config id */ - public String[] getDetectorIdsInRealtimeTaskCache() { + public String[] getConfigIdsInRealtimeTaskCache() { return realtimeTaskCaches.keySet().toArray(new String[0]); } /** * Remove detector's realtime task from cache. - * @param detectorId detector id + * @param configId config id */ - public void removeRealtimeTaskCache(String detectorId) { - if (realtimeTaskCaches.containsKey(detectorId)) { - logger.info("Delete realtime cache for detector {}", detectorId); - realtimeTaskCaches.remove(detectorId); + public void removeRealtimeTaskCache(String configId) { + if (realtimeTaskCaches.containsKey(configId)) { + logger.info("Delete realtime cache for config {}", configId); + realtimeTaskCaches.remove(configId); } } /** - * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + * We query result index to check if there are any result generated for config to tell whether it passed initialization of not. * To avoid repeated query when there is no data, record whether we have done that or not. - * @param id detector id + * @param configId config id */ - public void markResultIndexQueried(String id) { - RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); + public void markResultIndexQueried(String configId) { + RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(configId); // we initialize a real time cache at the beginning of AnomalyResultTransportAction if it // cannot be found. If the cache is empty, we will return early and wait it for it to be // initialized. @@ -218,13 +222,13 @@ public void markResultIndexQueried(String id) { } /** - * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + * We query result index to check if there are any result generated for config to tell whether it passed initialization of not. * - * @param id detector id + * @param configId config id * @return whether we have queried result index or not. */ - public boolean hasQueriedResultIndex(String id) { - RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); + public boolean hasQueriedResultIndex(String configId) { + RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(configId); if (realtimeTaskCache != null) { return realtimeTaskCache.hasQueriedResultIndex(); } diff --git a/src/main/java/org/opensearch/timeseries/task/TaskManager.java b/src/main/java/org/opensearch/timeseries/task/TaskManager.java new file mode 100644 index 000000000..7424ffb13 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/task/TaskManager.java @@ -0,0 +1,1085 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.task; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.constant.CommonMessages.CONFIG_IS_RUNNING; +import static org.opensearch.timeseries.model.TaskState.NOT_ENDED_STATES; +import static org.opensearch.timeseries.model.TaskType.taskTypeToString; +import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.index.reindex.UpdateByQueryAction; +import org.opensearch.index.reindex.UpdateByQueryRequest; +import org.opensearch.script.Script; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TaskCancelledException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.BiCheckedFunction; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.function.ResponseTransformer; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableMap; + +public abstract class TaskManager & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + protected static int DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS = 5; + + private final Logger logger = LogManager.getLogger(TaskManager.class); + + protected final TaskCacheManagerType taskCacheManager; + protected final ClusterService clusterService; + protected final Client client; + protected final String stateIndex; + private final List realTimeTaskTypes; + private final List historicalTaskTypes; + private final List runOnceTaskTypes; + protected final IndexManagementType indexManagement; + protected final NodeStateManager nodeStateManager; + protected final AnalysisType analysisType; + protected final NamedXContentRegistry xContentRegistry; + protected final String configIdFieldName; + + protected volatile Integer maxOldTaskDocsPerConfig; + + protected final ThreadPool threadPool; + private final String allResultIndexPattern; + private final String batchTaskThreadPoolName; + private volatile boolean deleteResultWhenDeleteConfig; + private final TaskState stopped; + + public TaskManager( + TaskCacheManagerType taskCacheManager, + ClusterService clusterService, + Client client, + String stateIndex, + List realTimeTaskTypes, + List historicalTaskTypes, + List runOnceTaskTypes, + IndexManagementType indexManagement, + NodeStateManager nodeStateManager, + AnalysisType analysisType, + NamedXContentRegistry xContentRegistry, + String configIdFieldName, + Setting maxOldADTaskDocsPerConfigSetting, + Settings settings, + ThreadPool threadPool, + String allResultIndexPattern, + String batchTaskThreadPoolName, + Setting deleteResultWhenDeleteConfigSetting, + TaskState stopped + ) { + this.taskCacheManager = taskCacheManager; + this.clusterService = clusterService; + this.client = client; + this.stateIndex = stateIndex; + this.realTimeTaskTypes = realTimeTaskTypes; + this.historicalTaskTypes = historicalTaskTypes; + this.runOnceTaskTypes = runOnceTaskTypes; + this.indexManagement = indexManagement; + this.nodeStateManager = nodeStateManager; + this.analysisType = analysisType; + this.xContentRegistry = xContentRegistry; + this.configIdFieldName = configIdFieldName; + + this.maxOldTaskDocsPerConfig = maxOldADTaskDocsPerConfigSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(maxOldADTaskDocsPerConfigSetting, it -> maxOldTaskDocsPerConfig = it); + + this.threadPool = threadPool; + this.allResultIndexPattern = allResultIndexPattern; + this.batchTaskThreadPoolName = batchTaskThreadPoolName; + + this.deleteResultWhenDeleteConfig = deleteResultWhenDeleteConfigSetting.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(deleteResultWhenDeleteConfigSetting, it -> deleteResultWhenDeleteConfig = it); + + this.stopped = stopped; + } + + public boolean skipUpdateRealtimeTask(String configId, String error) { + RealtimeTaskCache realtimeTaskCache = taskCacheManager.getRealtimeTaskCache(configId); + return realtimeTaskCache != null + && realtimeTaskCache.getInitProgress() != null + && realtimeTaskCache.getInitProgress().floatValue() == 1.0 + && Objects.equals(error, realtimeTaskCache.getError()); + } + + public boolean isHCRealtimeTaskStartInitializing(String detectorId) { + RealtimeTaskCache realtimeTaskCache = taskCacheManager.getRealtimeTaskCache(detectorId); + return realtimeTaskCache != null + && realtimeTaskCache.getInitProgress() != null + && realtimeTaskCache.getInitProgress().floatValue() > 0; + } + + /** + * Maintain running realtime tasks. Check if realtime task cache expires or not. Remove realtime + * task cache directly if expired. + */ + public void maintainRunningRealtimeTasks() { + String[] configIds = taskCacheManager.getConfigIdsInRealtimeTaskCache(); + if (configIds == null || configIds.length == 0) { + return; + } + for (int i = 0; i < configIds.length; i++) { + String configId = configIds[i]; + RealtimeTaskCache taskCache = taskCacheManager.getRealtimeTaskCache(configId); + if (taskCache != null && taskCache.expired()) { + taskCacheManager.removeRealtimeTaskCache(configId); + } + } + } + + public void refreshRealtimeJobRunTime(String detectorId) { + taskCacheManager.refreshRealtimeJobRunTime(detectorId); + } + + public void removeRealtimeTaskCache(String detectorId) { + taskCacheManager.removeRealtimeTaskCache(detectorId); + } + + /** + * Update realtime task cache on realtime config's coordinating node. + * + * @param configId config id + * @param state new state + * @param rcfTotalUpdates rcf total updates + * @param intervalInMinutes config interval in minutes + * @param error error + * @param listener action listener + */ + public void updateLatestRealtimeTaskOnCoordinatingNode( + String configId, + String state, + Long rcfTotalUpdates, + Long intervalInMinutes, + String error, + ActionListener listener + ) { + Float initProgress = null; + String newState = null; + // calculate init progress and task state with RCF total updates + if (intervalInMinutes != null && rcfTotalUpdates != null) { + newState = TaskState.INIT.name(); + if (rcfTotalUpdates < TimeSeriesSettings.NUM_MIN_SAMPLES) { + initProgress = (float) rcfTotalUpdates / TimeSeriesSettings.NUM_MIN_SAMPLES; + } else { + newState = TaskState.RUNNING.name(); + initProgress = 1.0f; + } + } + // Check if new state is not null and override state calculated with rcf total updates + if (state != null) { + newState = state; + } + + error = Optional.ofNullable(error).orElse(""); + if (!taskCacheManager.isRealtimeTaskChangeNeeded(configId, newState, initProgress, error)) { + // If task not changed, no need to update, just return + listener.onResponse(null); + return; + } + Map updatedFields = new HashMap<>(); + updatedFields.put(TimeSeriesTask.COORDINATING_NODE_FIELD, clusterService.localNode().getId()); + if (initProgress != null) { + updatedFields.put(TimeSeriesTask.INIT_PROGRESS_FIELD, initProgress); + updatedFields + .put( + TimeSeriesTask.ESTIMATED_MINUTES_LEFT_FIELD, + Math.max(0, TimeSeriesSettings.NUM_MIN_SAMPLES - rcfTotalUpdates) * intervalInMinutes + ); + } + if (newState != null) { + updatedFields.put(TimeSeriesTask.STATE_FIELD, newState); + } + if (error != null) { + updatedFields.put(TimeSeriesTask.ERROR_FIELD, error); + } + Float finalInitProgress = initProgress; + // Variable used in lambda expression should be final or effectively final + String finalError = error; + String finalNewState = newState; + updateLatestTask(configId, realTimeTaskTypes, updatedFields, ActionListener.wrap(r -> { + logger.debug("Updated latest realtime AD task successfully for config {}", configId); + taskCacheManager.updateRealtimeTaskCache(configId, finalNewState, finalInitProgress, finalError); + listener.onResponse(r); + }, e -> { + logger.error("Failed to update realtime task for config " + configId, e); + listener.onFailure(e); + })); + } + + /** + * Update latest task of a config. + * + * @param configId config id + * @param taskTypes task types + * @param updatedFields updated fields, key: filed name, value: new value + * @param listener action listener + */ + public void updateLatestTask( + String configId, + List taskTypes, + Map updatedFields, + ActionListener listener + ) { + getAndExecuteOnLatestConfigLevelTask(configId, taskTypes, (task) -> { + if (task.isPresent()) { + updateTask(task.get().getTaskId(), updatedFields, listener); + } else { + listener.onFailure(new ResourceNotFoundException(configId, CommonMessages.CAN_NOT_FIND_LATEST_TASK)); + } + }, null, false, listener); + } + + public void getAndExecuteOnLatestConfigLevelTask( + Config config, + DateRange dateRange, + boolean runOnce, + User user, + TransportService transportService, + ActionListener listener + ) { + getAndExecuteOnLatestConfigLevelTask(config.getId(), getTaskTypes(dateRange), (task) -> { + if (!task.isPresent() || task.get().isDone()) { + updateLatestFlagOfOldTasksAndCreateNewTask(config, dateRange, runOnce, user, TaskState.CREATED, listener); + } else { + listener.onFailure(new OpenSearchStatusException(CONFIG_IS_RUNNING, RestStatus.BAD_REQUEST)); + } + }, transportService, true, listener); + } + + public void updateLatestFlagOfOldTasksAndCreateNewTask( + Config config, + DateRange dateRange, + boolean runOnce, + User user, + TaskState initialState, + ActionListener listener + ) { + UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); + updateByQueryRequest.indices(stateIndex); + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(configIdFieldName, config.getId())); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); + // make sure we reset all latest task as false when user switch from single entity to HC, vice versa. + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(getTaskTypes(dateRange, true, runOnce)))); + updateByQueryRequest.setQuery(query); + updateByQueryRequest.setRefresh(true); + String script = String.format(Locale.ROOT, "ctx._source.%s=%s;", TimeSeriesTask.IS_LATEST_FIELD, false); + updateByQueryRequest.setScript(new Script(script)); + + client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { + List bulkFailures = r.getBulkFailures(); + if (bulkFailures.isEmpty()) { + // Realtime AD coordinating node is chosen by job scheduler, we won't know it until realtime AD job + // runs. Just set realtime AD coordinating node as null here, and AD job runner will reset correct + // coordinating node once realtime job starts. + // For historical analysis, this method will be called on coordinating node, so we can set coordinating + // node as local node. + String coordinatingNode = dateRange == null ? null : clusterService.localNode().getId(); + createNewTask(config, dateRange, runOnce, user, coordinatingNode, initialState, listener); + } else { + logger.error("Failed to update old task's state for detector: {}, response: {} ", config.getId(), r.toString()); + listener.onFailure(bulkFailures.get(0).getCause()); + } + }, e -> { + logger.error("Failed to reset old tasks as not latest for detector " + config.getId(), e); + listener.onFailure(e); + })); + } + + /** + * Get latest task and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param configId config id + * @param taskTypes task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param listener action listener + * @param action listener response type + */ + public void getAndExecuteOnLatestConfigLevelTask( + String configId, + List taskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + ActionListener listener + ) { + getAndExecuteOnLatestConfigTask(configId, null, null, taskTypes, function, transportService, resetTaskState, listener); + } + + /** + * Get one latest task and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param configId config id + * @param parentTaskId parent task id + * @param entity entity value + * @param taskTypes task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param listener action listener + * @param action listener response type + */ + public void getAndExecuteOnLatestConfigTask( + String configId, + String parentTaskId, + Entity entity, + List taskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + ActionListener listener + ) { + getAndExecuteOnLatestTasks(configId, parentTaskId, entity, taskTypes, (taskList) -> { + if (taskList != null && taskList.size() > 0) { + function.accept(Optional.ofNullable(taskList.get(0))); + } else { + function.accept(Optional.empty()); + } + }, transportService, resetTaskState, 1, listener); + } + + public List getTaskTypes(DateRange dateRange) { + return getTaskTypes(dateRange, false, false); + } + + /** + * Update latest realtime task. + * + * @param configId config id + * @param state task state + * @param error error + * @param transportService transport service + * @param listener action listener + */ + public void stopLatestRealtimeTask( + String configId, + TaskState state, + Exception error, + TransportService transportService, + ActionListener listener + ) { + getAndExecuteOnLatestConfigLevelTask(configId, realTimeTaskTypes, (adTask) -> { + if (adTask.isPresent() && !adTask.get().isDone()) { + Map updatedFields = new HashMap<>(); + updatedFields.put(TimeSeriesTask.STATE_FIELD, state.name()); + if (error != null) { + updatedFields.put(TimeSeriesTask.ERROR_FIELD, error.getMessage()); + } + ExecutorFunction function = () -> updateTask(adTask.get().getTaskId(), updatedFields, ActionListener.wrap(r -> { + if (error == null) { + listener.onResponse(new JobResponse(configId)); + } else { + listener.onFailure(error); + } + }, e -> { listener.onFailure(e); })); + + String coordinatingNode = adTask.get().getCoordinatingNode(); + if (coordinatingNode != null && transportService != null) { + cleanConfigCache(adTask.get(), transportService, function, listener); + } else { + function.execute(); + } + } else { + listener.onFailure(new OpenSearchStatusException("job is already stopped: " + configId, RestStatus.OK)); + } + }, null, false, listener); + } + + protected void resetTaskStateAsStopped( + TimeSeriesTask task, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + cleanConfigCache(task, transportService, () -> { + String taskId = task.getTaskId(); + Map updatedFields = ImmutableMap.of(TimeSeriesTask.STATE_FIELD, stopped.name()); + updateTask(taskId, updatedFields, ActionListener.wrap(r -> { + task.setState(stopped.name()); + if (function != null) { + function.execute(); + } + // For realtime anomaly detection, we only create config level task, no entity level realtime task. + if (isHistoricalHCTask(task)) { + // Reset running entity tasks as STOPPED + resetEntityTasksAsStopped(taskId); + } + }, e -> { + logger.error("Failed to update task state as stopped for task " + taskId, e); + listener.onFailure(e); + })); + }, listener); + } + + /** + * Get latest config tasks and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param configId config id + * @param parentTaskId parent task id + * @param entity entity value + * @param taskTypes task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param size return how many tasks + * @param listener action listener + * @param response type of action listener + */ + public void getAndExecuteOnLatestTasks( + String configId, + String parentTaskId, + Entity entity, + List taskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + int size, + ActionListener listener + ) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(configIdFieldName, configId)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); + if (parentTaskId != null) { + query.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, parentTaskId)); + } + if (taskTypes != null && taskTypes.size() > 0) { + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, TaskType.taskTypeToString(taskTypes))); + } + if (entity != null && !ParseUtils.isNullOrEmpty(entity.getAttributes())) { + String path = "entity"; + String entityKeyFieldName = path + ".name"; + String entityValueFieldName = path + ".value"; + + for (Map.Entry attribute : entity.getAttributes().entrySet()) { + BoolQueryBuilder entityBoolQuery = new BoolQueryBuilder(); + TermQueryBuilder entityKeyFilterQuery = QueryBuilders.termQuery(entityKeyFieldName, attribute.getKey()); + TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValueFieldName, attribute.getValue()); + + entityBoolQuery.filter(entityKeyFilterQuery).filter(entityValueFilterQuery); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(path, entityBoolQuery, ScoreMode.None); + query.filter(nestedQueryBuilder); + } + } + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(query).sort(TimeSeriesTask.EXECUTION_START_TIME_FIELD, SortOrder.DESC).size(size); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(sourceBuilder); + searchRequest.indices(stateIndex); + + client.search(searchRequest, ActionListener.wrap(r -> { + // https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/359#discussion_r558653132 + // getTotalHits will be null when we track_total_hits is false in the query request. + // Add more checking here to cover some unknown cases. + List tsTasks = new ArrayList<>(); + if (r == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + // don't throw exception here as consumer functions need to handle missing task + // in different way. + function.accept(tsTasks); + return; + } + BiCheckedFunction parserMethod = getTaskParser(); + Iterator iterator = r.getHits().iterator(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + TaskClass tsTask = parserMethod.apply(parser, searchHit.getId()); + tsTasks.add(tsTask); + } catch (Exception e) { + String message = "Failed to parse task for config " + configId + ", task id " + searchHit.getId(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } + if (resetTaskState) { + resetLatestConfigTaskState(tsTasks, function, transportService, listener); + } else { + function.accept(tsTasks); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + function.accept(new ArrayList<>()); + } else { + logger.error("Failed to search task for config " + configId, e); + listener.onFailure(e); + } + })); + } + + protected void resetRealtimeConfigTaskState( + List runningRealtimeTasks, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + if (ParseUtils.isNullOrEmpty(runningRealtimeTasks)) { + function.execute(); + return; + } + TimeSeriesTask tsTask = runningRealtimeTasks.get(0); + String configId = tsTask.getConfigId(); + GetRequest getJobRequest = new GetRequest(CommonName.JOB_INDEX).id(configId); + client.get(getJobRequest, ActionListener.wrap(r -> { + if (r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job job = Job.parse(parser); + if (!job.isEnabled()) { + logger.debug("job is disabled, reset realtime task as stopped for config {}", configId); + resetTaskStateAsStopped(tsTask, function, transportService, listener); + } else { + function.execute(); + } + } catch (IOException e) { + logger.error(" Failed to parse job " + configId, e); + listener.onFailure(e); + } + } else { + logger.debug("job is not found, reset realtime task as stopped for config {}", configId); + resetTaskStateAsStopped(tsTask, function, transportService, listener); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + logger.debug("job is not found, reset realtime task as stopped for config {}", configId); + resetTaskStateAsStopped(tsTask, function, transportService, listener); + } else { + logger.error("Fail to get realtime job for config " + configId, e); + listener.onFailure(e); + } + })); + } + + /** + * Handle exceptions for task. Update task state and record error message. + * + * @param task AD task + * @param e exception + */ + public void handleTaskException(TaskClass task, Exception e) { + // TODO: handle timeout exception + String state = TaskState.FAILED.name(); + Map updatedFields = new HashMap<>(); + if (e instanceof DuplicateTaskException) { + // If user send multiple start detector request, we will meet race condition. + // Cache manager will put first request in cache and throw DuplicateTaskException + // for the second request. We will delete the second task. + logger + .warn( + "There is already one running task for config, configId:" + + task.getConfigId() + + ". Will delete task " + + task.getTaskId() + ); + deleteTask(task.getTaskId()); + return; + } + if (e instanceof TaskCancelledException) { + logger.info("task cancelled, taskId: {}, configId: {}", task.getTaskId(), task.getConfigId()); + state = stopped.name(); + String stoppedBy = ((TaskCancelledException) e).getCancelledBy(); + if (stoppedBy != null) { + updatedFields.put(TimeSeriesTask.STOPPED_BY_FIELD, stoppedBy); + } + } else { + logger.error("Failed to execute batch task, task id: " + task.getTaskId() + ", config id: " + task.getConfigId(), e); + } + updatedFields.put(TimeSeriesTask.ERROR_FIELD, ExceptionUtil.getErrorMessage(e)); + updatedFields.put(TimeSeriesTask.STATE_FIELD, state); + updatedFields.put(TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli()); + updateTask(task.getTaskId(), updatedFields); + } + + /** + * Update task with specific fields. + * + * @param taskId task id + * @param updatedFields updated fields, key: filed name, value: new value + */ + public void updateTask(String taskId, Map updatedFields) { + updateTask(taskId, updatedFields, ActionListener.wrap(response -> { + if (response.status() == RestStatus.OK) { + logger.debug("Updated task successfully: {}, task id: {}", response.status(), taskId); + } else { + logger.error("Failed to update task {}, status: {}", taskId, response.status()); + } + }, e -> { logger.error("Failed to update task: " + taskId, e); })); + } + + /** + * Update task for specific fields. + * + * @param taskId task id + * @param updatedFields updated fields, key: filed name, value: new value + * @param listener action listener + */ + public void updateTask(String taskId, Map updatedFields, ActionListener listener) { + UpdateRequest updateRequest = new UpdateRequest(stateIndex, taskId); + Map updatedContent = new HashMap<>(); + updatedContent.putAll(updatedFields); + updatedContent.put(TimeSeriesTask.LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); + updateRequest.doc(updatedContent); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.update(updateRequest, listener); + } + + /** + * Delete task with task id. + * + * @param taskId task id + */ + public void deleteTask(String taskId) { + deleteTask(taskId, ActionListener.wrap(r -> { logger.info("Deleted task {} with status: {}", taskId, r.status()); }, e -> { + logger.error("Failed to delete task " + taskId, e); + })); + } + + /** + * Delete task with task id. + * + * @param taskId task id + * @param listener action listener + */ + public void deleteTask(String taskId, ActionListener listener) { + DeleteRequest deleteRequest = new DeleteRequest(stateIndex, taskId); + client.delete(deleteRequest, listener); + } + + /** + * Create config task directly without checking index exists of not. + * [Important!] Make sure listener returns in function + * + * @param tsTask Time series task + * @param function consumer function + * @param listener action listener + * @param action listener response type + */ + public void createTaskDirectly(TaskClass tsTask, Consumer function, ActionListener listener) { + IndexRequest request = new IndexRequest(stateIndex); + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + request + .source(tsTask.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(request, ActionListener.wrap(r -> function.accept(r), e -> { + logger.error("Failed to create task for config " + tsTask.getConfigId(), e); + listener.onFailure(e); + })); + } catch (Exception e) { + logger.error("Failed to create task for config " + tsTask.getConfigId(), e); + listener.onFailure(e); + } + } + + protected void cleanOldConfigTaskDocs( + IndexResponse response, + TaskClass tsTask, + ResponseTransformer responseTransformer, + ActionListener delegatedListener + ) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(configIdFieldName, tsTask.getConfigId())); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, false)); + + if (tsTask.isHistoricalTask()) { + // If historical task, only delete detector level task. It may take longer time to delete entity tasks. + // We will delete child task (entity task) of config level task in hourly cron job. + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(historicalTaskTypes))); + } else if (tsTask.isRunOnceTask()) { + // We don't have entity level task for run once detection, so will delete all tasks. + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(runOnceTaskTypes))); + } else { + // We don't have entity level task for realtime detection, so will delete all tasks. + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(realTimeTaskTypes))); + } + + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder + .query(query) + .sort(TimeSeriesTask.EXECUTION_START_TIME_FIELD, SortOrder.DESC) + // Search query "from" starts from maxOldTaskDocsPerConfig. + .from(maxOldTaskDocsPerConfig) + .size(MAX_OLD_AD_TASK_DOCS); + searchRequest.source(sourceBuilder).indices(stateIndex); + String configId = tsTask.getConfigId(); + deleteTaskDocs(configId, searchRequest, () -> { + if (tsTask.isHistoricalTask()) { + // run batch result action for historical analysis + runBatchResultAction(response, tsTask, responseTransformer, delegatedListener); + } else { + // use the responseTransformer to transform the response + T transformedResponse = responseTransformer.transform(response); + delegatedListener.onResponse(transformedResponse); + } + }, delegatedListener); + } + + public void deleteTaskDocs(String configId, SearchRequest searchRequest, ExecutorFunction function, ActionListener listener) { + ActionListener searchListener = ActionListener.wrap(r -> { + Iterator iterator = r.getHits().iterator(); + if (iterator.hasNext()) { + BulkRequest bulkRequest = new BulkRequest(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + TimeSeriesTask task = null; + if (analysisType.isAD()) { + task = ADTask.parse(parser, searchHit.getId()); + } else { + task = ForecastTask.parse(parser, searchHit.getId()); + } + + logger.debug("Delete old task: {} of config: {}", task.getTaskId(), task.getConfigId()); + bulkRequest.add(new DeleteRequest(stateIndex).id(task.getTaskId())); + } catch (Exception e) { + listener.onFailure(e); + } + } + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { + logger.info("Old tasks deleted for config {}", configId); + BulkItemResponse[] bulkItemResponses = res.getItems(); + if (bulkItemResponses != null && bulkItemResponses.length > 0) { + for (BulkItemResponse bulkItemResponse : bulkItemResponses) { + if (!bulkItemResponse.isFailed()) { + logger.debug("Add config task into cache. Task id: {}", bulkItemResponse.getId()); + // add deleted task in cache and delete its child tasks and results + taskCacheManager.addDeletedTask(bulkItemResponse.getId()); + } + } + } + // delete child tasks and results of this task + cleanChildTasksAndResultsOfDeletedTask(); + function.execute(); + }, e -> { + logger.warn("Failed to clean tasks for config " + configId, e); + listener.onFailure(e); + })); + } else { + function.execute(); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + function.execute(); + } else { + listener.onFailure(e); + } + }); + + client.search(searchRequest, searchListener); + } + + /** + * Poll deleted config task from cache and delete its child tasks and results. + */ + public void cleanChildTasksAndResultsOfDeletedTask() { + if (!taskCacheManager.hasDeletedTask()) { + return; + } + threadPool.schedule(() -> { + String taskId = taskCacheManager.pollDeletedTask(); + if (taskId == null) { + return; + } + DeleteByQueryRequest deleteResultsRequest = new DeleteByQueryRequest(allResultIndexPattern); + deleteResultsRequest.setQuery(new TermsQueryBuilder(CommonName.TASK_ID_FIELD, taskId)); + client.execute(DeleteByQueryAction.INSTANCE, deleteResultsRequest, ActionListener.wrap(res -> { + logger.debug("Successfully deleted results of task " + taskId); + DeleteByQueryRequest deleteChildTasksRequest = new DeleteByQueryRequest(stateIndex); + deleteChildTasksRequest.setQuery(new TermsQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, taskId)); + + client.execute(DeleteByQueryAction.INSTANCE, deleteChildTasksRequest, ActionListener.wrap(r -> { + logger.debug("Successfully deleted child tasks of task " + taskId); + cleanChildTasksAndResultsOfDeletedTask(); + }, e -> { logger.error("Failed to delete child tasks of task " + taskId, e); })); + }, ex -> { logger.error("Failed to delete results for task " + taskId, ex); })); + }, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), batchTaskThreadPoolName); + } + + protected void resetEntityTasksAsStopped(String configTaskId) { + UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); + updateByQueryRequest.indices(stateIndex); + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, configTaskId)); + query.filter(new TermQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, ADTaskType.HISTORICAL_HC_ENTITY.name())); + query.filter(new TermsQueryBuilder(TimeSeriesTask.STATE_FIELD, NOT_ENDED_STATES)); + updateByQueryRequest.setQuery(query); + updateByQueryRequest.setRefresh(true); + String script = String.format(Locale.ROOT, "ctx._source.%s='%s';", TimeSeriesTask.STATE_FIELD, TaskState.INACTIVE.name()); + updateByQueryRequest.setScript(new Script(script)); + + client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { + List bulkFailures = r.getBulkFailures(); + if (ParseUtils.isNullOrEmpty(bulkFailures)) { + logger.debug("Updated {} child entity tasks state for config task {}", r.getUpdated(), configTaskId); + } else { + logger.error("Failed to update child entity task's state for config task {} ", configTaskId); + } + }, e -> logger.error("Exception happened when update child entity task's state for config task " + configTaskId, e))); + } + + /** + * Set old task's latest flag as false. + * @param tasks list of tasks + */ + public void resetLatestFlagAsFalse(List tasks) { + if (tasks == null || tasks.size() == 0) { + return; + } + BulkRequest bulkRequest = new BulkRequest(); + tasks.forEach(task -> { + try { + task.setLatest(false); + task.setLastUpdateTime(Instant.now()); + IndexRequest indexRequest = new IndexRequest(stateIndex) + .id(task.getTaskId()) + .source(task.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)); + bulkRequest.add(indexRequest); + } catch (Exception e) { + logger.error("Fail to parse task task to XContent, task id " + task.getTaskId(), e); + } + }); + + bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { + BulkItemResponse[] bulkItemResponses = res.getItems(); + if (bulkItemResponses != null && bulkItemResponses.length > 0) { + for (BulkItemResponse bulkItemResponse : bulkItemResponses) { + if (!bulkItemResponse.isFailed()) { + logger.warn("Reset tasks latest flag as false Successfully. Task id: {}", bulkItemResponse.getId()); + } else { + logger.warn("Failed to reset tasks latest flag as false. Task id: " + bulkItemResponse.getId()); + } + } + } + }, e -> { logger.warn("Failed to reset AD tasks latest flag as false", e); })); + } + + /** + * Delete tasks docs. + * [Important!] Make sure listener returns in function + * + * @param configId config id + * @param function time series function + * @param listener action listener + */ + public void deleteTasks(String configId, ExecutorFunction function, ActionListener listener) { + DeleteByQueryRequest request = new DeleteByQueryRequest(stateIndex); + + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(configIdFieldName, configId)); + + request.setQuery(query); + client.execute(DeleteByQueryAction.INSTANCE, request, ActionListener.wrap(r -> { + if (r.getBulkFailures() == null || r.getBulkFailures().size() == 0) { + logger.info("tasks deleted for config {}", configId); + deleteResultOfConfig(configId); + function.execute(); + } else { + listener.onFailure(new OpenSearchStatusException("Failed to delete all tasks", RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + logger.info("Failed to delete tasks for " + configId, e); + if (e instanceof IndexNotFoundException) { + deleteResultOfConfig(configId); + function.execute(); + } else { + listener.onFailure(e); + } + })); + } + + public void deleteResultOfConfig(String configId) { + if (!deleteResultWhenDeleteConfig) { + logger.info("Won't delete result for {} as delete result setting is disabled", configId); + return; + } + logger.info("Start to delete results of config {}", configId); + DeleteByQueryRequest deleteADResultsRequest = new DeleteByQueryRequest(allResultIndexPattern); + deleteADResultsRequest.setQuery(new TermQueryBuilder(configIdFieldName, configId)); + client.execute(DeleteByQueryAction.INSTANCE, deleteADResultsRequest, ActionListener.wrap(response -> { + logger.debug("Successfully deleted results of config " + configId); + }, exception -> { + logger.error("Failed to delete results of config " + configId, exception); + taskCacheManager.addDeletedConfig(configId); + })); + } + + /** + * Clean results of deleted config. + */ + public void cleanResultOfDeletedConfig() { + String detectorId = taskCacheManager.pollDeletedConfig(); + if (detectorId != null) { + deleteResultOfConfig(detectorId); + } + } + + public abstract void startHistorical( + Config config, + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener + ); + + protected abstract TaskType getTaskType(Config config, DateRange dateRange, boolean runOnce); + + protected abstract void createNewTask( + Config config, + DateRange dateRange, + boolean runOnce, + User user, + String coordinatingNode, + TaskState initialState, + ActionListener listener + ); + + public abstract void cleanConfigCache( + TimeSeriesTask task, + TransportService transportService, + ExecutorFunction function, + ActionListener listener + ); + + protected abstract boolean isHistoricalHCTask(TimeSeriesTask task); + + protected abstract void resetLatestConfigTaskState( + List tasks, + Consumer> function, + TransportService transportService, + ActionListener listener + ); + + protected abstract void onIndexConfigTaskResponse( + IndexResponse response, + TaskClass adTask, + BiConsumer> function, + ActionListener listener + ); + + protected abstract void runBatchResultAction( + IndexResponse response, + TaskClass tsTask, + ResponseTransformer responseTransformer, + ActionListener listener + ); + + protected abstract BiCheckedFunction getTaskParser(); + + /** + * the function initializes the real time cache and only performs cleanup if it is deemed necessary. + * @param configId config id + * @param config config accessor + * @param transportService Transport service + * @param listener listener to return back init success or not + */ + public abstract void initRealtimeTaskCacheAndCleanupStaleCache( + String configId, + Config config, + TransportService transportService, + ActionListener listener + ); + + public abstract void createRunOnceTaskAndCleanupStaleTasks( + String configId, + Config config, + TransportService transportService, + ActionListener listener + ); + + public abstract List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag, boolean runOnce); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/AbstractSingleStreamResultTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/AbstractSingleStreamResultTransportAction.java new file mode 100644 index 000000000..39ed3c5cf --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/AbstractSingleStreamResultTransportAction.java @@ -0,0 +1,244 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.time.Instant; +import java.util.List; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.transport.ForecastSingleStreamResultTransportAction; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.CacheBuffer; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.PriorityCache; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.ratelimit.ColdStartWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; +import org.opensearch.timeseries.ratelimit.ResultWriteWorker; +import org.opensearch.timeseries.ratelimit.SaveResultStrategy; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class AbstractSingleStreamResultTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, CheckpointMaintainerType extends CheckpointMaintainWorker, CacheBufferType extends CacheBuffer, PriorityCacheType extends PriorityCache, CacheProviderType extends CacheProvider, ResultType extends IndexableResult, RCFResultType extends IntermediateResult, ColdStarterType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, CheckpointReadWorkerType extends CheckpointReadWorker, ResultWriteRequestType extends ResultWriteRequest, BatchRequestType extends ResultBulkRequest, ResultHandlerType extends IndexMemoryPressureAwareResultHandler, ResultWriteWorkerType extends ResultWriteWorker> + extends HandledTransportAction { + private static final Logger LOG = LogManager.getLogger(ForecastSingleStreamResultTransportAction.class); + protected CircuitBreakerService circuitBreakerService; + protected CacheProviderType cache; + protected final NodeStateManager stateManager; + protected CheckpointReadWorkerType checkpointReadQueue; + protected ModelManagerType modelManager; + protected IndexManagementType indexUtil; + protected ResultWriteWorkerType resultWriteQueue; + protected Stats stats; + protected ColdStartWorkerType coldStartWorker; + protected IndexType resultIndex; + protected AnalysisType analysisType; + + public AbstractSingleStreamResultTransportAction( + TransportService transportService, + ActionFilters actionFilters, + CircuitBreakerService circuitBreakerService, + CacheProviderType cache, + NodeStateManager stateManager, + CheckpointReadWorkerType checkpointReadQueue, + ModelManagerType modelManager, + IndexManagementType indexUtil, + ResultWriteWorkerType resultWriteQueue, + Stats stats, + ColdStartWorkerType forecastColdStartQueue, + String resultAction, + IndexType resultIndex, + AnalysisType analysisType + ) { + super(resultAction, transportService, actionFilters, SingleStreamResultRequest::new); + this.circuitBreakerService = circuitBreakerService; + this.cache = cache; + this.stateManager = stateManager; + this.checkpointReadQueue = checkpointReadQueue; + this.modelManager = modelManager; + this.indexUtil = indexUtil; + this.resultWriteQueue = resultWriteQueue; + this.stats = stats; + this.coldStartWorker = forecastColdStartQueue; + this.resultIndex = resultIndex; + this.analysisType = analysisType; + } + + @Override + protected void doExecute(Task task, SingleStreamResultRequest request, ActionListener listener) { + if (circuitBreakerService.isOpen()) { + listener.onFailure(new LimitExceededException(request.getConfigId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + try { + String configId = request.getConfigId(); + + Optional previousException = stateManager.fetchExceptionAndClear(configId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", configId, exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + + listener = ExceptionUtil.wrapListener(listener, exception, configId); + } + + stateManager.getConfig(configId, analysisType, onGetConfig(listener, configId, request, previousException)); + } catch (Exception exception) { + LOG.error("fail to get entity's forecasts", exception); + listener.onFailure(exception); + } + } + + public ActionListener> onGetConfig( + ActionListener listener, + String forecasterId, + SingleStreamResultRequest request, + Optional prevException + ) { + return ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new EndRunException(forecasterId, "Config " + forecasterId + " is not available.", false)); + return; + } + + Config config = configOptional.get(); + + Instant executionStartTime = Instant.now(); + + String modelId = request.getModelId(); + double[] datapoint = request.getDataPoint(); + ModelState modelState = cache.get().get(modelId, config); + if (modelState == null) { + // cache miss + checkpointReadQueue + .put( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + request.getModelId(), + datapoint, + request.getStart(), + request.getTaskId() + ) + ); + } else { + try { + RCFResultType result = modelManager + .getResult( + new Sample(datapoint, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd())), + modelState, + modelId, + Optional.empty(), + config, + request.getTaskId() + ); + // result.getRcfScore() = 0 means the model is not initialized + if (result.getRcfScore() > 0) { + List indexableResults = result + .toIndexableResults( + config, + Instant.ofEpochMilli(request.getStart()), + Instant.ofEpochMilli(request.getEnd()), + executionStartTime, + Instant.now(), + ParseUtils.getFeatureData(datapoint, config), + Optional.empty(), + indexUtil.getSchemaVersion(resultIndex), + modelId, + null, + null + ); + + for (ResultType r : indexableResults) { + resultWriteQueue.put(createResultWriteRequest(config, r)); + } + } + } catch (IllegalArgumentException e) { + // fail to score likely due to model corruption. Re-cold start to recover. + LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); + stats.getStat(StatNames.FORECAST_MODEL_CORRUTPION_COUNT.getName()).increment(); + cache.get().removeModel(forecasterId, modelId); + coldStartWorker + .put( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + modelId, + datapoint, + request.getStart(), + request.getTaskId() + ) + ); + } + } + + // respond back + if (prevException.isPresent()) { + listener.onFailure(prevException.get()); + } else { + listener.onResponse(new AcknowledgedResponse(true)); + } + }, exception -> { + LOG + .error( + new ParameterizedMessage( + "fail to get entity's forecasts for forecaster [{}]: start: [{}], end: [{}]", + forecasterId, + request.getStart(), + request.getEnd() + ), + exception + ); + listener.onFailure(exception); + }); + } + + public abstract ResultWriteRequestType createResultWriteRequest(Config config, ResultType result); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseDeleteConfigTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseDeleteConfigTransportAction.java new file mode 100644 index 000000000..d7e1c355f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseDeleteConfigTransportAction.java @@ -0,0 +1,247 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_DELETE_CONFIG; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.io.IOException; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public abstract class BaseDeleteConfigTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskManagerType extends TaskManager, ConfigType extends Config> + extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(BaseDeleteConfigTransportAction.class); + + private final Client client; + private final ClusterService clusterService; + private final TransportService transportService; + private NamedXContentRegistry xContentRegistry; + private final TaskManagerType taskManager; + private volatile Boolean filterByEnabled; + private final NodeStateManager nodeStateManager; + private final AnalysisType analysisType; + private final String stateIndex; + private final Class configTypeClass; + private final List batchTaskTypes; + + public BaseDeleteConfigTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + NamedXContentRegistry xContentRegistry, + NodeStateManager nodeStateManager, + TaskManagerType taskManager, + String deleteConfigAction, + Setting filterByBackendRoleSetting, + AnalysisType analysisType, + String stateIndex, + Class configTypeClass, + List historicalTaskTypes + ) { + super(deleteConfigAction, transportService, actionFilters, DeleteConfigRequest::new); + this.transportService = transportService; + this.client = client; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.taskManager = taskManager; + this.nodeStateManager = nodeStateManager; + filterByEnabled = filterByBackendRoleSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterByEnabled = it); + + this.analysisType = analysisType; + this.stateIndex = stateIndex; + this.configTypeClass = configTypeClass; + this.batchTaskTypes = historicalTaskTypes; + } + + @Override + protected void doExecute(Task task, DeleteConfigRequest request, ActionListener actionListener) { + String configId = request.getConfigID(); + LOG.info("Delete job {}", configId); + User user = ParseUtils.getUserContext(client); + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_DELETE_CONFIG); + // By the time request reaches here, the user permissions are validated by Security plugin. + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + configId, + filterByEnabled, + listener, + (input) -> nodeStateManager.getConfig(configId, analysisType, config -> { + if (!config.isPresent()) { + // In a mixed cluster, if delete detector request routes to node running AD1.0, then it will + // not delete detector tasks. User can re-delete these deleted detector after cluster upgraded, + // in that case, the detector is not present. + LOG.info("Can't find config {}", configId); + taskManager.deleteTasks(configId, () -> deleteJobDoc(configId, listener), listener); + return; + } + // Check if there is realtime job or batch analysis task running. If none of these running, we + // can delete the config. + getJob(configId, listener, () -> { + taskManager.getAndExecuteOnLatestConfigLevelTask(configId, batchTaskTypes, configTask -> { + if (configTask.isPresent() && !configTask.get().isDone()) { + listener + .onFailure(new OpenSearchStatusException("Run once or historical is running", RestStatus.BAD_REQUEST)); + } else { + taskManager.deleteTasks(configId, () -> deleteJobDoc(configId, listener), listener); + } + // false means don't reset task state as inactive/stopped state. We are checking if task has finished or not. + // So no need to reset task state. + }, transportService, false, listener); + }); + }, listener), + client, + clusterService, + xContentRegistry, + configTypeClass + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + private void deleteJobDoc(String configId, ActionListener listener) { + LOG.info("Delete job {}", configId); + DeleteRequest deleteRequest = new DeleteRequest(CommonName.JOB_INDEX, configId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteRequest, ActionListener.wrap(response -> { + if (response.getResult() == DocWriteResponse.Result.DELETED || response.getResult() == DocWriteResponse.Result.NOT_FOUND) { + deleteStateDoc(configId, listener); + } else { + String message = "Fail to delete job " + configId; + LOG.error(message); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, exception -> { + LOG.error("Failed to delete job for " + configId, exception); + if (exception instanceof IndexNotFoundException) { + deleteStateDoc(configId, listener); + } else { + LOG.error("Failed to delete job", exception); + listener.onFailure(exception); + } + })); + } + + private void deleteStateDoc(String configId, ActionListener listener) { + LOG.info("Delete config state {}", configId); + DeleteRequest deleteRequest = new DeleteRequest(stateIndex, configId); + client.delete(deleteRequest, ActionListener.wrap(response -> { + // whether deleted state doc or not, continue as state doc may not exist + deleteConfigDoc(configId, listener); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + deleteConfigDoc(configId, listener); + } else { + LOG.error("Failed to delete state", exception); + listener.onFailure(exception); + } + })); + } + + private void deleteConfigDoc(String configId, ActionListener listener) { + LOG.info("Delete config {}", configId); + DeleteRequest deleteRequest = new DeleteRequest(CommonName.CONFIG_INDEX, configId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteRequest, new ActionListener() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + listener.onResponse(deleteResponse); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + } + + private void getJob(String configId, ActionListener listener, ExecutorFunction function) { + if (clusterService.state().metadata().indices().containsKey(CommonName.JOB_INDEX)) { + GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(configId); + client.get(request, ActionListener.wrap(response -> onGetJobResponseForWrite(response, listener, function), exception -> { + LOG.error("Fail to get job: " + configId, exception); + listener.onFailure(exception); + })); + } else { + function.execute(); + } + } + + private void onGetJobResponseForWrite(GetResponse response, ActionListener listener, ExecutorFunction function) + throws IOException { + if (response.isExists()) { + String jobId = response.getId(); + if (jobId != null) { + // check if job is running on the config, if yes, we can't delete the config + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job adJob = Job.parse(parser); + if (adJob.isEnabled()) { + listener.onFailure(new OpenSearchStatusException("job is running: " + jobId, RestStatus.BAD_REQUEST)); + } else { + function.execute(); + } + } catch (IOException e) { + String message = "Failed to parse job " + jobId; + LOG.error(message, e); + function.execute(); + } + } + } else { + function.execute(); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseDeleteModelTransportAction.java similarity index 52% rename from src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java rename to src/main/java/org/opensearch/timeseries/transport/BaseDeleteModelTransportAction.java index 10aa64725..8a638e401 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseDeleteModelTransportAction.java @@ -1,15 +1,9 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; @@ -19,44 +13,44 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.task.TaskCacheManager; import org.opensearch.transport.TransportService; -public class DeleteModelTransportAction extends - TransportNodesAction { - private static final Logger LOG = LogManager.getLogger(DeleteModelTransportAction.class); +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class BaseDeleteModelTransportAction, CacheProviderType extends CacheProvider, TaskCacheManagerType extends TaskCacheManager, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ModelColdStartType extends ModelColdStart> + extends TransportNodesAction { + + private static final Logger LOG = LogManager.getLogger(BaseDeleteModelTransportAction.class); private NodeStateManager nodeStateManager; - private ModelManager modelManager; - private FeatureManager featureManager; - private CacheProvider cache; - private ADTaskCacheManager adTaskCacheManager; - private EntityColdStarter coldStarter; - - @Inject - public DeleteModelTransportAction( + private CacheProviderType cache; + private TaskCacheManagerType adTaskCacheManager; + private ModelColdStartType coldStarter; + + public BaseDeleteModelTransportAction( ThreadPool threadPool, ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, NodeStateManager nodeStateManager, - ModelManager modelManager, - FeatureManager featureManager, - CacheProvider cache, - ADTaskCacheManager adTaskCacheManager, - EntityColdStarter coldStarter + CacheProviderType cache, + TaskCacheManagerType taskCacheManager, + ModelColdStartType coldStarter, + String deleteModelAction ) { super( - DeleteModelAction.NAME, + deleteModelAction, threadPool, clusterService, transportService, @@ -67,10 +61,8 @@ public DeleteModelTransportAction( DeleteModelNodeResponse.class ); this.nodeStateManager = nodeStateManager; - this.modelManager = modelManager; - this.featureManager = featureManager; this.cache = cache; - this.adTaskCacheManager = adTaskCacheManager; + this.adTaskCacheManager = taskCacheManager; this.coldStarter = coldStarter; } @@ -104,34 +96,18 @@ protected DeleteModelNodeResponse newNodeResponse(StreamInput in) throws IOExcep @Override protected DeleteModelNodeResponse nodeOperation(DeleteModelNodeRequest request) { - String adID = request.getAdID(); - LOG.info("Delete model for {}", adID); - // delete in-memory models and model checkpoint - modelManager - .clear( - adID, - ActionListener - .wrap( - r -> LOG.info("Deleted model for [{}] with response [{}] ", adID, r), - e -> LOG.error("Fail to delete model for " + adID, e) - ) - ); + String configID = request.getConfigID(); + LOG.info("Delete model for {}", configID); + nodeStateManager.clear(configID); - // delete buffered shingle data - featureManager.clear(adID); + cache.get().clear(configID); - // delete transport state - nodeStateManager.clear(adID); - - cache.get().clear(adID); - - coldStarter.clear(adID); + coldStarter.clear(configID); // delete realtime task cache - adTaskCacheManager.removeRealtimeTaskCache(adID); + adTaskCacheManager.removeRealtimeTaskCache(configID); - LOG.info("Finished deleting {}", adID); + LOG.info("Finished deleting {}", configID); return new DeleteModelNodeResponse(clusterService.localNode()); } - } diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseEntityProfileTransportAction.java similarity index 70% rename from src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java rename to src/main/java/org/opensearch/timeseries/transport/BaseEntityProfileTransportAction.java index fedfb2aa7..68bea6e1b 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseEntityProfileTransportAction.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.Locale; @@ -20,34 +20,37 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.model.ModelProfileOnNode; -import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + /** * Transport action to get entity profile. */ -public class EntityProfileTransportAction extends HandledTransportAction { +public class BaseEntityProfileTransportAction, CacheProviderType extends CacheProvider> + extends HandledTransportAction { - private static final Logger LOG = LogManager.getLogger(EntityProfileTransportAction.class); + private static final Logger LOG = LogManager.getLogger(BaseEntityProfileTransportAction.class); public static final String NO_NODE_FOUND_MSG = "Cannot find model hosting node"; public static final String NO_MODEL_ID_FOUND_MSG = "Cannot find model id"; static final String FAIL_TO_GET_ENTITY_PROFILE_MSG = "Cannot get entity profile info"; @@ -56,33 +59,36 @@ public class EntityProfileTransportAction extends HandledTransportAction requestTimeOut ) { - super(EntityProfileAction.NAME, transportService, actionFilters, EntityProfileRequest::new); + super(entityProfileAction, transportService, actionFilters, EntityProfileRequest::new); this.transportService = transportService; this.hashRing = hashRing; this.option = TransportRequestOptions .builder() .withType(TransportRequestOptions.Type.REG) - .withTimeout(AnomalyDetectorSettings.AD_REQUEST_TIMEOUT.get(settings)) + .withTimeout(requestTimeOut.get(settings)) .build(); this.clusterService = clusterService; this.cacheProvider = cacheProvider; + this.entityProfileAction = entityProfileAction; } @Override protected void doExecute(Task task, EntityProfileRequest request, ActionListener listener) { - String adID = request.getAdID(); + String adID = request.getConfigID(); Entity entityValue = request.getEntityValue(); Optional modelIdOptional = entityValue.getModelId(adID); if (false == modelIdOptional.isPresent()) { @@ -91,7 +97,7 @@ protected void doExecute(Task task, EntityProfileRequest request, ActionListener } // we use entity's toString (e.g., app_0) to find its node // This should be consistent with how we land a model node in AnomalyResultTransportAction - Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(entityValue.toString()); + Optional node = hashRing.getOwningNodeWithSameLocalVersionForRealtime(entityValue.toString()); if (false == node.isPresent()) { listener.onFailure(new TimeSeriesException(adID, NO_NODE_FOUND_MSG)); return; @@ -100,12 +106,12 @@ protected void doExecute(Task task, EntityProfileRequest request, ActionListener String modelId = modelIdOptional.get(); DiscoveryNode localNode = clusterService.localNode(); if (localNode.getId().equals(nodeId)) { - EntityCache cache = cacheProvider.get(); + CacheType cache = cacheProvider.get(); Set profilesToCollect = request.getProfilesToCollect(); EntityProfileResponse.Builder builder = new EntityProfileResponse.Builder(); if (profilesToCollect.contains(EntityProfileName.ENTITY_INFO)) { builder.setActive(cache.isActive(adID, modelId)); - builder.setLastActiveMs(cache.getLastActiveMs(adID, modelId)); + builder.setLastActiveMs(cache.getLastActiveTime(adID, modelId)); } if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS) || profilesToCollect.contains(EntityProfileName.STATE)) { builder.setTotalUpdates(cache.getTotalUpdates(adID, modelId)); @@ -126,35 +132,29 @@ protected void doExecute(Task task, EntityProfileRequest request, ActionListener try { transportService - .sendRequest( - node.get(), - EntityProfileAction.NAME, - request, - option, - new TransportResponseHandler() { - - @Override - public EntityProfileResponse read(StreamInput in) throws IOException { - return new EntityProfileResponse(in); - } - - @Override - public void handleResponse(EntityProfileResponse response) { - listener.onResponse(response); - } - - @Override - public void handleException(TransportException exp) { - listener.onFailure(exp); - } - - @Override - public String executor() { - return ThreadPool.Names.SAME; - } + .sendRequest(node.get(), entityProfileAction, request, option, new TransportResponseHandler() { + + @Override + public EntityProfileResponse read(StreamInput in) throws IOException { + return new EntityProfileResponse(in); + } + + @Override + public void handleResponse(EntityProfileResponse response) { + listener.onResponse(response); + } + + @Override + public void handleException(TransportException exp) { + listener.onFailure(exp); + } + @Override + public String executor() { + return ThreadPool.Names.SAME; } - ); + + }); } catch (Exception e) { LOG.error(String.format(Locale.ROOT, "Fail to get entity profile for detector {}, entity {}", adID, entityValue), e); listener.onFailure(new TimeSeriesException(adID, FAIL_TO_GET_ENTITY_PROFILE_MSG, e)); diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java new file mode 100644 index 000000000..9e1ece6f2 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java @@ -0,0 +1,516 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_GET_FORECASTER; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionType; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.EntityProfileRunner; +import org.opensearch.timeseries.Name; +import org.opensearch.timeseries.ProfileRunner; +import org.opensearch.timeseries.TaskProfile; +import org.opensearch.timeseries.TaskProfileRunner; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigProfile; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.Sets; + +public abstract class BaseGetConfigTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskManagerType extends TaskManager, ConfigType extends Config, EntityProfileActionType extends ActionType, EntityProfileRunnerType extends EntityProfileRunner, TaskProfileType extends TaskProfile, ConfigProfileType extends ConfigProfile, ProfileActionType extends ActionType, TaskProfileRunnerType extends TaskProfileRunner, ProfileRunnerType extends ProfileRunner> + extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(BaseGetConfigTransportAction.class); + + protected final ClusterService clusterService; + protected final Client client; + protected final SecurityClientUtil clientUtil; + protected final Set allProfileTypeStrs; + protected final Set allProfileTypes; + protected final Set defaultDetectorProfileTypes; + protected final Set allEntityProfileTypeStrs; + protected final Set allEntityProfileTypes; + protected final Set defaultEntityProfileTypes; + protected final NamedXContentRegistry xContentRegistry; + protected final DiscoveryNodeFilterer nodeFilter; + protected final TransportService transportService; + protected volatile Boolean filterByEnabled; + protected final TaskManagerType taskManager; + private final Class configTypeClass; + private final String configParseFieldName; + private final List allTaskTypes; + private final String singleStreamRealTimeTaskName; + private final String hcRealTImeTaskName; + private final String singleStreamHistoricalTaskname; + private final String hcHistoricalTaskName; + private final TaskProfileRunnerType taskProfileRunner; + + public BaseGetConfigTransportAction( + TransportService transportService, + DiscoveryNodeFilterer nodeFilter, + ActionFilters actionFilters, + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + Settings settings, + NamedXContentRegistry xContentRegistry, + TaskManagerType forecastTaskManager, + String getConfigAction, + Class configTypeClass, + String configParseFieldName, + List allTaskTypes, + String hcRealTImeTaskName, + String singleStreamRealTimeTaskName, + String hcHistoricalTaskName, + String singleStreamHistoricalTaskname, + Setting filterByBackendRoleEnableSetting, + TaskProfileRunnerType taskProfileRunner + ) { + super(getConfigAction, transportService, actionFilters, GetConfigRequest::new); + this.clusterService = clusterService; + this.client = client; + this.clientUtil = clientUtil; + + List allProfiles = Arrays.asList(ProfileName.values()); + this.allProfileTypes = EnumSet.copyOf(allProfiles); + this.allProfileTypeStrs = Name.getListStrs(allProfiles); + List defaultProfiles = Arrays.asList(ProfileName.ERROR, ProfileName.STATE); + this.defaultDetectorProfileTypes = new HashSet<>(defaultProfiles); + + List allEntityProfiles = Arrays.asList(EntityProfileName.values()); + this.allEntityProfileTypes = EnumSet.copyOf(allEntityProfiles); + this.allEntityProfileTypeStrs = Name.getListStrs(allEntityProfiles); + List defaultEntityProfiles = Arrays.asList(EntityProfileName.STATE); + this.defaultEntityProfileTypes = new HashSet<>(defaultEntityProfiles); + + this.xContentRegistry = xContentRegistry; + this.nodeFilter = nodeFilter; + filterByEnabled = filterByBackendRoleEnableSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleEnableSetting, it -> filterByEnabled = it); + this.transportService = transportService; + this.taskManager = forecastTaskManager; + this.configTypeClass = configTypeClass; + this.configParseFieldName = configParseFieldName; + this.allTaskTypes = allTaskTypes; + this.hcRealTImeTaskName = hcRealTImeTaskName; + this.singleStreamRealTimeTaskName = singleStreamRealTimeTaskName; + this.hcHistoricalTaskName = hcHistoricalTaskName; + this.singleStreamHistoricalTaskname = singleStreamHistoricalTaskname; + this.taskProfileRunner = taskProfileRunner; + } + + @Override + public void doExecute(Task task, GetConfigRequest request, ActionListener actionListener) { + String configID = request.getConfigID(); + User user = ParseUtils.getUserContext(client); + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_FORECASTER); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + configID, + filterByEnabled, + listener, + (config) -> getExecute(request, listener), + client, + clusterService, + xContentRegistry, + configTypeClass + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + protected void getConfigAndJob( + String configID, + boolean returnJob, + boolean returnTask, + Optional realtimeConfigTask, + Optional historicalConfigTask, + ActionListener listener + ) { + MultiGetRequest.Item configItem = new MultiGetRequest.Item(CommonName.CONFIG_INDEX, configID); + MultiGetRequest multiGetRequest = new MultiGetRequest().add(configItem); + if (returnJob) { + MultiGetRequest.Item adJobItem = new MultiGetRequest.Item(CommonName.JOB_INDEX, configID); + multiGetRequest.add(adJobItem); + } + client + .multiGet( + multiGetRequest, + onMultiGetResponse(listener, returnJob, returnTask, realtimeConfigTask, historicalConfigTask, configID) + ); + } + + public void getExecute(GetConfigRequest request, ActionListener listener) { + String configID = request.getConfigID(); + String typesStr = request.getTypeStr(); + String rawPath = request.getRawPath(); + Entity entity = request.getEntity(); + boolean all = request.isAll(); + boolean returnJob = request.isReturnJob(); + boolean returnTask = request.isReturnTask(); + + try { + if (!Strings.isEmpty(typesStr) || rawPath.endsWith(PROFILE) || rawPath.endsWith(PROFILE + "/")) { + getExecuteProfile(request, entity, typesStr, all, configID, listener); + } else { + if (returnTask) { + taskManager.getAndExecuteOnLatestTasks(configID, null, null, allTaskTypes, (taskList) -> { + Optional realtimeTask = Optional.empty(); + Optional historicalTask = Optional.empty(); + if (taskList != null && taskList.size() > 0) { + Map tasks = new HashMap<>(); + List duplicateTasks = new ArrayList<>(); + for (TaskClass task : taskList) { + if (tasks.containsKey(task.getTaskType())) { + LOG + .info( + "Found duplicate latest task of config {}, task id: {}, task type: {}", + configID, + task.getTaskType(), + task.getTaskId() + ); + duplicateTasks.add(task); + continue; + } + tasks.put(task.getTaskType(), task); + } + if (duplicateTasks.size() > 0) { + taskManager.resetLatestFlagAsFalse(duplicateTasks); + } + + if (tasks.containsKey(hcRealTImeTaskName)) { + realtimeTask = Optional.ofNullable(tasks.get(hcRealTImeTaskName)); + } else if (tasks.containsKey(singleStreamRealTimeTaskName)) { + realtimeTask = Optional.ofNullable(tasks.get(singleStreamRealTimeTaskName)); + } + if (tasks.containsKey(hcHistoricalTaskName)) { + historicalTask = Optional.ofNullable(tasks.get(hcHistoricalTaskName)); + } else if (tasks.containsKey(singleStreamHistoricalTaskname)) { + historicalTask = Optional.ofNullable(tasks.get(singleStreamHistoricalTaskname)); + } else { + // AD needs to provides custom behavior for bwc, while forecasting can inherit + // the empty implementation + fillInHistoricalTaskforBwc(tasks, historicalTask); + } + } + getConfigAndJob(configID, returnJob, returnTask, realtimeTask, historicalTask, listener); + }, transportService, true, 2, listener); + } else { + getConfigAndJob(configID, returnJob, returnTask, Optional.empty(), Optional.empty(), listener); + } + } + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + private ActionListener onMultiGetResponse( + ActionListener listener, + boolean returnJob, + boolean returnTask, + Optional realtimeTask, + Optional historicalTask, + String configId + ) { + return new ActionListener() { + @Override + public void onResponse(MultiGetResponse multiGetResponse) { + MultiGetItemResponse[] responses = multiGetResponse.getResponses(); + ConfigType config = null; + Job job = null; + String id = null; + long version = 0; + long seqNo = 0; + long primaryTerm = 0; + + for (MultiGetItemResponse response : responses) { + if (CommonName.CONFIG_INDEX.equals(response.getIndex())) { + if (response.getResponse() == null || !response.getResponse().isExists()) { + listener + .onFailure( + new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND) + ); + return; + } + id = response.getId(); + version = response.getResponse().getVersion(); + primaryTerm = response.getResponse().getPrimaryTerm(); + seqNo = response.getResponse().getSeqNo(); + if (!response.getResponse().isSourceEmpty()) { + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + config = parser.namedObject(configTypeClass, configParseFieldName, null); + } catch (Exception e) { + String message = "Failed to parse config " + configId; + listener.onFailure(buildInternalServerErrorResponse(e, message)); + return; + } + } + } else if (CommonName.JOB_INDEX.equals(response.getIndex())) { + if (response.getResponse() != null + && response.getResponse().isExists() + && !response.getResponse().isSourceEmpty()) { + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + job = Job.parse(parser); + } catch (Exception e) { + String message = "Failed to parse job " + configId; + listener.onFailure(buildInternalServerErrorResponse(e, message)); + return; + } + } + } + } + listener + .onResponse( + createResponse( + version, + id, + primaryTerm, + seqNo, + config, + job, + returnJob, + realtimeTask, + historicalTask, + returnTask, + RestStatus.OK, + null, + null, + false + ) + ); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }; + } + + protected void fillInHistoricalTaskforBwc(Map tasks, Optional historicalAdTask) {} + + protected void getExecuteProfile( + GetConfigRequest request, + Entity entity, + String typesStr, + boolean all, + String configId, + ActionListener listener + ) { + if (entity != null) { + Set entityProfilesToCollect = getEntityProfilesToCollect(typesStr, all); + EntityProfileRunnerType profileRunner = createEntityProfileRunner( + client, + clientUtil, + xContentRegistry, + TimeSeriesSettings.NUM_MIN_SAMPLES + ); + profileRunner.profile(configId, entity, entityProfilesToCollect, ActionListener.wrap(profile -> { + listener + .onResponse( + createResponse( + 0, + null, + 0, + 0, + null, + null, + false, + Optional.empty(), + Optional.empty(), + false, + null, + null, + profile, + true + ) + ); + }, e -> listener.onFailure(e))); + } else { + Set profilesToCollect = getProfilesToCollect(typesStr, all); + ProfileRunnerType profileRunner = createProfileRunner( + client, + clientUtil, + xContentRegistry, + nodeFilter, + TimeSeriesSettings.NUM_MIN_SAMPLES, + transportService, + taskManager, + taskProfileRunner + ); + profileRunner.profile(configId, getProfileActionListener(listener), profilesToCollect); + } + + } + + protected abstract GetConfigResponseType createResponse( + long version, + String id, + long primaryTerm, + long seqNo, + ConfigType config, + Job job, + boolean returnJob, + Optional realtimeTask, + Optional historicalTask, + boolean returnTask, + RestStatus restStatus, + ConfigProfileType detectorProfile, + EntityProfile entityProfile, + boolean profileResponse + ); + + protected OpenSearchStatusException buildInternalServerErrorResponse(Exception e, String errorMsg) { + LOG.error(errorMsg, e); + return new OpenSearchStatusException(errorMsg, RestStatus.INTERNAL_SERVER_ERROR); + } + + /** + * + * @param typesStr a list of input profile types separated by comma + * @param all whether we should return all profile in the response + * @return profiles to collect for an entity + */ + protected Set getEntityProfilesToCollect(String typesStr, boolean all) { + if (all) { + return this.allEntityProfileTypes; + } else if (Strings.isEmpty(typesStr)) { + return this.defaultEntityProfileTypes; + } else { + // Filter out unsupported types + Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); + return EntityProfileName.getNames(Sets.intersection(allEntityProfileTypeStrs, typesInRequest)); + } + } + + /** + * + * @param typesStr a list of input profile types separated by comma + * @param all whether we should return all profile in the response + * @return profiles to collect for a detector + */ + protected Set getProfilesToCollect(String typesStr, boolean all) { + if (all) { + return this.allProfileTypes; + } else if (Strings.isEmpty(typesStr)) { + return this.defaultDetectorProfileTypes; + } else { + // Filter out unsupported types + Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); + return ProfileName.getNames(Sets.intersection(allProfileTypeStrs, typesInRequest)); + } + } + + protected ActionListener getProfileActionListener(ActionListener listener) { + return ActionListener.wrap(new CheckedConsumer() { + @Override + public void accept(ConfigProfileType profile) throws Exception { + listener + .onResponse( + createResponse( + 0, + null, + 0, + 0, + null, + null, + false, + Optional.empty(), + Optional.empty(), + false, + null, + profile, + null, + true + ) + ); + } + }, exception -> { listener.onFailure(exception); }); + } + + protected abstract EntityProfileRunnerType createEntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples + ); + + protected abstract ProfileRunnerType createProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeFilterer nodeFilter, + long requiredSamples, + TransportService transportService, + TaskManagerType taskManager, + TaskProfileRunnerType taskProfileRunner + ); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java new file mode 100644 index 000000000..99f4a69b3 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java @@ -0,0 +1,133 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public abstract class BaseJobTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager, IndexableResultType extends IndexableResult, ProfileActionType extends ActionType, ExecuteResultResponseRecorderType extends ExecuteResultResponseRecorder, IndexJobActionHandlerType extends IndexJobActionHandler> + extends HandledTransportAction { + private final Logger logger = LogManager.getLogger(BaseJobTransportAction.class); + + private final Client client; + private final ClusterService clusterService; + private final Settings settings; + private final NamedXContentRegistry xContentRegistry; + private volatile Boolean filterByEnabled; + private final TransportService transportService; + private final Setting requestTimeOutSetting; + private final String failtoStartMsg; + private final String failtoStopMsg; + private final Class configClass; + private final IndexJobActionHandlerType indexJobActionHandlerType; + + public BaseJobTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + NamedXContentRegistry xContentRegistry, + Setting filterByBackendRoleSettng, + String jobActionName, + Setting requestTimeOutSetting, + String failtoStartMsg, + String failtoStopMsg, + Class configClass, + IndexJobActionHandlerType indexJobActionHandlerType + ) { + super(jobActionName, transportService, actionFilters, JobRequest::new); + this.transportService = transportService; + this.client = client; + this.clusterService = clusterService; + this.settings = settings; + this.xContentRegistry = xContentRegistry; + filterByEnabled = filterByBackendRoleSettng.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSettng, it -> filterByEnabled = it); + this.requestTimeOutSetting = requestTimeOutSetting; + this.failtoStartMsg = failtoStartMsg; + this.failtoStopMsg = failtoStopMsg; + this.configClass = configClass; + this.indexJobActionHandlerType = indexJobActionHandlerType; + } + + @Override + protected void doExecute(Task task, JobRequest request, ActionListener actionListener) { + String configId = request.getConfigID(); + DateRange dateRange = request.getDateRange(); + boolean historical = request.isHistorical(); + String rawPath = request.getRawPath(); + TimeValue requestTimeout = requestTimeOutSetting.get(settings); + String errorMessage = rawPath.endsWith(RestHandlerUtils.START_JOB) ? failtoStartMsg : failtoStopMsg; + ActionListener listener = wrapRestActionListener(actionListener, errorMessage); + + // By the time request reaches here, the user permissions are validated by Security plugin. + User user = ParseUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + configId, + filterByEnabled, + listener, + (config) -> executeConfig(listener, configId, dateRange, historical, rawPath, requestTimeout, user, context), + client, + clusterService, + xContentRegistry, + configClass + ); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + private void executeConfig( + ActionListener listener, + String configId, + DateRange dateRange, + boolean historical, + String rawPath, + TimeValue requestTimeout, + User user, + ThreadContext.StoredContext context + ) { + if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { + indexJobActionHandlerType.startConfig(configId, dateRange, user, transportService, context, listener); + } else if (rawPath.endsWith(RestHandlerUtils.STOP_JOB)) { + indexJobActionHandlerType.stopConfig(configId, historical, user, transportService, listener); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseProfileTransportAction.java similarity index 50% rename from src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java rename to src/main/java/org/opensearch/timeseries/transport/BaseProfileTransportAction.java index af1bbed50..398e03994 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseProfileTransportAction.java @@ -9,9 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; @@ -23,26 +21,26 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ProfileName; import org.opensearch.transport.TransportService; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + /** * This class contains the logic to extract the stats from the nodes */ -public class ProfileTransportAction extends TransportNodesAction { - private static final Logger LOG = LogManager.getLogger(ProfileTransportAction.class); - private ModelManager modelManager; - private FeatureManager featureManager; - private CacheProvider cacheProvider; +public class BaseProfileTransportAction, CacheProviderType extends CacheProvider> + extends TransportNodesAction { + private static final Logger LOG = LogManager.getLogger(BaseProfileTransportAction.class); + private CacheProviderType cacheProvider; // the number of models to return. Defaults to 10. private volatile int numModelsToReturn; @@ -53,24 +51,22 @@ public class ProfileTransportAction extends TransportNodesAction maxModelNumberPerNode ) { super( - ProfileAction.NAME, + profileAction, threadPool, clusterService, transportService, @@ -80,11 +76,9 @@ public ProfileTransportAction( ThreadPool.Names.MANAGEMENT, ProfileNodeResponse.class ); - this.modelManager = modelManager; - this.featureManager = featureManager; this.cacheProvider = cacheProvider; - this.numModelsToReturn = AD_MAX_MODEL_SIZE_PER_NODE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it); + this.numModelsToReturn = maxModelNumberPerNode.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(maxModelNumberPerNode, it -> this.numModelsToReturn = it); } @Override @@ -104,41 +98,33 @@ protected ProfileNodeResponse newNodeResponse(StreamInput in) throws IOException @Override protected ProfileNodeResponse nodeOperation(ProfileNodeRequest request) { - String detectorId = request.getId(); - Set profiles = request.getProfilesToBeRetrieved(); + String configId = request.getConfigId(); + Set profiles = request.getProfilesToBeRetrieved(); int shingleSize = -1; long activeEntity = 0; long totalUpdates = 0; Map modelSize = null; List modelProfiles = null; int modelCount = 0; - if (request.isForMultiEntityDetector()) { - if (profiles.contains(DetectorProfileName.ACTIVE_ENTITIES)) { - activeEntity = cacheProvider.get().getActiveEntities(detectorId); - } - if (profiles.contains(DetectorProfileName.INIT_PROGRESS)) { - totalUpdates = cacheProvider.get().getTotalUpdates(detectorId);// get toal updates - } - if (profiles.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES)) { - modelSize = cacheProvider.get().getModelSize(detectorId); - } - // need to provide entity info for HCAD - if (profiles.contains(DetectorProfileName.MODELS)) { - modelProfiles = cacheProvider.get().getAllModelProfile(detectorId); - modelCount = modelProfiles.size(); - int limit = Math.min(numModelsToReturn, modelCount); - if (limit != modelCount) { - LOG.info("model number limit reached"); - modelProfiles = modelProfiles.subList(0, limit); - } - } - } else { - if (profiles.contains(DetectorProfileName.COORDINATING_NODE) || profiles.contains(DetectorProfileName.SHINGLE_SIZE)) { - shingleSize = featureManager.getShingleSize(detectorId); - } + if (profiles.contains(ProfileName.ACTIVE_ENTITIES)) { + activeEntity = cacheProvider.get().getActiveEntities(configId); + } - if (profiles.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) || profiles.contains(DetectorProfileName.MODELS)) { - modelSize = modelManager.getModelSize(detectorId); + // state profile requires totalUpdates as well + if (profiles.contains(ProfileName.INIT_PROGRESS) || profiles.contains(ProfileName.STATE)) { + totalUpdates = cacheProvider.get().getTotalUpdates(configId);// get toal updates + } + if (profiles.contains(ProfileName.TOTAL_SIZE_IN_BYTES)) { + modelSize = cacheProvider.get().getModelSize(configId); + } + // need to provide entity info for HCAD + if (profiles.contains(ProfileName.MODELS)) { + modelProfiles = cacheProvider.get().getAllModelProfile(configId); + modelCount = modelProfiles.size(); + int limit = Math.min(numModelsToReturn, modelCount); + if (limit != modelCount) { + LOG.info("model number limit reached"); + modelProfiles = modelProfiles.subList(0, limit); } } diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseSearchConfigInfoTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseSearchConfigInfoTransportAction.java new file mode 100644 index 000000000..536bb1466 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseSearchConfigInfoTransportAction.java @@ -0,0 +1,111 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public abstract class BaseSearchConfigInfoTransportAction extends + HandledTransportAction { + private static final Logger LOG = LogManager.getLogger(BaseSearchConfigInfoTransportAction.class); + private final Client client; + + public BaseSearchConfigInfoTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + String searchConfigActionName + ) { + super(searchConfigActionName, transportService, actionFilters, SearchConfigInfoRequest::new); + this.client = client; + } + + @Override + protected void doExecute(Task task, SearchConfigInfoRequest request, ActionListener actionListener) { + String name = request.getName(); + String rawPath = request.getRawPath(); + ActionListener listener = wrapRestActionListener(actionListener, CommonMessages.FAIL_TO_GET_CONFIG_INFO); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + SearchRequest searchRequest = new SearchRequest().indices(CommonName.CONFIG_INDEX); + if (rawPath.endsWith(RestHandlerUtils.COUNT)) { + // Count detectors + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); + client.search(searchRequest, new ActionListener() { + + @Override + public void onResponse(SearchResponse searchResponse) { + SearchConfigInfoResponse response = new SearchConfigInfoResponse( + searchResponse.getHits().getTotalHits().value, + false + ); + listener.onResponse(response); + } + + @Override + public void onFailure(Exception e) { + if (e.getClass() == IndexNotFoundException.class) { + // Anomaly Detectors index does not exist + // Could be that user is creating first detector + SearchConfigInfoResponse response = new SearchConfigInfoResponse(0, false); + listener.onResponse(response); + } else { + listener.onFailure(e); + } + } + }); + } else { + // Match name with existing detectors + TermsQueryBuilder query = QueryBuilders.termsQuery("name.keyword", name); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + searchRequest.source(searchSourceBuilder); + client.search(searchRequest, new ActionListener() { + + @Override + public void onResponse(SearchResponse searchResponse) { + boolean nameExists = false; + nameExists = searchResponse.getHits().getTotalHits().value > 0; + SearchConfigInfoResponse response = new SearchConfigInfoResponse(0, nameExists); + listener.onResponse(response); + } + + @Override + public void onFailure(Exception e) { + if (e.getClass() == IndexNotFoundException.class) { + // Anomaly Detectors index does not exist + // Could be that user is creating first detector + SearchConfigInfoResponse response = new SearchConfigInfoResponse(0, false); + listener.onResponse(response); + } else { + listener.onFailure(e); + } + } + }); + } + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseStatsNodesTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseStatsNodesTransportAction.java new file mode 100644 index 000000000..ac43684ba --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseStatsNodesTransportAction.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.transport.TransportService; + +public class BaseStatsNodesTransportAction extends + TransportNodesAction { + + private Stats stats; + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param stats TimeSeriesStats object + */ + public BaseStatsNodesTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + Stats stats, + String statsNodesActionName + ) { + super( + statsNodesActionName, + threadPool, + clusterService, + transportService, + actionFilters, + StatsRequest::new, + StatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + StatsNodeResponse.class + ); + this.stats = stats; + } + + @Override + protected StatsNodesResponse newResponse(StatsRequest request, List responses, List failures) { + return new StatsNodesResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected StatsNodeRequest newNodeRequest(StatsRequest request) { + return new StatsNodeRequest(request); + } + + @Override + protected StatsNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new StatsNodeResponse(in); + } + + @Override + protected StatsNodeResponse nodeOperation(StatsNodeRequest request) { + return createADStatsNodeResponse(request.getADStatsRequest()); + } + + protected StatsNodeResponse createADStatsNodeResponse(StatsRequest statsRequest) { + Map statValues = new HashMap<>(); + Set statsToBeRetrieved = statsRequest.getStatsToBeRetrieved(); + + for (String statName : stats.getNodeStats().keySet()) { + if (statsToBeRetrieved.contains(statName)) { + statValues.put(statName, stats.getStats().get(statName).getValue()); + } + } + + return new StatsNodeResponse(clusterService.localNode(), statValues); + } + +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseStatsTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseStatsTransportAction.java new file mode 100644 index 000000000..72c533e2e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseStatsTransportAction.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.transport.TransportService; + +public abstract class BaseStatsTransportAction extends HandledTransportAction { + public final Logger logger = LogManager.getLogger(BaseStatsTransportAction.class); + + protected final Client client; + protected final Stats stats; + protected final ClusterService clusterService; + + public BaseStatsTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + Stats stats, + ClusterService clusterService, + String statsAction + + ) { + super(statsAction, transportService, actionFilters, StatsRequest::new); + this.client = client; + this.stats = stats; + this.clusterService = clusterService; + } + + @Override + protected void doExecute(Task task, StatsRequest request, ActionListener actionListener) { + ActionListener listener = wrapRestActionListener(actionListener, CommonMessages.FAIL_TO_GET_STATS); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + getStats(client, listener, request); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + /** + * Make the 2 requests to get the node and cluster statistics + * + * @param client Client + * @param listener Listener to send response + * @param statsRequest Request containing stats to be retrieved + */ + public void getStats(Client client, ActionListener listener, StatsRequest statsRequest) { + // Use MultiResponsesDelegateActionListener to execute 2 async requests and create the response once they finish + MultiResponsesDelegateActionListener delegateListener = new MultiResponsesDelegateActionListener<>( + getRestStatsListener(listener), + 2, + "Unable to return Stats", + false + ); + + getClusterStats(client, delegateListener, statsRequest); + getNodeStats(client, delegateListener, statsRequest); + } + + /** + * Listener sends response once Node Stats and Cluster Stats are gathered + * + * @param listener Listener to send response + * @return ActionListener for StatsResponse + */ + public ActionListener getRestStatsListener(ActionListener listener) { + return ActionListener + .wrap( + statsResponse -> { listener.onResponse(new StatsTimeSeriesResponse(statsResponse)); }, + exception -> listener.onFailure(new OpenSearchStatusException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)) + ); + } + + /** + * Collect Cluster Stats into map to be retrieved + * + * @param statsRequest Request containing stats to be retrieved + * @return Map containing Cluster Stats + */ + protected Map getClusterStatsMap(StatsRequest statsRequest) { + Map clusterStats = new HashMap<>(); + Set statsToBeRetrieved = statsRequest.getStatsToBeRetrieved(); + stats + .getClusterStats() + .entrySet() + .stream() + .filter(s -> statsToBeRetrieved.contains(s.getKey())) + .forEach(s -> clusterStats.put(s.getKey(), s.getValue().getValue())); + return clusterStats; + } + + protected abstract void getClusterStats( + Client client, + MultiResponsesDelegateActionListener listener, + StatsRequest adStatsRequest + ); + + protected abstract void getNodeStats( + Client client, + MultiResponsesDelegateActionListener listener, + StatsRequest adStatsRequest + ); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseSuggestConfigParamTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseSuggestConfigParamTransportAction.java new file mode 100644 index 000000000..3bc2decb1 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseSuggestConfigParamTransportAction.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; + +import java.time.Clock; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.transport.SuggestName; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.Name; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.rest.handler.IntervalCalculation; +import org.opensearch.timeseries.rest.handler.LatestTimeRetriever; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.Sets; + +public abstract class BaseSuggestConfigParamTransportAction extends + HandledTransportAction { + public static final Logger logger = LogManager.getLogger(BaseSuggestConfigParamTransportAction.class); + + protected final Client client; + protected final SecurityClientUtil clientUtil; + protected final SearchFeatureDao searchFeatureDao; + protected volatile Boolean filterByEnabled; + protected Clock clock; + protected AnalysisType context; + protected final Set allSuggestParamStrs; + + public BaseSuggestConfigParamTransportAction( + String actionName, + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + Settings settings, + ActionFilters actionFilters, + TransportService transportService, + Setting filterByBackendRoleSetting, + AnalysisType context, + SearchFeatureDao searchFeatureDao + ) { + super(actionName, transportService, actionFilters, SuggestConfigParamRequest::new); + this.client = client; + this.clientUtil = clientUtil; + this.filterByEnabled = filterByBackendRoleSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterByEnabled = it); + this.clock = Clock.systemUTC(); + this.context = context; + this.searchFeatureDao = searchFeatureDao; + List allSuggestParams = Arrays.asList(SuggestName.values()); + this.allSuggestParamStrs = Name.getListStrs(allSuggestParams); + } + + @Override + protected void doExecute(Task task, SuggestConfigParamRequest request, ActionListener listener) { + User user = ParseUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute(user, listener, () -> suggestExecute(request, user, context, listener)); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + public void resolveUserAndExecute(User requestedUser, ActionListener listener, ExecutorFunction function) { + try { + // Check if user has backend roles + // When filter by is enabled, block users who do not have backend roles. + if (filterByEnabled) { + String error = checkFilterByBackendRoles(requestedUser); + if (error != null) { + listener.onFailure(new TimeSeriesException(error)); + return; + } + } + // Validate analysis + function.execute(); + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void suggestInterval(Config config, User user, TimeValue timeout, ActionListener listener) { + IntervalCalculation intervalCalculation = new IntervalCalculation(config, timeout, client, clientUtil, user, context, clock); + LatestTimeRetriever latestTimeRetriever = new LatestTimeRetriever( + config, + timeout, + clientUtil, + client, + user, + context, + searchFeatureDao + ); + + ActionListener intervalSuggestionListener = ActionListener + .wrap( + interval -> listener.onResponse(new SuggestConfigParamResponse.Builder().interval(interval).build()), + listener::onFailure + ); + ActionListener, Map>> latestTimeListener = ActionListener.wrap(latestEntityAttributes -> { + Optional latestTime = latestEntityAttributes.getLeft(); + if (latestTime.isPresent()) { + intervalCalculation.findInterval(latestTime.get(), latestEntityAttributes.getRight(), intervalSuggestionListener); + } else { + listener.onFailure(new TimeSeriesException("Empty data. Cannot find a good interval.")); + } + + }, exception -> { + listener.onFailure(exception); + logger.error("Failed to create search request for last data point", exception); + }); + + latestTimeRetriever.checkIfHC(latestTimeListener); + } + + protected void suggestHistory(Config config, ActionListener listener) { + listener.onResponse(new SuggestConfigParamResponse.Builder().history(config.suggestHistory()).build()); + } + + public abstract void suggestExecute( + SuggestConfigParamRequest request, + User user, + ThreadContext.StoredContext storedContext, + ActionListener listener + ); + + /** + * + * @param typesStr a list of input suggest types separated by comma + * @return parameters to suggest for a forecaster + */ + protected Set getParametersToSuggest(String typesStr) { + // Filter out unsupported params + Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); + return SuggestName.getNames(Sets.intersection(allSuggestParamStrs, typesInRequest)); + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseValidateConfigTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseValidateConfigTransportAction.java new file mode 100644 index 000000000..5b21a750b --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseValidateConfigTransportAction.java @@ -0,0 +1,228 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; + +import java.time.Clock; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigValidationIssue; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.rest.handler.Processor; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public abstract class BaseValidateConfigTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement> + extends HandledTransportAction { + public static final Logger logger = LogManager.getLogger(BaseValidateConfigTransportAction.class); + + protected final Client client; + protected final SecurityClientUtil clientUtil; + protected final ClusterService clusterService; + protected final NamedXContentRegistry xContentRegistry; + protected final IndexManagementType indexManagement; + protected final SearchFeatureDao searchFeatureDao; + protected volatile Boolean filterByEnabled; + protected Clock clock; + protected Settings settings; + + public BaseValidateConfigTransportAction( + String actionName, + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + Settings settings, + IndexManagementType indexManagement, + ActionFilters actionFilters, + TransportService transportService, + SearchFeatureDao searchFeatureDao, + Setting filterByBackendRoleSetting + ) { + super(actionName, transportService, actionFilters, ValidateConfigRequest::new); + this.client = client; + this.clientUtil = clientUtil; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.indexManagement = indexManagement; + this.filterByEnabled = filterByBackendRoleSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterByEnabled = it); + this.searchFeatureDao = searchFeatureDao; + this.clock = Clock.systemUTC(); + this.settings = settings; + } + + @Override + protected void doExecute(Task task, ValidateConfigRequest request, ActionListener listener) { + User user = ParseUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute(user, listener, () -> validateExecute(request, user, context, listener)); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + public void resolveUserAndExecute(User requestedUser, ActionListener listener, ExecutorFunction function) { + try { + // Check if user has backend roles + // When filter by is enabled, block users validating detectors who do not have backend roles. + if (filterByEnabled) { + String error = checkFilterByBackendRoles(requestedUser); + if (error != null) { + listener.onFailure(new TimeSeriesException(error)); + return; + } + } + // Validate analysis + function.execute(); + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void checkIndicesAndExecute( + List indices, + ExecutorFunction function, + ActionListener listener + ) { + SearchRequest searchRequest = new SearchRequest() + .indices(indices.toArray(new String[0])) + .source(new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery())); + client.search(searchRequest, ActionListener.wrap(r -> function.execute(), e -> { + if (e instanceof IndexNotFoundException) { + // IndexNotFoundException is converted to a ADValidationException that gets + // parsed to a DetectorValidationIssue that is returned to + // the user as a response indicating index doesn't exist + ConfigValidationIssue issue = parseValidationException( + new ValidationException(ADCommonMessages.INDEX_NOT_FOUND, ValidationIssueType.INDICES, ValidationAspect.DETECTOR) + ); + listener.onResponse(new ValidateConfigResponse(issue)); + return; + } + logger.error(e); + listener.onFailure(e); + })); + } + + protected Map getFeatureSubIssuesFromErrorMessage(String errorMessage) { + Map result = new HashMap<>(); + String[] subIssueMessagesSuffix = errorMessage.split(", "); + for (int i = 0; i < subIssueMessagesSuffix.length; i++) { + result.put(subIssueMessagesSuffix[i].split(": ")[1], subIssueMessagesSuffix[i].split(": ")[0]); + } + return result; + } + + public ConfigValidationIssue parseValidationException(ValidationException exception) { + String originalErrorMessage = exception.getMessage(); + String errorMessage = ""; + Map subIssues = null; + IntervalTimeConfiguration intervalSuggestion = exception.getIntervalSuggestion(); + switch (exception.getType()) { + case FEATURE_ATTRIBUTES: + int firstLeftBracketIndex = originalErrorMessage.indexOf("["); + int lastRightBracketIndex = originalErrorMessage.lastIndexOf("]"); + if (firstLeftBracketIndex != -1) { + // if feature issue messages are between square brackets like + // [Feature has issue: A, Feature has issue: B] + errorMessage = originalErrorMessage.substring(firstLeftBracketIndex + 1, lastRightBracketIndex); + subIssues = getFeatureSubIssuesFromErrorMessage(errorMessage); + } else { + // features having issue like over max feature limit, duplicate feature name, etc. + errorMessage = originalErrorMessage; + } + break; + case NAME: + case CATEGORY: + case DETECTION_INTERVAL: + case FILTER_QUERY: + case TIMEFIELD_FIELD: + case SHINGLE_SIZE_FIELD: + case WINDOW_DELAY: + case RESULT_INDEX: + case GENERAL_SETTINGS: + case AGGREGATION: + case TIMEOUT: + case INDICES: + case FORECAST_INTERVAL: + case IMPUTATION: + case HORIZON_SIZE: + case RECENCY_EMPHASIS: + errorMessage = originalErrorMessage; + break; + } + return new ConfigValidationIssue(exception.getAspect(), exception.getType(), errorMessage, subIssues, intervalSuggestion); + } + + public void validateExecute( + ValidateConfigRequest request, + User user, + ThreadContext.StoredContext storedContext, + ActionListener listener + ) { + storedContext.restore(); + Config detector = request.getConfig(); + ActionListener validateListener = ActionListener.wrap(response -> { + logger.debug("Result of validation process " + response); + // forcing response to be empty + listener.onResponse(new ValidateConfigResponse((ConfigValidationIssue) null)); + }, exception -> { + if (exception instanceof ValidationException) { + // ADValidationException is converted as validation issues returned as response to user + ConfigValidationIssue issue = parseValidationException((ValidationException) exception); + listener.onResponse(new ValidateConfigResponse(issue)); + return; + } + logger.error(exception); + listener.onFailure(exception); + }); + checkIndicesAndExecute(detector.getIndices(), () -> { + try { + createProcessor(detector, request, user).start(validateListener); + } catch (Exception exception) { + String errorMessage = String + .format(Locale.ROOT, "Unknown exception caught while validating detector %s", request.getConfig()); + logger.error(errorMessage, exception); + listener.onFailure(exception); + } + }, listener); + } + + protected abstract Processor createProcessor(Config detector, ValidateConfigRequest request, User user); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BooleanNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/BooleanNodeResponse.java new file mode 100644 index 000000000..c6b4f1285 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BooleanNodeResponse.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class BooleanNodeResponse extends BaseNodeResponse { + private final boolean answer; + + public BooleanNodeResponse(StreamInput in) throws IOException { + super(in); + answer = in.readBoolean(); + } + + public BooleanNodeResponse(DiscoveryNode node, boolean answer) { + super(node); + this.answer = answer; + } + + public boolean isAnswerTrue() { + return answer; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(answer); + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BooleanResponse.java b/src/main/java/org/opensearch/timeseries/transport/BooleanResponse.java new file mode 100644 index 000000000..8eb18475a --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BooleanResponse.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; + +public class BooleanResponse extends BaseNodesResponse implements ToXContentFragment { + private final boolean answer; + + public BooleanResponse(StreamInput in) throws IOException { + super(in); + answer = in.readBoolean(); + } + + public BooleanResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + this.answer = nodes.stream().anyMatch(response -> response.isAnswerTrue()); + ; + } + + public boolean isAnswerTrue() { + return answer; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(answer); + } + + @Override + protected List readNodesFrom(StreamInput in) throws IOException { + return in.readList(BooleanNodeResponse::new); + } + + @Override + protected void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(CommonName.ANSWER_FIELD, answer); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/CronNodeRequest.java b/src/main/java/org/opensearch/timeseries/transport/CronNodeRequest.java similarity index 93% rename from src/main/java/org/opensearch/ad/transport/CronNodeRequest.java rename to src/main/java/org/opensearch/timeseries/transport/CronNodeRequest.java index a5362ff46..aef33bb3c 100644 --- a/src/main/java/org/opensearch/ad/transport/CronNodeRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronNodeRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; diff --git a/src/main/java/org/opensearch/ad/transport/CronNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/CronNodeResponse.java similarity index 93% rename from src/main/java/org/opensearch/ad/transport/CronNodeResponse.java rename to src/main/java/org/opensearch/timeseries/transport/CronNodeResponse.java index f1e9fb0e1..b83e049d3 100644 --- a/src/main/java/org/opensearch/ad/transport/CronNodeResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronNodeResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -20,7 +20,7 @@ import org.opensearch.core.xcontent.XContentBuilder; public class CronNodeResponse extends BaseNodeResponse implements ToXContentObject { - static String NODE_ID = "node_id"; + public static String NODE_ID = "node_id"; public CronNodeResponse(StreamInput in) throws IOException { super(in); diff --git a/src/main/java/org/opensearch/ad/transport/CronRequest.java b/src/main/java/org/opensearch/timeseries/transport/CronRequest.java similarity index 95% rename from src/main/java/org/opensearch/ad/transport/CronRequest.java rename to src/main/java/org/opensearch/timeseries/transport/CronRequest.java index 0f91ae676..9f1add649 100644 --- a/src/main/java/org/opensearch/ad/transport/CronRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; diff --git a/src/main/java/org/opensearch/ad/transport/CronResponse.java b/src/main/java/org/opensearch/timeseries/transport/CronResponse.java similarity index 94% rename from src/main/java/org/opensearch/ad/transport/CronResponse.java rename to src/main/java/org/opensearch/timeseries/transport/CronResponse.java index 13332c3af..56998f2cf 100644 --- a/src/main/java/org/opensearch/ad/transport/CronResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; @@ -23,7 +23,7 @@ import org.opensearch.core.xcontent.XContentBuilder; public class CronResponse extends BaseNodesResponse implements ToXContentFragment { - static String NODES_JSON_KEY = "nodes"; + public static String NODES_JSON_KEY = "nodes"; public CronResponse(StreamInput in) throws IOException { super(in); diff --git a/src/main/java/org/opensearch/ad/transport/CronTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java similarity index 62% rename from src/main/java/org/opensearch/ad/transport/CronTransportAction.java rename to src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java index 82075d035..fb8912703 100644 --- a/src/main/java/org/opensearch/ad/transport/CronTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; @@ -19,27 +19,34 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.CronAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.task.ForecastTaskManager; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.transport.TransportService; public class CronTransportAction extends TransportNodesAction { private final Logger LOG = LogManager.getLogger(CronTransportAction.class); private NodeStateManager transportStateManager; - private ModelManager modelManager; + private ADModelManager adModelManager; private FeatureManager featureManager; - private CacheProvider cacheProvider; - private EntityColdStarter entityColdStarter; + private ADCacheProvider adCacheProvider; + private ForecastCacheProvider forecastCacheProvider; + private ADColdStart adEntityColdStarter; + private ForecastColdStart forecastColdStarter; private ADTaskManager adTaskManager; + private ForecastTaskManager forecastTaskManager; @Inject public CronTransportAction( @@ -48,11 +55,14 @@ public CronTransportAction( TransportService transportService, ActionFilters actionFilters, NodeStateManager tarnsportStatemanager, - ModelManager modelManager, + ADModelManager adModelManager, FeatureManager featureManager, - CacheProvider cacheProvider, - EntityColdStarter entityColdStarter, - ADTaskManager adTaskManager + ADCacheProvider adCacheProvider, + ForecastCacheProvider forecastCacheProvider, + ADColdStart adEntityColdStarter, + ForecastColdStart forecastColdStarter, + ADTaskManager adTaskManager, + ForecastTaskManager forecastTaskManager ) { super( CronAction.NAME, @@ -66,11 +76,14 @@ public CronTransportAction( CronNodeResponse.class ); this.transportStateManager = tarnsportStatemanager; - this.modelManager = modelManager; + this.adModelManager = adModelManager; this.featureManager = featureManager; - this.cacheProvider = cacheProvider; - this.entityColdStarter = entityColdStarter; + this.adCacheProvider = adCacheProvider; + this.forecastCacheProvider = forecastCacheProvider; + this.adEntityColdStarter = adEntityColdStarter; + this.forecastColdStarter = forecastColdStarter; this.adTaskManager = adTaskManager; + this.forecastTaskManager = forecastTaskManager; } @Override @@ -97,27 +110,27 @@ protected CronNodeResponse newNodeResponse(StreamInput in) throws IOException { */ @Override protected CronNodeResponse nodeOperation(CronNodeRequest request) { - LOG.info("Start running AD hourly cron."); + LOG.info("Start running hourly cron."); + // ====================== + // AD + // ====================== // makes checkpoints for hosted models and stop hosting models not actively // used. // for single-entity detector - modelManager - .maintenance(ActionListener.wrap(v -> LOG.debug("model maintenance done"), e -> LOG.error("Error maintaining model", e))); + adModelManager + .maintenance(ActionListener.wrap(v -> LOG.debug("model maintenance done"), e -> LOG.error("Error maintaining ad model", e))); // for multi-entity detector - cacheProvider.get().maintenance(); + adCacheProvider.get().maintenance(); // delete unused buffered shingle data featureManager.maintenance(); - // delete unused transport state - transportStateManager.maintenance(); - - entityColdStarter.maintenance(); + adEntityColdStarter.maintenance(); // clean child tasks and AD results of deleted detector level task - adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + adTaskManager.cleanChildTasksAndResultsOfDeletedTask(); // clean AD results of deleted detector - adTaskManager.cleanADResultOfDeletedDetector(); + adTaskManager.cleanResultOfDeletedConfig(); // maintain running historical tasks: reset task state as stopped if not running and clean stale running entities adTaskManager.maintainRunningHistoricalTasks(transportService, 100); @@ -125,6 +138,22 @@ protected CronNodeResponse nodeOperation(CronNodeRequest request) { // maintain running realtime tasks: clean stale running realtime task cache adTaskManager.maintainRunningRealtimeTasks(); + // ====================== + // Forecast + // ====================== + forecastCacheProvider.get().maintenance(); + forecastColdStarter.maintenance(); + // clean child tasks and forecast results of deleted forecaster level task + forecastTaskManager.cleanChildTasksAndResultsOfDeletedTask(); + forecastTaskManager.cleanResultOfDeletedConfig(); + forecastTaskManager.maintainRunningRealtimeTasks(); + + // ====================== + // Common + // ====================== + // delete unused transport state + transportStateManager.maintenance(); + return new CronNodeResponse(clusterService.localNode()); } } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java b/src/main/java/org/opensearch/timeseries/transport/DeleteConfigRequest.java similarity index 60% rename from src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteConfigRequest.java index f87b6e0a1..93980ce83 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteConfigRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -17,40 +17,40 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.constant.CommonMessages; -public class DeleteAnomalyDetectorRequest extends ActionRequest { +public class DeleteConfigRequest extends ActionRequest { - private String detectorID; + private String configID; - public DeleteAnomalyDetectorRequest(StreamInput in) throws IOException { + public DeleteConfigRequest(StreamInput in) throws IOException { super(in); - this.detectorID = in.readString(); + this.configID = in.readString(); } - public DeleteAnomalyDetectorRequest(String detectorID) { + public DeleteConfigRequest(String detectorID) { super(); - this.detectorID = detectorID; + this.configID = detectorID; } - public String getDetectorID() { - return detectorID; + public String getConfigID() { + return configID; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(detectorID); + out.writeString(configID); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(detectorID)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configID)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } return validationException; } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeRequest.java similarity index 67% rename from src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeRequest.java index d10eef4c3..6af6b9fcc 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -22,26 +22,26 @@ */ public class DeleteModelNodeRequest extends TransportRequest { - private String adID; + private String configID; DeleteModelNodeRequest() {} - DeleteModelNodeRequest(StreamInput in) throws IOException { + public DeleteModelNodeRequest(StreamInput in) throws IOException { super(in); - this.adID = in.readString(); + this.configID = in.readString(); } - DeleteModelNodeRequest(DeleteModelRequest request) { - this.adID = request.getAdID(); + public DeleteModelNodeRequest(DeleteModelRequest request) { + this.configID = request.getAdID(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(adID); + out.writeString(configID); } - public String getAdID() { - return adID; + public String getConfigID() { + return configID; } } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeResponse.java similarity index 96% rename from src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeResponse.java index c71e7368c..a57cb0d30 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelRequest.java similarity index 77% rename from src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelRequest.java index 9ec58acda..d6b119e6a 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -17,24 +17,24 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.nodes.BaseNodesRequest; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; /** * Request should be sent from the handler logic of transport delete detector API * */ public class DeleteModelRequest extends BaseNodesRequest implements ToXContentObject { - private String adID; + private String configID; public String getAdID() { - return adID; + return configID; } public DeleteModelRequest() { @@ -43,25 +43,25 @@ public DeleteModelRequest() { public DeleteModelRequest(StreamInput in) throws IOException { super(in); - this.adID = in.readString(); + this.configID = in.readString(); } public DeleteModelRequest(String adID, DiscoveryNode... nodes) { super(nodes); - this.adID = adID; + this.configID = adID; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(adID); + out.writeString(configID); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(adID)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configID)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } return validationException; } @@ -69,7 +69,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(CommonName.CONFIG_ID_KEY, configID); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelResponse.java similarity index 97% rename from src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelResponse.java index f2cbe2468..a2154481a 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java b/src/main/java/org/opensearch/timeseries/transport/EntityProfileRequest.java similarity index 85% rename from src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java rename to src/main/java/org/opensearch/timeseries/transport/EntityProfileRequest.java index 2aba165a7..edee7f379 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/EntityProfileRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -19,28 +19,27 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.EntityProfileName; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfileName; public class EntityProfileRequest extends ActionRequest implements ToXContentObject { public static final String ENTITY = "entity"; public static final String PROFILES = "profiles"; - private String adID; + private String configID; // changed from String to Entity since 1.1 private Entity entityValue; private Set profilesToCollect; public EntityProfileRequest(StreamInput in) throws IOException { super(in); - adID = in.readString(); + configID = in.readString(); entityValue = new Entity(in); int size = in.readVInt(); @@ -54,13 +53,13 @@ public EntityProfileRequest(StreamInput in) throws IOException { public EntityProfileRequest(String adID, Entity entityValue, Set profilesToCollect) { super(); - this.adID = adID; + this.configID = adID; this.entityValue = entityValue; this.profilesToCollect = profilesToCollect; } - public String getAdID() { - return adID; + public String getConfigID() { + return configID; } public Entity getEntityValue() { @@ -74,7 +73,7 @@ public Set getProfilesToCollect() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(adID); + out.writeString(configID); entityValue.writeTo(out); out.writeVInt(profilesToCollect.size()); @@ -86,8 +85,8 @@ public void writeTo(StreamOutput out) throws IOException { @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(adID)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configID)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } if (entityValue == null) { validationException = addValidationError("Entity value is missing", validationException); @@ -101,7 +100,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(CommonName.CONFIG_ID_KEY, configID); builder.field(ENTITY, entityValue); builder.field(PROFILES, profilesToCollect); builder.endObject(); diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java b/src/main/java/org/opensearch/timeseries/transport/EntityProfileResponse.java similarity index 95% rename from src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java rename to src/main/java/org/opensearch/timeseries/transport/EntityProfileResponse.java index 1b8b51da2..8d7dc5843 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/EntityProfileResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.Optional; @@ -17,13 +17,13 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ModelProfileOnNode; public class EntityProfileResponse extends ActionResponse implements ToXContentObject { public static final String ACTIVE = "active"; @@ -128,7 +128,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(TOTAL_UPDATES, totalUpdates); } if (modelProfile != null) { - builder.field(ADCommonName.MODEL, modelProfile); + builder.field(CommonName.MODEL, modelProfile); } builder.endObject(); return builder; @@ -140,7 +140,7 @@ public String toString() { builder.append(ACTIVE, isActive); builder.append(LAST_ACTIVE_TS, lastActiveMs); builder.append(TOTAL_UPDATES, totalUpdates); - builder.append(ADCommonName.MODEL, modelProfile); + builder.append(CommonName.MODEL, modelProfile); return builder.toString(); } diff --git a/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java b/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java new file mode 100644 index 000000000..6f5294d4c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java @@ -0,0 +1,266 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.ratelimit.ColdEntityWorker; +import org.opensearch.timeseries.ratelimit.ColdStartWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.SaveResultStrategy; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Shared code to implement an entity result transportation + * (e.g., EntityForecastResultTransportAction) + * + */ +public class EntityResultProcessor, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ModelColdStartType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, HCCheckpointReadWorkerType extends CheckpointReadWorker, ColdEntityWorkerType extends ColdEntityWorker> { + + private static final Logger LOG = LogManager.getLogger(EntityResultProcessor.class); + + private CacheProvider cache; + private ModelManagerType modelManager; + private Stats stats; + private ColdStartWorkerType entityColdStartWorker; + private HCCheckpointReadWorkerType checkpointReadQueue; + private ColdEntityWorkerType coldEntityQueue; + private SaveResultStrategyType saveResultStrategy; + private StatNames modelCorruptionStat; + + public EntityResultProcessor( + CacheProvider cache, + ModelManagerType manager, + Stats stats, + ColdStartWorkerType entityColdStartWorker, + HCCheckpointReadWorkerType checkpointReadQueue, + ColdEntityWorkerType coldEntityQueue, + SaveResultStrategyType saveResultStrategy, + StatNames modelCorruptionStat + ) { + this.cache = cache; + this.modelManager = manager; + this.stats = stats; + this.entityColdStartWorker = entityColdStartWorker; + this.checkpointReadQueue = checkpointReadQueue; + this.coldEntityQueue = coldEntityQueue; + this.saveResultStrategy = saveResultStrategy; + this.modelCorruptionStat = modelCorruptionStat; + } + + public ActionListener> onGetConfig( + ActionListener listener, + String forecasterId, + EntityResultRequest request, + Optional prevException, + AnalysisType analysisType + ) { + return ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new EndRunException(forecasterId, "Config " + forecasterId + " is not available.", false)); + return; + } + + Config config = configOptional.get(); + + if (request.getEntities() == null) { + listener.onFailure(new EndRunException(forecasterId, "Fail to get any entities from request.", false)); + return; + } + + Map cacheMissEntities = new HashMap<>(); + for (Entry entityEntry : request.getEntities().entrySet()) { + Entity entity = entityEntry.getKey(); + + if (isEntityFromOldNodeMsg(entity) && config.getCategoryFields() != null && config.getCategoryFields().size() == 1) { + Map attrValues = entity.getAttributes(); + // handle a request from a version before OpenSearch 1.1. + entity = Entity.createSingleAttributeEntity(config.getCategoryFields().get(0), attrValues.get(CommonName.EMPTY_FIELD)); + } + + Optional modelIdOptional = entity.getModelId(forecasterId); + if (modelIdOptional.isEmpty()) { + continue; + } + + String modelId = modelIdOptional.get(); + double[] datapoint = entityEntry.getValue(); + ModelState entityModel = cache.get().get(modelId, config); + if (entityModel == null) { + // cache miss + cacheMissEntities.put(entity, datapoint); + continue; + } + try { + IntermediateResultType result = modelManager + .getResult( + new Sample(datapoint, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd())), + entityModel, + modelId, + Optional.ofNullable(entity), + config, + request.getTaskId() + ); + + saveResultStrategy + .saveResult( + result, + config, + Instant.ofEpochMilli(request.getStart()), + Instant.ofEpochMilli(request.getEnd()), + modelId, + datapoint, + Optional.of(entity), + request.getTaskId() + ); + } catch (IllegalArgumentException e) { + // fail to score likely due to model corruption. Re-cold start to recover. + LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); + stats.getStat(modelCorruptionStat.getName()).increment(); + cache.get().removeModel(forecasterId, modelId); + entityColdStartWorker + .put( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + datapoint, + request.getStart(), + entity, + request.getTaskId() + ) + ); + } + } + + // split hot and cold entities + Pair, List> hotColdEntities = cache + .get() + .selectUpdateCandidate(cacheMissEntities.keySet(), forecasterId, config); + + List hotEntityRequests = new ArrayList<>(); + List coldEntityRequests = new ArrayList<>(); + + for (Entity hotEntity : hotColdEntities.getLeft()) { + double[] hotEntityValue = cacheMissEntities.get(hotEntity); + if (hotEntityValue == null) { + LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", hotEntity)); + continue; + } + hotEntityRequests + .add( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + // hot entities has MEDIUM priority + RequestPriority.MEDIUM, + hotEntityValue, + request.getStart(), + hotEntity, + request.getTaskId() + ) + ); + } + + for (Entity coldEntity : hotColdEntities.getRight()) { + double[] coldEntityValue = cacheMissEntities.get(coldEntity); + if (coldEntityValue == null) { + LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", coldEntity)); + continue; + } + coldEntityRequests + .add( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + // cold entities has LOW priority + RequestPriority.LOW, + coldEntityValue, + request.getStart(), + coldEntity, + request.getTaskId() + ) + ); + } + + checkpointReadQueue.putAll(hotEntityRequests); + coldEntityQueue.putAll(coldEntityRequests); + // respond back + if (prevException.isPresent()) { + listener.onFailure(prevException.get()); + } else { + listener.onResponse(new AcknowledgedResponse(true)); + } + }, exception -> { + LOG + .error( + new ParameterizedMessage( + "fail to get entity's analysis result for config [{}]: start: [{}], end: [{}]", + forecasterId, + request.getStart(), + request.getEnd() + ), + exception + ); + listener.onFailure(exception); + }); + } + + /** + * Whether the received entity comes from an node that doesn't support multi-category fields. + * This can happen during rolling-upgrade or blue/green deployment. + * + * Specifically, when receiving an EntityResultRequest from an incompatible node, + * EntityResultRequest(StreamInput in) gets an String that represents an entity. + * But Entity class requires both an category field name and value. Since we + * don't have access to detector config in EntityResultRequest(StreamInput in), + * we put CommonName.EMPTY_FIELD as the placeholder. In this method, + * we use the same CommonName.EMPTY_FIELD to check if the deserialized entity + * comes from an incompatible node. If it is, we will add the field name back + * as EntityResultTranportAction has access to the detector config object. + * + * @param categoricalValues deserialized Entity from inbound message. + * @return Whether the received entity comes from an node that doesn't support multi-category fields. + */ + private boolean isEntityFromOldNodeMsg(Entity categoricalValues) { + Map attrValues = categoricalValues.getAttributes(); + return (attrValues != null && attrValues.containsKey(CommonName.EMPTY_FIELD)); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java b/src/main/java/org/opensearch/timeseries/transport/EntityResultRequest.java similarity index 69% rename from src/main/java/org/opensearch/ad/transport/EntityResultRequest.java rename to src/main/java/org/opensearch/timeseries/transport/EntityResultRequest.java index 91041f447..8177178f4 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/EntityResultRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -17,32 +17,31 @@ import java.util.Locale; import java.util.Map; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; public class EntityResultRequest extends ActionRequest implements ToXContentObject { - private static final Logger LOG = LogManager.getLogger(EntityResultRequest.class); - private String detectorId; + protected String configId; // changed from Map to Map - private Map entities; - private long start; - private long end; + protected Map entities; + // data start/end time epoch + protected long start; + protected long end; + protected AnalysisType analysisType; + protected String taskId; public EntityResultRequest(StreamInput in) throws IOException { super(in); - this.detectorId = in.readString(); + this.configId = in.readString(); // guarded with version check. Just in case we receive requests from older node where we use String // to represent an entity @@ -50,18 +49,33 @@ public EntityResultRequest(StreamInput in) throws IOException { this.start = in.readLong(); this.end = in.readLong(); + + // newly added + if (in.available() > 0) { + analysisType = in.readEnum(AnalysisType.class); + taskId = in.readOptionalString(); + } } - public EntityResultRequest(String detectorId, Map entities, long start, long end) { + public EntityResultRequest( + String configId, + Map entities, + long start, + long end, + AnalysisType analysisType, + String taskId + ) { super(); - this.detectorId = detectorId; + this.configId = configId; this.entities = entities; this.start = start; this.end = end; + this.analysisType = analysisType; + this.taskId = taskId; } - public String getId() { - return this.detectorId; + public String getConfigId() { + return this.configId; } public Map getEntities() { @@ -76,23 +90,33 @@ public long getEnd() { return this.end; } + public AnalysisType getAnalysisType() { + return analysisType; + } + + public String getTaskId() { + return taskId; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(this.detectorId); + out.writeString(this.configId); // guarded with version check. Just in case we send requests to older node where we use String // to represent an entity out.writeMap(entities, (s, e) -> e.writeTo(s), StreamOutput::writeDoubleArray); out.writeLong(this.start); out.writeLong(this.end); + out.writeEnum(analysisType); + out.writeOptionalString(taskId); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(detectorId)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configId)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } if (start <= 0 || end <= 0 || start > end) { validationException = addValidationError( @@ -106,7 +130,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, detectorId); + builder.field(CommonName.CONFIG_ID_KEY, configId); builder.field(CommonName.START_JSON_KEY, start); builder.field(CommonName.END_JSON_KEY, end); builder.startArray(CommonName.ENTITIES_JSON_KEY); @@ -119,6 +143,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } } builder.endArray(); + builder.field(CommonName.ANALYSIS_TYPE_FIELD, analysisType); + builder.field(CommonName.TASK_ID_FIELD, taskId); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/timeseries/transport/ForecastRunOnceProfileNodeRequest.java b/src/main/java/org/opensearch/timeseries/transport/ForecastRunOnceProfileNodeRequest.java new file mode 100644 index 000000000..4c2895378 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ForecastRunOnceProfileNodeRequest.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.forecast.transport.ForecastRunOnceProfileRequest; +import org.opensearch.transport.TransportRequest; + +public class ForecastRunOnceProfileNodeRequest extends TransportRequest { + private final ForecastRunOnceProfileRequest request; + + public ForecastRunOnceProfileNodeRequest(StreamInput in) throws IOException { + super(in); + request = new ForecastRunOnceProfileRequest(in); + } + + public ForecastRunOnceProfileNodeRequest(ForecastRunOnceProfileRequest request) { + this.request = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + request.writeTo(out); + } + + public String getConfigId() { + return request.getConfigId(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java b/src/main/java/org/opensearch/timeseries/transport/GetConfigRequest.java similarity index 86% rename from src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java rename to src/main/java/org/opensearch/timeseries/transport/GetConfigRequest.java index aef29626d..1aed87c66 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/GetConfigRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -19,9 +19,9 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.timeseries.model.Entity; -public class GetAnomalyDetectorRequest extends ActionRequest { +public class GetConfigRequest extends ActionRequest { - private String detectorID; + private String configID; private long version; private boolean returnJob; private boolean returnTask; @@ -30,9 +30,9 @@ public class GetAnomalyDetectorRequest extends ActionRequest { private boolean all; private Entity entity; - public GetAnomalyDetectorRequest(StreamInput in) throws IOException { + public GetConfigRequest(StreamInput in) throws IOException { super(in); - detectorID = in.readString(); + configID = in.readString(); version = in.readLong(); returnJob = in.readBoolean(); returnTask = in.readBoolean(); @@ -44,7 +44,7 @@ public GetAnomalyDetectorRequest(StreamInput in) throws IOException { } } - public GetAnomalyDetectorRequest( + public GetConfigRequest( String detectorID, long version, boolean returnJob, @@ -55,7 +55,7 @@ public GetAnomalyDetectorRequest( Entity entity ) { super(); - this.detectorID = detectorID; + this.configID = detectorID; this.version = version; this.returnJob = returnJob; this.returnTask = returnTask; @@ -65,8 +65,8 @@ public GetAnomalyDetectorRequest( this.entity = entity; } - public String getDetectorID() { - return detectorID; + public String getConfigID() { + return configID; } public long getVersion() { @@ -100,7 +100,7 @@ public Entity getEntity() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(detectorID); + out.writeString(configID); out.writeLong(version); out.writeBoolean(returnJob); out.writeBoolean(returnTask); diff --git a/src/main/java/org/opensearch/timeseries/transport/JobRequest.java b/src/main/java/org/opensearch/timeseries/transport/JobRequest.java new file mode 100644 index 000000000..98b56930f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/JobRequest.java @@ -0,0 +1,98 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.model.DateRange; + +public class JobRequest extends ActionRequest { + + private String configID; + private DateRange dateRange; + private boolean historical; + private String rawPath; + + public JobRequest(StreamInput in) throws IOException { + super(in); + configID = in.readString(); + rawPath = in.readString(); + if (in.readBoolean()) { + dateRange = new DateRange(in); + } + historical = in.readBoolean(); + } + + public JobRequest(String detectorID, String rawPath) { + this(detectorID, null, false, rawPath); + } + + /** + * Constructor function. + * + * The dateRange and historical boolean can be passed in individually. + * The historical flag is for stopping analysis, the dateRange is for + * starting analysis. It's ok if historical is true but dateRange is + * null. + * + * @param configID config identifier + * @param dateRange analysis date range + * @param historical historical analysis or not + * @param rawPath raw request path + */ + public JobRequest(String configID, DateRange dateRange, boolean historical, String rawPath) { + super(); + this.configID = configID; + this.dateRange = dateRange; + this.historical = historical; + this.rawPath = rawPath; + } + + public String getConfigID() { + return configID; + } + + public DateRange getDateRange() { + return dateRange; + } + + public String getRawPath() { + return rawPath; + } + + public boolean isHistorical() { + return historical; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(configID); + out.writeString(rawPath); + if (dateRange != null) { + out.writeBoolean(true); + dateRange.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeBoolean(historical); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java b/src/main/java/org/opensearch/timeseries/transport/ProfileNodeRequest.java similarity index 74% rename from src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java rename to src/main/java/org/opensearch/timeseries/transport/ProfileNodeRequest.java index d3db87d33..a5ebfb61a 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/ProfileNodeRequest.java @@ -9,14 +9,14 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.Set; -import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.model.ProfileName; import org.opensearch.transport.TransportRequest; /** @@ -39,8 +39,8 @@ public ProfileNodeRequest(ProfileRequest request) { this.request = request; } - public String getId() { - return request.getId(); + public String getConfigId() { + return request.getConfigId(); } /** @@ -48,16 +48,17 @@ public String getId() { * * @return the set that contains the profile names marked for retrieval */ - public Set getProfilesToBeRetrieved() { + public Set getProfilesToBeRetrieved() { return request.getProfilesToBeRetrieved(); } /** * - * @return Whether this is about a multi-entity detector or not + * @return Whether the models are stored in priority cache. AD single stream models are stored in ModelManager. + * Other models are stored in priority cache. */ - public boolean isForMultiEntityDetector() { - return request.isForMultiEntityDetector(); + public boolean isModelInPriorityCache() { + return request.isModelInPriorityCache(); } @Override diff --git a/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/ProfileNodeResponse.java similarity index 92% rename from src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java rename to src/main/java/org/opensearch/timeseries/transport/ProfileNodeResponse.java index 9517f6add..37be94232 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/ProfileNodeResponse.java @@ -9,21 +9,20 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; import java.util.Map; import org.opensearch.action.support.nodes.BaseNodeResponse; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ModelProfile; /** * Profile response on a node @@ -137,12 +136,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endObject(); - builder.field(ADCommonName.SHINGLE_SIZE, shingleSize); - builder.field(ADCommonName.ACTIVE_ENTITIES, activeEntities); - builder.field(ADCommonName.TOTAL_UPDATES, totalUpdates); + builder.field(CommonName.SHINGLE_SIZE, shingleSize); + builder.field(CommonName.ACTIVE_ENTITIES, activeEntities); + builder.field(CommonName.TOTAL_UPDATES, totalUpdates); - builder.field(ADCommonName.MODEL_COUNT, modelCount); - builder.startArray(ADCommonName.MODELS); + builder.field(CommonName.MODEL_COUNT, modelCount); + builder.startArray(CommonName.MODELS); for (ModelProfile modelProfile : modelProfiles) { builder.startObject(); modelProfile.toXContent(builder, params); diff --git a/src/main/java/org/opensearch/ad/transport/ProfileRequest.java b/src/main/java/org/opensearch/timeseries/transport/ProfileRequest.java similarity index 56% rename from src/main/java/org/opensearch/ad/transport/ProfileRequest.java rename to src/main/java/org/opensearch/timeseries/transport/ProfileRequest.java index ea779e733..07cdd3d3c 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/ProfileRequest.java @@ -9,73 +9,68 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.HashSet; import java.util.Set; import org.opensearch.action.support.nodes.BaseNodesRequest; -import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.model.ProfileName; /** * implements a request to obtain profiles about an AD detector */ public class ProfileRequest extends BaseNodesRequest { - private Set profilesToBeRetrieved; - private String detectorId; - private boolean forMultiEntityDetector; + private Set profilesToBeRetrieved; + private String configId; + private boolean modelInPriorityCache; public ProfileRequest(StreamInput in) throws IOException { super(in); int size = in.readVInt(); - profilesToBeRetrieved = new HashSet(); + profilesToBeRetrieved = new HashSet(); if (size != 0) { for (int i = 0; i < size; i++) { - profilesToBeRetrieved.add(in.readEnum(DetectorProfileName.class)); + profilesToBeRetrieved.add(in.readEnum(ProfileName.class)); } } - detectorId = in.readString(); - forMultiEntityDetector = in.readBoolean(); + configId = in.readString(); + modelInPriorityCache = in.readBoolean(); } /** * Constructor * - * @param detectorId detector's id + * @param configId config id * @param profilesToBeRetrieved profiles to be retrieved - * @param forMultiEntityDetector whether the request is for a multi-entity detector + * @param forHC whether the request is for an high-cardinality analysis * @param nodes nodes of nodes' profiles to be retrieved */ - public ProfileRequest( - String detectorId, - Set profilesToBeRetrieved, - boolean forMultiEntityDetector, - DiscoveryNode... nodes - ) { + public ProfileRequest(String configId, Set profilesToBeRetrieved, boolean forHC, DiscoveryNode... nodes) { super(nodes); - this.detectorId = detectorId; + this.configId = configId; this.profilesToBeRetrieved = profilesToBeRetrieved; - this.forMultiEntityDetector = forMultiEntityDetector; + this.modelInPriorityCache = forHC; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeVInt(profilesToBeRetrieved.size()); - for (DetectorProfileName profile : profilesToBeRetrieved) { + for (ProfileName profile : profilesToBeRetrieved) { out.writeEnum(profile); } - out.writeString(detectorId); - out.writeBoolean(forMultiEntityDetector); + out.writeString(configId); + out.writeBoolean(modelInPriorityCache); } - public String getId() { - return detectorId; + public String getConfigId() { + return configId; } /** @@ -83,15 +78,16 @@ public String getId() { * * @return the set that contains the profile names marked for retrieval */ - public Set getProfilesToBeRetrieved() { + public Set getProfilesToBeRetrieved() { return profilesToBeRetrieved; } /** * - * @return Whether this is about a multi-entity detector or not + * @return Whether the models are stored in priority cache. AD single stream models are stored in ModelManager. + * Other models are stored in priority cache. */ - public boolean isForMultiEntityDetector() { - return forMultiEntityDetector; + public boolean isModelInPriorityCache() { + return modelInPriorityCache; } } diff --git a/src/main/java/org/opensearch/ad/transport/ProfileResponse.java b/src/main/java/org/opensearch/timeseries/transport/ProfileResponse.java similarity index 90% rename from src/main/java/org/opensearch/ad/transport/ProfileResponse.java rename to src/main/java/org/opensearch/timeseries/transport/ProfileResponse.java index 11ba28163..9d1680430 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/ProfileResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.ArrayList; @@ -20,14 +20,14 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.cluster.ClusterName; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; /** * This class consists of the aggregated responses from the nodes @@ -35,13 +35,13 @@ public class ProfileResponse extends BaseNodesResponse implements ToXContentFragment { private static final Logger LOG = LogManager.getLogger(ProfileResponse.class); // filed name in toXContent - static final String COORDINATING_NODE = ADCommonName.COORDINATING_NODE; - static final String SHINGLE_SIZE = ADCommonName.SHINGLE_SIZE; - static final String TOTAL_SIZE = ADCommonName.TOTAL_SIZE_IN_BYTES; - static final String ACTIVE_ENTITY = ADCommonName.ACTIVE_ENTITIES; - static final String MODELS = ADCommonName.MODELS; - static final String TOTAL_UPDATES = ADCommonName.TOTAL_UPDATES; - static final String MODEL_COUNT = ADCommonName.MODEL_COUNT; + public static final String COORDINATING_NODE = CommonName.COORDINATING_NODE; + public static final String SHINGLE_SIZE = CommonName.SHINGLE_SIZE; + public static final String TOTAL_SIZE = CommonName.TOTAL_SIZE_IN_BYTES; + static final String ACTIVE_ENTITY = CommonName.ACTIVE_ENTITIES; + public static final String MODELS = CommonName.MODELS; + static final String TOTAL_UPDATES = CommonName.TOTAL_UPDATES; + static final String MODEL_COUNT = CommonName.MODEL_COUNT; // changed from ModelProfile to ModelProfileOnNode since Opensearch 1.1 private ModelProfileOnNode[] modelProfile; diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultBulkRequest.java b/src/main/java/org/opensearch/timeseries/transport/ResultBulkRequest.java new file mode 100644 index 000000000..cd8efc9de --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultBulkRequest.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.ValidateActions; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; + +public class ResultBulkRequest> extends + ActionRequest + implements + Writeable { + private final List results; + + public ResultBulkRequest() { + results = new ArrayList<>(); + } + + public ResultBulkRequest(StreamInput in, Writeable.Reader reader) throws IOException { + super(in); + int size = in.readVInt(); + results = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + results.add(reader.read(in)); + } + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (results.isEmpty()) { + validationException = ValidateActions.addValidationError(CommonMessages.NO_REQUESTS_ADDED_ERR, validationException); + } + return validationException; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeVInt(results.size()); + for (ResultWriteRequestType result : results) { + result.writeTo(out); + } + } + + /** + * + * @return all of the results to send + */ + public List getAnomalyResults() { + return results; + } + + /** + * Add result to send + * @param resultWriteRequest The result write request + */ + public void add(ResultWriteRequestType resultWriteRequest) { + results.add(resultWriteRequest); + } + + /** + * + * @return total index requests + */ + public int numberOfActions() { + return results.size(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java b/src/main/java/org/opensearch/timeseries/transport/ResultBulkResponse.java similarity index 86% rename from src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java rename to src/main/java/org/opensearch/timeseries/transport/ResultBulkResponse.java index 70768311c..570bddca2 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/ResultBulkResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.ArrayList; @@ -21,7 +21,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -public class ADResultBulkResponse extends ActionResponse { +public class ResultBulkResponse extends ActionResponse { public static final String RETRY_REQUESTS_JSON_KEY = "retry_requests"; private List retryRequests; @@ -30,15 +30,15 @@ public class ADResultBulkResponse extends ActionResponse { * * @param retryRequests a list of requests to retry */ - public ADResultBulkResponse(List retryRequests) { + public ResultBulkResponse(List retryRequests) { this.retryRequests = retryRequests; } - public ADResultBulkResponse() { + public ResultBulkResponse() { this.retryRequests = null; } - public ADResultBulkResponse(StreamInput in) throws IOException { + public ResultBulkResponse(StreamInput in) throws IOException { int size = in.readInt(); if (size > 0) { retryRequests = new ArrayList<>(size); diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java new file mode 100644 index 000000000..f070c38c6 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.opensearch.index.IndexingPressure.MAX_INDEXING_BYTES; + +import java.io.IOException; +import java.util.List; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexingPressure; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; +import org.opensearch.timeseries.util.BulkUtil; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +@SuppressWarnings("rawtypes") +public abstract class ResultBulkTransportAction, ResultBulkRequestType extends ResultBulkRequest> + extends HandledTransportAction { + private static final Logger LOG = LogManager.getLogger(ResultBulkTransportAction.class); + protected IndexingPressure indexingPressure; + private final long primaryAndCoordinatingLimits; + protected float softLimit; + protected float hardLimit; + protected String indexName; + private Client client; + protected Random random; + + public ResultBulkTransportAction( + String actionName, + TransportService transportService, + ActionFilters actionFilters, + IndexingPressure indexingPressure, + Settings settings, + Client client, + float softLimit, + float hardLimit, + String indexName, + Writeable.Reader requestReader + ) { + super(actionName, transportService, actionFilters, requestReader, ThreadPool.Names.SAME); + this.indexingPressure = indexingPressure; + this.primaryAndCoordinatingLimits = MAX_INDEXING_BYTES.get(settings).getBytes(); + this.client = client; + + this.softLimit = softLimit; + this.hardLimit = hardLimit; + this.indexName = indexName; + + // random seed is 42. Can be any number + this.random = new Random(42); + } + + @Override + protected void doExecute(Task task, ResultBulkRequestType request, ActionListener listener) { + // Concurrent indexing memory limit = 10% of heap + // indexing pressure = indexing bytes / indexing limit + // Write all until index pressure (global indexing memory pressure) is less than 80% of 10% of heap. Otherwise, index + // all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure). + long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes(); + float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits; + @SuppressWarnings("rawtypes") + List results = request.getAnomalyResults(); + + if (results == null || results.size() < 1) { + listener.onResponse(new ResultBulkResponse()); + } + + BulkRequest bulkRequest = prepareBulkRequest(indexingPressurePercent, request); + + if (bulkRequest.numberOfActions() > 0) { + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(bulkResponse -> { + List failedRequests = BulkUtil.getFailedIndexRequest(bulkRequest, bulkResponse); + listener.onResponse(new ResultBulkResponse(failedRequests)); + }, e -> { + LOG.error("Failed to bulk index AD result", e); + listener.onFailure(e); + })); + } else { + listener.onResponse(new ResultBulkResponse()); + } + } + + protected abstract BulkRequest prepareBulkRequest(float indexingPressurePercent, ResultBulkRequestType request); + + protected void addResult(BulkRequest bulkRequest, ToXContentObject result, String resultIndex) { + String index = resultIndex == null ? indexName : resultIndex; + try (XContentBuilder builder = jsonBuilder()) { + IndexRequest indexRequest = new IndexRequest(index).source(result.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); + bulkRequest.add(indexRequest); + } catch (IOException e) { + LOG.error("Failed to prepare bulk index request for index " + index, e); + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java new file mode 100644 index 000000000..74ee17eb3 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java @@ -0,0 +1,884 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.net.ConnectException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.OpenSearchTimeoutException; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.search.SearchPhaseExecutionException; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.NetworkExceptionHelper; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.node.NodeClosedException; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.common.exception.ClientException; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.feature.CompositeRetriever; +import org.opensearch.timeseries.feature.CompositeRetriever.PageIterator; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SinglePointFeatures; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.ActionNotFoundTransportException; +import org.opensearch.transport.ConnectTransportException; +import org.opensearch.transport.NodeNotConnectedException; +import org.opensearch.transport.ReceiveTimeoutTransportException; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportService; + +public abstract class ResultProcessor, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskManagerType extends TaskManager> { + + private static final Logger LOG = LogManager.getLogger(ResultProcessor.class); + + static final String WAIT_FOR_THRESHOLD_ERR_MSG = "Exception in waiting for threshold result"; + + static final String NO_ACK_ERR = "no acknowledgements from model hosting nodes."; + + public static final String TROUBLE_QUERYING_ERR_MSG = "Having trouble querying data: "; + + public static final String NULL_RESPONSE = "Received null response from"; + + public static final String INDEX_READ_BLOCKED = "Cannot read user index due to read block."; + + public static final String READ_WRITE_BLOCKED = "Cannot read/write due to global block."; + + public static final String NODE_UNRESPONSIVE_ERR_MSG = "Model node is unresponsive. Mute node"; + + protected final TransportRequestOptions option; + private String entityResultAction; + protected Class transportResultResponseClazz; + private StatNames hcRequestCountStat; + private String threadPoolName; + // within an interval, how many percents are used to process requests. + // 1.0 means we use all of the detection interval to process requests. + // to ensure we don't block next interval, it is better to set it less than 1.0. + private final float intervalRatioForRequest; + private int maxEntitiesPerInterval; + private int pageSize; + protected final ThreadPool threadPool; + private final HashRing hashRing; + protected final NodeStateManager nodeStateManager; + protected final TransportService transportService; + private final Stats timeSeriesStats; + private final TaskManagerType realTimeTaskManager; + private NamedXContentRegistry xContentRegistry; + private final Client client; + private final SecurityClientUtil clientUtil; + private Settings settings; + private final IndexNameExpressionResolver indexNameExpressionResolver; + private final ClusterService clusterService; + protected final FeatureManager featureManager; + protected final AnalysisType analysisType; + protected final String singleStreamActionName; + + protected boolean runOnce; + + public ResultProcessor( + Setting requestTimeoutSetting, + float intervalRatioForRequests, + String entityResultAction, + StatNames hcRequestCountStat, + Settings settings, + ClusterService clusterService, + ThreadPool threadPool, + String threadPoolName, + HashRing hashRing, + NodeStateManager nodeStateManager, + TransportService transportService, + Stats timeSeriesStats, + TaskManagerType realTimeTaskManager, + NamedXContentRegistry xContentRegistry, + Client client, + SecurityClientUtil clientUtil, + IndexNameExpressionResolver indexNameExpressionResolver, + Class transportResultResponseClazz, + FeatureManager featureManager, + Setting maxEntitiesPerIntervalSetting, + Setting pageSizeSetting, + AnalysisType context, + boolean runOnce, + String singleStreamActionName + ) { + this.option = TransportRequestOptions + .builder() + .withType(TransportRequestOptions.Type.REG) + .withTimeout(requestTimeoutSetting.get(settings)) + .build(); + this.intervalRatioForRequest = intervalRatioForRequests; + + this.maxEntitiesPerInterval = maxEntitiesPerIntervalSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(maxEntitiesPerIntervalSetting, it -> maxEntitiesPerInterval = it); + + this.pageSize = pageSizeSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(pageSizeSetting, it -> pageSize = it); + + this.entityResultAction = entityResultAction; + this.hcRequestCountStat = hcRequestCountStat; + this.threadPool = threadPool; + this.hashRing = hashRing; + this.nodeStateManager = nodeStateManager; + this.transportService = transportService; + this.timeSeriesStats = timeSeriesStats; + this.realTimeTaskManager = realTimeTaskManager; + this.xContentRegistry = xContentRegistry; + this.client = client; + this.clientUtil = clientUtil; + this.settings = settings; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.clusterService = clusterService; + this.transportResultResponseClazz = transportResultResponseClazz; + this.featureManager = featureManager; + this.analysisType = context; + this.threadPoolName = threadPoolName; + this.runOnce = runOnce; + this.singleStreamActionName = singleStreamActionName; + } + + /** + * didn't use ActionListener.wrap so that I can + * 1) use this to refer to the listener inside the listener + * 2) pass parameters using constructors + * + */ + class PageListener implements ActionListener { + private PageIterator pageIterator; + private String configId; + private long dataStartTime; + private long dataEndTime; + private Runnable finishRunnable; + private String taskId; + + PageListener( + PageIterator pageIterator, + String detectorId, + long dataStartTime, + long dataEndTime, + Runnable finishRunnable, + String taskId + ) { + this.pageIterator = pageIterator; + this.configId = detectorId; + this.dataStartTime = dataStartTime; + this.dataEndTime = dataEndTime; + this.finishRunnable = finishRunnable; + this.taskId = taskId; + } + + @Override + public void onResponse(CompositeRetriever.Page entityFeatures) { + if (pageIterator.hasNext()) { + pageIterator.next(this); + } else { + finishRunnable.run(); + } + if (entityFeatures != null && false == entityFeatures.isEmpty()) { + // wrap expensive operation inside ad threadpool + threadPool.executor(threadPoolName).execute(() -> { + try { + + Set>> node2Entities = entityFeatures + .getResults() + .entrySet() + .stream() + .filter(e -> hashRing.getOwningNodeWithSameLocalVersionForRealtime(e.getKey().toString()).isPresent()) + .collect( + Collectors + .groupingBy( + // from entity name to its node + e -> hashRing.getOwningNodeWithSameLocalVersionForRealtime(e.getKey().toString()).get(), + Collectors.toMap(Entry::getKey, Entry::getValue) + ) + ) + .entrySet(); + + Iterator>> iterator = node2Entities.iterator(); + + while (iterator.hasNext()) { + Entry> entry = iterator.next(); + DiscoveryNode modelNode = entry.getKey(); + if (modelNode == null) { + iterator.remove(); + continue; + } + String modelNodeId = modelNode.getId(); + if (nodeStateManager.isMuted(modelNodeId, configId)) { + LOG + .info( + String + .format( + Locale.ROOT, + ResultProcessor.NODE_UNRESPONSIVE_ERR_MSG + " %s for detector %s", + modelNodeId, + configId + ) + ); + iterator.remove(); + } + } + + final AtomicReference failure = new AtomicReference<>(); + node2Entities.stream().forEach(nodeEntity -> { + DiscoveryNode node = nodeEntity.getKey(); + transportService + .sendRequest( + node, + entityResultAction, + new EntityResultRequest( + configId, + nodeEntity.getValue(), + dataStartTime, + dataEndTime, + analysisType, + taskId + ), + option, + new ActionListenerResponseHandler<>( + new ErrorResponseListener(node.getId(), configId, failure), + AcknowledgedResponse::new, + ThreadPool.Names.SAME + ) + ); + }); + + } catch (Exception e) { + LOG.error("Unexpected exception", e); + handleException(e); + } + }); + } + } + + @Override + public void onFailure(Exception e) { + try { + LOG.error("Unexpetected exception", e); + handleException(e); + } finally { + // make sure we return listener + finishRunnable.run(); + } + } + + private void handleException(Exception e) { + Exception convertedException = convertedQueryFailureException(e, configId); + if (false == (convertedException instanceof TimeSeriesException)) { + Throwable cause = ExceptionsHelper.unwrapCause(convertedException); + convertedException = new InternalFailure(configId, cause); + } + nodeStateManager.setException(configId, convertedException); + } + } + + public ActionListener> onGetConfig( + ActionListener listener, + String configID, + TransportResultRequestType request, + Optional> hcDetectors + ) { + return ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new EndRunException(configID, "config is not available.", true)); + return; + } + + Config config = configOptional.get(); + // no stat increment in runOnce where hcDetectors is empty. + if (config.isHighCardinality() && hcDetectors.isPresent()) { + hcDetectors.get().add(configID); + timeSeriesStats.getStat(hcRequestCountStat.getName()).increment(); + } + + if (request.getStart() <= 0) { + long duration = config.getIntervalInMilliseconds(); + long executionStartTime = request.getEnd() - duration; + + request.setStart(executionStartTime); + } + long delayMillis = Optional + .ofNullable((IntervalTimeConfiguration) config.getWindowDelay()) + .map(t -> t.toDuration().toMillis()) + .orElse(0L); + long dataStartTime = request.getStart() - delayMillis; + long dataEndTime = request.getEnd() - delayMillis; + + if (runOnce) { + realTimeTaskManager.createRunOnceTaskAndCleanupStaleTasks(configID, config, transportService, ActionListener.wrap(r -> { + if (r == null) { + LOG.error("Unexpected empty new task for " + configID); + listener + .onFailure( + new OpenSearchStatusException( + "Failed to bootstrap run once task for " + configID, + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + return; + } + executeAnalysis(listener, configID, request, config, dataStartTime, dataEndTime, r.getTaskId()); + }, e -> { + LOG.error("Failed to init run once task for " + configID, e); + listener + .onFailure( + new OpenSearchStatusException( + "Failed to bootstrap run once task for " + configID, + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + })); + } else { + realTimeTaskManager + .initRealtimeTaskCacheAndCleanupStaleCache( + configID, + config, + transportService, + ActionListener + .runAfter( + initRealtimeTaskListener(configID), + () -> executeAnalysis(listener, configID, request, config, dataStartTime, dataEndTime, null) + ) + ); + } + + }, exception -> ResultProcessor.handleExecuteException(exception, listener, configID)); + } + + private ActionListener initRealtimeTaskListener(String configId) { + return ActionListener.wrap(r -> { + if (r) { + LOG.debug("Realtime task initied for config {}", configId); + } + }, e -> LOG.error("Failed to init realtime task for " + configId, e)); + } + + private void executeAnalysis( + ActionListener listener, + String configID, + ResultRequest request, + Config config, + long dataStartTime, + long dataEndTime, + String taskId + ) { + // HC logic starts here + if (config.isHighCardinality()) { + Optional previousException = nodeStateManager.fetchExceptionAndClear(configID); + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error(new ParameterizedMessage("Previous exception of [{}]", configID), exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + } + + // assume request are in epoch milliseconds + long nextDetectionStartTime = request.getEnd() + (long) (config.getIntervalInMilliseconds() * intervalRatioForRequest); + + CompositeRetriever compositeRetriever = new CompositeRetriever( + dataStartTime, + dataEndTime, + config, + xContentRegistry, + client, + clientUtil, + nextDetectionStartTime, + settings, + maxEntitiesPerInterval, + pageSize, + indexNameExpressionResolver, + clusterService, + analysisType + ); + + PageIterator pageIterator = null; + + try { + pageIterator = compositeRetriever.iterator(); + } catch (Exception e) { + listener.onFailure(new EndRunException(config.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, false)); + return; + } + + Runnable finishRunnable = () -> { + // When pagination finishes or the time is up, + // return response or exceptions. + if (previousException.isPresent()) { + listener.onFailure(previousException.get()); + } else { + listener + .onResponse( + createResultResponse(new ArrayList(), null, null, config.getIntervalInMinutes(), true, taskId) + ); + } + }; + + PageListener getEntityFeatureslistener = new PageListener( + pageIterator, + configID, + dataStartTime, + dataEndTime, + finishRunnable, + taskId + ); + if (pageIterator.hasNext()) { + pageIterator.next(getEntityFeatureslistener); + } + + return; + } + + // HC logic ends and single entity logic starts here + // We are going to use only 1 model partition for a single stream detector. + // That's why we use 0 here. + String rcfModelID = SingleStreamModelIdMapper.getRcfModelId(configID, 0); + Optional asRCFNode = hashRing.getOwningNodeWithSameLocalVersionForRealtime(rcfModelID); + if (!asRCFNode.isPresent()) { + listener.onFailure(new InternalFailure(configID, "RCF model node is not available.")); + return; + } + + DiscoveryNode rcfNode = asRCFNode.get(); + + if (!shouldStart(listener, configID, config, rcfNode.getId(), rcfModelID)) { + return; + } + + featureManager + .getCurrentFeatures( + config, + dataStartTime, + dataEndTime, + onFeatureResponseForSingleStreamConfig(configID, config, listener, rcfModelID, rcfNode, dataStartTime, dataEndTime, taskId) + ); + } + + protected void handleQueryFailure(Exception exception, ActionListener listener, String adID) { + Exception convertedQueryFailureException = convertedQueryFailureException(exception, adID); + + if (convertedQueryFailureException instanceof EndRunException) { + // invalid feature query + listener.onFailure(convertedQueryFailureException); + } else { + ResultProcessor.handleExecuteException(convertedQueryFailureException, listener, adID); + } + } + + /** + * Convert a query related exception to EndRunException + * + * These query exception can happen during the starting phase of the OpenSearch + * process. Thus, set the stopNow parameter of these EndRunException to false + * and confirm the EndRunException is not a false positive. + * + * @param exception Exception + * @param adID detector Id + * @return the converted exception if the exception is query related + */ + private Exception convertedQueryFailureException(Exception exception, String adID) { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + return new EndRunException(adID, ResultProcessor.TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), false) + .countedInStats(false); + } else if (exception instanceof SearchPhaseExecutionException && invalidQuery((SearchPhaseExecutionException) exception)) { + // This is to catch invalid aggregation on wrong field type. For example, + // sum aggregation on text field. We should end detector run for such case. + return new EndRunException( + adID, + CommonMessages.INVALID_SEARCH_QUERY_MSG + " " + ((SearchPhaseExecutionException) exception).getDetailedMessage(), + exception, + false + ).countedInStats(false); + } + + return exception; + } + + protected void findException(Throwable cause, String configID, AtomicReference failure, String nodeId) { + if (cause == null) { + LOG.error(new ParameterizedMessage("Null input exception")); + return; + } + if (cause instanceof Error) { + // we cannot do anything with Error. + LOG.error(new ParameterizedMessage("Error during prediction for {}: ", configID), cause); + return; + } + + Exception causeException = (Exception) cause; + + if (causeException instanceof TimeSeriesException) { + failure.set(causeException); + } else if (causeException instanceof NotSerializableExceptionWrapper) { + // we only expect this happens on AD exceptions + Optional actualException = NotSerializedExceptionName + .convertWrappedTimeSeriesException((NotSerializableExceptionWrapper) causeException, configID); + if (actualException.isPresent()) { + TimeSeriesException adException = actualException.get(); + failure.set(adException); + if (adException instanceof ResourceNotFoundException) { + // During a rolling upgrade or blue/green deployment, ResourceNotFoundException might be caused by old node using RCF + // 1.0 + // cannot recognize new checkpoint produced by the coordinating node using compact RCF. Add pressure to mute the node + // after consecutive failures. + nodeStateManager.addPressure(nodeId, configID); + } + } else { + // some unexpected bugs occur while predicting anomaly + failure.set(new EndRunException(configID, CommonMessages.BUG_RESPONSE, causeException, false)); + } + } else if (causeException instanceof OpenSearchTimeoutException) { + // we can have OpenSearchTimeoutException when a node tries to load RCF or + // threshold model + failure.set(new InternalFailure(configID, causeException)); + } else if (causeException instanceof IllegalArgumentException) { + // we can have IllegalArgumentException when a model is corrupted + failure.set(new InternalFailure(configID, causeException)); + } else { + // some unexpected bug occurred or cluster is unstable (e.g., ClusterBlockException) or index is red (e.g. + // NoShardAvailableActionException) while predicting anomaly + failure.set(new EndRunException(configID, CommonMessages.BUG_RESPONSE, causeException, false)); + } + } + + private boolean invalidQuery(SearchPhaseExecutionException ex) { + // If all shards return bad request and failure cause is IllegalArgumentException, we + // consider the feature query is invalid and will not count the error in failure stats. + for (ShardSearchFailure failure : ex.shardFailures()) { + if (RestStatus.BAD_REQUEST != failure.status() || !(failure.getCause() instanceof IllegalArgumentException)) { + return false; + } + } + return true; + } + + /** + * Handle a prediction failure. Possibly (i.e., we don't always need to do that) + * convert the exception to a form that AD can recognize and handle and sets the + * input failure reference to the converted exception. + * + * @param e prediction exception + * @param adID Detector Id + * @param nodeID Node Id + * @param failure Parameter to receive the possibly converted function for the + * caller to deal with + */ + protected void handlePredictionFailure(Exception e, String adID, String nodeID, AtomicReference failure) { + LOG.error(new ParameterizedMessage("Received an error from node {} while doing model inference for {}", nodeID, adID), e); + if (e == null) { + return; + } + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (hasConnectionIssue(cause)) { + handleConnectionException(nodeID, adID); + } else { + findException(cause, adID, failure, nodeID); + } + } + + /** + * Check if the input exception indicates connection issues. + * During blue-green deployment, we may see ActionNotFoundTransportException. + * Count that as connection issue and isolate that node if it continues to happen. + * + * @param e exception + * @return true if we get disconnected from the node or the node is not in the + * right state (being closed) or transport request times out (sent from TimeoutHandler.run) + */ + private boolean hasConnectionIssue(Throwable e) { + return e instanceof ConnectTransportException + || e instanceof NodeClosedException + || e instanceof ReceiveTimeoutTransportException + || e instanceof NodeNotConnectedException + || e instanceof ConnectException + || NetworkExceptionHelper.isCloseConnectionException(e) + || e instanceof ActionNotFoundTransportException; + } + + private void handleConnectionException(String node, String detectorId) { + final DiscoveryNodes nodes = clusterService.state().nodes(); + if (!nodes.nodeExists(node)) { + hashRing.buildCirclesForRealtime(); + return; + } + // rebuilding is not done or node is unresponsive + nodeStateManager.addPressure(node, detectorId); + } + + /** + * Since we need to read from customer index and write to anomaly result index, + * we need to make sure we can read and write. + * + * @param state Cluster state + * @return whether we have global block or not + */ + private boolean checkGlobalBlock(ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.READ) != null + || state.blocks().globalBlockedException(ClusterBlockLevel.WRITE) != null; + } + + /** + * Similar to checkGlobalBlock, we check block on the indices level. + * + * @param state Cluster state + * @param level block level + * @param indices the indices on which to check block + * @return whether any of the index has block on the level. + */ + private boolean checkIndicesBlocked(ClusterState state, ClusterBlockLevel level, String... indices) { + // the original index might be an index expression with wildcards like "log*", + // so we need to expand the expression to concrete index name + String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(state, IndicesOptions.lenientExpandOpen(), indices); + + return state.blocks().indicesBlockedException(level, concreteIndices) != null; + } + + /** + * Check if we should start anomaly prediction. + * + * @param listener listener to respond back to AnomalyResultRequest. + * @param adID detector ID + * @param detector detector instance corresponds to adID + * @param rcfNodeId the rcf model hosting node ID for adID + * @param rcfModelID the rcf model ID for adID + * @return if we can start anomaly prediction. + */ + private boolean shouldStart( + ActionListener listener, + String adID, + Config detector, + String rcfNodeId, + String rcfModelID + ) { + ClusterState state = clusterService.state(); + if (checkGlobalBlock(state)) { + listener.onFailure(new InternalFailure(adID, ResultProcessor.READ_WRITE_BLOCKED)); + return false; + } + + if (nodeStateManager.isMuted(rcfNodeId, adID)) { + listener + .onFailure( + new InternalFailure( + adID, + String + .format(Locale.ROOT, ResultProcessor.NODE_UNRESPONSIVE_ERR_MSG + " %s for rcf model %s", rcfNodeId, rcfModelID) + ) + ); + return false; + } + + if (checkIndicesBlocked(state, ClusterBlockLevel.READ, detector.getIndices().toArray(new String[0]))) { + listener.onFailure(new InternalFailure(adID, ResultProcessor.INDEX_READ_BLOCKED)); + return false; + } + + return true; + } + + public static void handleExecuteException(Exception ex, ActionListener listener, String id) { + if (ex instanceof ClientException) { + listener.onFailure(ex); + } else if (ex instanceof TimeSeriesException) { + listener.onFailure(new InternalFailure((TimeSeriesException) ex)); + } else { + Throwable cause = ExceptionsHelper.unwrapCause(ex); + listener.onFailure(new InternalFailure(id, cause)); + } + } + + public class ErrorResponseListener implements ActionListener { + private String nodeId; + private final String configId; + private AtomicReference failure; + + public ErrorResponseListener(String nodeId, String configId, AtomicReference failure) { + this.nodeId = nodeId; + this.configId = configId; + this.failure = failure; + } + + @Override + public void onResponse(AcknowledgedResponse response) { + try { + if (response.isAcknowledged() == false) { + LOG.error("Cannot send entities' features to {} for {}", nodeId, configId); + nodeStateManager.addPressure(nodeId, configId); + } else { + nodeStateManager.resetBackpressureCounter(nodeId, configId); + } + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, configId); + handleException(ex); + } + } + + @Override + public void onFailure(Exception e) { + try { + // e.g., we have connection issues with all of the nodes while restarting clusters + LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, configId), e); + + handleException(e); + + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, configId); + handleException(ex); + } + } + + private void handleException(Exception e) { + handlePredictionFailure(e, configId, nodeId, failure); + if (failure.get() != null) { + nodeStateManager.setException(configId, failure.get()); + } + } + } + + protected ActionListener onFeatureResponseForSingleStreamConfig( + String configId, + Config config, + ActionListener listener, + String rcfModelId, + DiscoveryNode rcfNode, + long dataStartTime, + long dataEndTime, + String taskId + ) { + return ActionListener.wrap(featureOptional -> { + Optional previousException = nodeStateManager.fetchExceptionAndClear(configId); + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error(new ParameterizedMessage("Previous exception of [{}]", configId), exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + } + + if (featureOptional.getUnprocessedFeatures().isEmpty()) { + // Feature not available is common when we have data holes. Respond empty response + // and don't log to avoid bloating our logs. + LOG.debug("No data in current window between {} and {} for {}", dataStartTime, dataEndTime, configId); + listener + .onResponse(createResultResponse(new ArrayList(), "No data in current window", null, null, false, taskId)); + return; + } + + final AtomicReference failure = new AtomicReference(); + + LOG.info("Sending single stream request to {} for model {}", rcfNode.getId(), rcfModelId); + + transportService + .sendRequest( + rcfNode, + singleStreamActionName, + new SingleStreamResultRequest( + configId, + rcfModelId, + dataStartTime, + dataEndTime, + featureOptional.getUnprocessedFeatures().get(), + taskId + ), + option, + new ActionListenerResponseHandler<>( + new ErrorResponseListener(rcfNode.getId(), configId, failure), + AcknowledgedResponse::new, + ThreadPool.Names.SAME + ) + ); + + if (previousException.isPresent()) { + listener.onFailure(previousException.get()); + } else if (!featureOptional.getUnprocessedFeatures().isPresent()) { + // Feature not available is common when we have data holes. Respond empty response + // and don't log to avoid bloating our logs. + LOG.debug("No data in current window between {} and {} for {}", dataStartTime, dataEndTime, configId); + listener + .onResponse(createResultResponse(new ArrayList(), "No data in current window", null, null, false, taskId)); + } else { + listener + .onResponse( + createResultResponse(new ArrayList(), null, null, config.getIntervalInMinutes(), true, taskId) + ); + } + }, exception -> { handleQueryFailure(exception, listener, configId); }); + } + + protected abstract ResultResponseType createResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long configInterval, + Boolean isHC, + String taskId + ); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultRequest.java b/src/main/java/org/opensearch/timeseries/transport/ResultRequest.java new file mode 100644 index 000000000..c1e6a345f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultRequest.java @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; + +public abstract class ResultRequest extends ActionRequest implements ToXContentObject { + protected String configId; + // time range start and end. Unit: epoch milliseconds + protected long start; + protected long end; + + public ResultRequest(StreamInput in) throws IOException { + super(in); + configId = in.readString(); + start = in.readLong(); + end = in.readLong(); + } + + public ResultRequest(String configID, long start, long end) { + super(); + this.configId = configID; + this.start = start; + this.end = end; + } + + public long getStart() { + return start; + } + + public void setStart(long start) { + this.start = start; + } + + public long getEnd() { + return end; + } + + public String getConfigId() { + return configId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(configId); + out.writeLong(start); + out.writeLong(end); + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultResponse.java b/src/main/java/org/opensearch/timeseries/transport/ResultResponse.java new file mode 100644 index 000000000..38e566f3d --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultResponse.java @@ -0,0 +1,101 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; + +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IndexableResult; + +public abstract class ResultResponse extends ActionResponse implements ToXContentObject { + + protected String error; + protected List features; + protected Long rcfTotalUpdates; + protected Long configIntervalInMinutes; + protected Boolean isHC; + protected String taskId; + + public ResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long configInterval, + Boolean isHC, + String taskId + ) { + this.error = error; + this.features = features; + this.rcfTotalUpdates = rcfTotalUpdates; + this.configIntervalInMinutes = configInterval; + this.isHC = isHC; + this.taskId = taskId; + } + + /** + * Leave it as implementation detail in subclass as how to deserialize TimeSeriesResultResponse + * @param in deserialization stream + * @throws IOException when deserialization errs + */ + public ResultResponse(StreamInput in) throws IOException { + super(in); + } + + public String getError() { + return error; + } + + public List getFeatures() { + return features; + } + + public Long getRcfTotalUpdates() { + return rcfTotalUpdates; + } + + public Long getConfigIntervalInMinutes() { + return configIntervalInMinutes; + } + + public Boolean isHC() { + return isHC; + } + + public String getTaskId() { + return taskId; + } + + /** + * + * @return whether we should save the response to result index + */ + public boolean shouldSave() { + return error != null; + } + + public abstract List toIndexableResults( + String configId, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + Integer schemaVersion, + User user, + String error + ); +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java b/src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoRequest.java similarity index 80% rename from src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java rename to src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoRequest.java index 8289619c1..1592dd594 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -18,18 +18,18 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -public class SearchAnomalyDetectorInfoRequest extends ActionRequest { +public class SearchConfigInfoRequest extends ActionRequest { private String name; private String rawPath; - public SearchAnomalyDetectorInfoRequest(StreamInput in) throws IOException { + public SearchConfigInfoRequest(StreamInput in) throws IOException { super(in); name = in.readOptionalString(); rawPath = in.readString(); } - public SearchAnomalyDetectorInfoRequest(String name, String rawPath) throws IOException { + public SearchConfigInfoRequest(String name, String rawPath) throws IOException { super(); this.name = name; this.rawPath = rawPath; diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java b/src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoResponse.java similarity index 83% rename from src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java rename to src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoResponse.java index 852c39d1a..67b44953e 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -20,17 +20,17 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.util.RestHandlerUtils; -public class SearchAnomalyDetectorInfoResponse extends ActionResponse implements ToXContentObject { +public class SearchConfigInfoResponse extends ActionResponse implements ToXContentObject { private long count; private boolean nameExists; - public SearchAnomalyDetectorInfoResponse(StreamInput in) throws IOException { + public SearchConfigInfoResponse(StreamInput in) throws IOException { super(in); count = in.readLong(); nameExists = in.readBoolean(); } - public SearchAnomalyDetectorInfoResponse(long count, boolean nameExists) { + public SearchConfigInfoResponse(long count, boolean nameExists) { this.count = count; this.nameExists = nameExists; } diff --git a/src/main/java/org/opensearch/timeseries/transport/SingleStreamResultRequest.java b/src/main/java/org/opensearch/timeseries/transport/SingleStreamResultRequest.java new file mode 100644 index 000000000..4028e1565 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/SingleStreamResultRequest.java @@ -0,0 +1,124 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; + +public class SingleStreamResultRequest extends ActionRequest implements ToXContentObject { + private final String configId; + private final String modelId; + + // data start/end time epoch in milliseconds + private final long startMillis; + private final long endMillis; + private final double[] datapoint; + private final String taskId; + + public SingleStreamResultRequest(String configId, String modelId, long start, long end, double[] datapoint, String taskId) { + super(); + this.configId = configId; + this.modelId = modelId; + this.startMillis = start; + this.endMillis = end; + this.datapoint = datapoint; + this.taskId = taskId; + } + + public SingleStreamResultRequest(StreamInput in) throws IOException { + super(in); + this.configId = in.readString(); + this.modelId = in.readString(); + this.startMillis = in.readLong(); + this.endMillis = in.readLong(); + this.datapoint = in.readDoubleArray(); + this.taskId = in.readOptionalString(); + } + + public String getConfigId() { + return this.configId; + } + + public String getModelId() { + return modelId; + } + + public long getStart() { + return this.startMillis; + } + + public long getEnd() { + return this.endMillis; + } + + public double[] getDataPoint() { + return this.datapoint; + } + + public String getTaskId() { + return taskId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.configId); + out.writeString(this.modelId); + out.writeLong(this.startMillis); + out.writeLong(this.endMillis); + out.writeDoubleArray(datapoint); + out.writeOptionalString(this.taskId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CommonName.CONFIG_ID_KEY, configId); + builder.field(CommonName.MODEL_ID_KEY, modelId); + builder.field(CommonName.START_JSON_KEY, startMillis); + builder.field(CommonName.END_JSON_KEY, endMillis); + builder.array(CommonName.VALUE_LIST_FIELD, datapoint); + builder.field(CommonName.RUN_ONCE_FIELD, taskId); + builder.endObject(); + return builder; + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(configId)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); + } + if (Strings.isEmpty(modelId)) { + validationException = addValidationError(CommonMessages.MODEL_ID_MISSING_MSG, validationException); + } + if (startMillis <= 0 || endMillis <= 0 || startMillis > endMillis) { + validationException = addValidationError( + String.format(Locale.ROOT, "%s: start %d, end %d", CommonMessages.INVALID_TIMESTAMP_ERR_MSG, startMillis, endMillis), + validationException + ); + } + return validationException; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java b/src/main/java/org/opensearch/timeseries/transport/StatsNodeRequest.java similarity index 69% rename from src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java rename to src/main/java/org/opensearch/timeseries/transport/StatsNodeRequest.java index 099bc7db1..a61135d1a 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsNodeRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -18,21 +18,21 @@ import org.opensearch.transport.TransportRequest; /** - * ADStatsNodeRequest to get a nodes stat + * StatsNodeRequest to get a nodes stat */ -public class ADStatsNodeRequest extends TransportRequest { - private ADStatsRequest request; +public class StatsNodeRequest extends TransportRequest { + private StatsRequest request; /** * Constructor */ - public ADStatsNodeRequest() { + public StatsNodeRequest() { super(); } - public ADStatsNodeRequest(StreamInput in) throws IOException { + public StatsNodeRequest(StreamInput in) throws IOException { super(in); - this.request = new ADStatsRequest(in); + this.request = new StatsRequest(in); } /** @@ -40,7 +40,7 @@ public ADStatsNodeRequest(StreamInput in) throws IOException { * * @param request ADStatsRequest */ - public ADStatsNodeRequest(ADStatsRequest request) { + public StatsNodeRequest(StatsRequest request) { this.request = request; } @@ -49,7 +49,7 @@ public ADStatsNodeRequest(ADStatsRequest request) { * * @return ADStatsRequest for this node */ - public ADStatsRequest getADStatsRequest() { + public StatsRequest getADStatsRequest() { return request; } diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/StatsNodeResponse.java similarity index 85% rename from src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java rename to src/main/java/org/opensearch/timeseries/transport/StatsNodeResponse.java index f5296cf17..a1b5b180f 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsNodeResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.Map; @@ -24,7 +24,7 @@ /** * ADStatsNodeResponse */ -public class ADStatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { +public class StatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { private Map statsMap; @@ -34,7 +34,7 @@ public class ADStatsNodeResponse extends BaseNodeResponse implements ToXContentF * @param in StreamInput * @throws IOException throws an IO exception if the StreamInput cannot be read from */ - public ADStatsNodeResponse(StreamInput in) throws IOException { + public StatsNodeResponse(StreamInput in) throws IOException { super(in); this.statsMap = in.readMap(StreamInput::readString, StreamInput::readGenericValue); } @@ -45,7 +45,7 @@ public ADStatsNodeResponse(StreamInput in) throws IOException { * @param node node * @param statsToValues Mapping of stat name to value */ - public ADStatsNodeResponse(DiscoveryNode node, Map statsToValues) { + public StatsNodeResponse(DiscoveryNode node, Map statsToValues) { super(node); this.statsMap = statsToValues; } @@ -57,9 +57,9 @@ public ADStatsNodeResponse(DiscoveryNode node, Map statsToValues * @return ADStatsNodeResponse object corresponding to the input stream * @throws IOException throws an IO exception if the StreamInput cannot be read from */ - public static ADStatsNodeResponse readStats(StreamInput in) throws IOException { + public static StatsNodeResponse readStats(StreamInput in) throws IOException { - return new ADStatsNodeResponse(in); + return new StatsNodeResponse(in); } /** diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java b/src/main/java/org/opensearch/timeseries/transport/StatsNodesResponse.java similarity index 66% rename from src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java rename to src/main/java/org/opensearch/timeseries/transport/StatsNodesResponse.java index 2dbdff03c..7a8ff9901 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsNodesResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; @@ -24,9 +24,9 @@ import org.opensearch.core.xcontent.XContentBuilder; /** - * ADStatsNodesResponse consists of the aggregated responses from the nodes + * StatsNodesResponse consists of the aggregated responses from the nodes */ -public class ADStatsNodesResponse extends BaseNodesResponse implements ToXContentObject { +public class StatsNodesResponse extends BaseNodesResponse implements ToXContentObject { private static final String NODES_KEY = "nodes"; @@ -36,18 +36,18 @@ public class ADStatsNodesResponse extends BaseNodesResponse * @param in StreamInput * @throws IOException thrown when unable to read from stream */ - public ADStatsNodesResponse(StreamInput in) throws IOException { - super(new ClusterName(in), in.readList(ADStatsNodeResponse::readStats), in.readList(FailedNodeException::new)); + public StatsNodesResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(StatsNodeResponse::readStats), in.readList(FailedNodeException::new)); } /** * Constructor * * @param clusterName name of cluster - * @param nodes List of ADStatsNodeResponses from nodes + * @param nodes List of StatsNodeResponse from nodes * @param failures List of failures from nodes */ - public ADStatsNodesResponse(ClusterName clusterName, List nodes, List failures) { + public StatsNodesResponse(ClusterName clusterName, List nodes, List failures) { super(clusterName, nodes, failures); } @@ -57,13 +57,13 @@ public void writeTo(StreamOutput out) throws IOException { } @Override - public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { out.writeList(nodes); } @Override - public List readNodesFrom(StreamInput in) throws IOException { - return in.readList(ADStatsNodeResponse::readStats); + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(StatsNodeResponse::readStats); } @Override @@ -71,7 +71,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws String nodeId; DiscoveryNode node; builder.startObject(NODES_KEY); - for (ADStatsNodeResponse adStats : getNodes()) { + for (StatsNodeResponse adStats : getNodes()) { node = adStats.getNode(); nodeId = node.getId(); builder.startObject(nodeId); diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsRequest.java b/src/main/java/org/opensearch/timeseries/transport/StatsRequest.java similarity index 86% rename from src/main/java/org/opensearch/ad/transport/ADStatsRequest.java rename to src/main/java/org/opensearch/timeseries/transport/StatsRequest.java index 32301e526..f8b5a8896 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.HashSet; @@ -21,9 +21,9 @@ import org.opensearch.core.common.io.stream.StreamOutput; /** - * ADStatsRequest implements a request to obtain stats about the AD plugin + * StatsRequest implements a request to obtain stats about the time series analytics plugin */ -public class ADStatsRequest extends BaseNodesRequest { +public class StatsRequest extends BaseNodesRequest { /** * Key indicating all stats should be retrieved @@ -32,7 +32,7 @@ public class ADStatsRequest extends BaseNodesRequest { private Set statsToBeRetrieved; - public ADStatsRequest(StreamInput in) throws IOException { + public StatsRequest(StreamInput in) throws IOException { super(in); statsToBeRetrieved = in.readSet(StreamInput::readString); } @@ -42,7 +42,7 @@ public ADStatsRequest(StreamInput in) throws IOException { * * @param nodeIds nodeIds of nodes' stats to be retrieved */ - public ADStatsRequest(String... nodeIds) { + public StatsRequest(String... nodeIds) { super(nodeIds); statsToBeRetrieved = new HashSet<>(); } @@ -52,7 +52,7 @@ public ADStatsRequest(String... nodeIds) { * * @param nodes nodes of nodes' stats to be retrieved */ - public ADStatsRequest(DiscoveryNode... nodes) { + public StatsRequest(DiscoveryNode... nodes) { super(nodes); statsToBeRetrieved = new HashSet<>(); } diff --git a/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java b/src/main/java/org/opensearch/timeseries/transport/StatsResponse.java similarity index 60% rename from src/main/java/org/opensearch/ad/stats/ADStatsResponse.java rename to src/main/java/org/opensearch/timeseries/transport/StatsResponse.java index f90e451f9..414951e19 100644 --- a/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.Map; @@ -17,19 +17,18 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.ad.model.Mergeable; -import org.opensearch.ad.transport.ADStatsNodesResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.Mergeable; /** - * ADStatsResponse contains logic to merge the node stats and cluster stats together and return them to user + * StatsResponse contains logic to merge the node stats and cluster stats together and return them to user */ -public class ADStatsResponse implements ToXContentObject, Mergeable { - private ADStatsNodesResponse adStatsNodesResponse; +public class StatsResponse implements ToXContentObject, Mergeable { + private StatsNodesResponse statsNodesResponse; private Map clusterStats; /** @@ -53,23 +52,23 @@ public void setClusterStats(Map clusterStats) { /** * Get cluster stats * - * @return ADStatsNodesResponse + * @return StatsNodesResponse */ - public ADStatsNodesResponse getADStatsNodesResponse() { - return adStatsNodesResponse; + public StatsNodesResponse getStatsNodesResponse() { + return statsNodesResponse; } /** - * Sets adStatsNodesResponse + * Sets statsNodesResponse * - * @param adStatsNodesResponse AD Stats Response from Nodes + * @param statsNodesResponse Stats Response from Nodes */ - public void setADStatsNodesResponse(ADStatsNodesResponse adStatsNodesResponse) { - this.adStatsNodesResponse = adStatsNodesResponse; + public void setStatsNodesResponse(StatsNodesResponse statsNodesResponse) { + this.statsNodesResponse = statsNodesResponse; } /** - * Convert ADStatsResponse to XContent + * Convert StatsResponse to XContent * * @param builder XContentBuilder * @return XContentBuilder @@ -79,15 +78,15 @@ public XContentBuilder toXContent(XContentBuilder builder) throws IOException { return toXContent(builder, ToXContent.EMPTY_PARAMS); } - public ADStatsResponse() {} + public StatsResponse() {} - public ADStatsResponse(StreamInput in) throws IOException { - adStatsNodesResponse = new ADStatsNodesResponse(in); + public StatsResponse(StreamInput in) throws IOException { + statsNodesResponse = new StatsNodesResponse(in); clusterStats = in.readMap(); } public void writeTo(StreamOutput out) throws IOException { - adStatsNodesResponse.writeTo(out); + statsNodesResponse.writeTo(out); out.writeMap(clusterStats); } @@ -97,7 +96,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws for (Map.Entry clusterStat : clusterStats.entrySet()) { builder.field(clusterStat.getKey(), clusterStat.getValue()); } - adStatsNodesResponse.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + statsNodesResponse.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); return xContentBuilder.endObject(); } @@ -107,10 +106,10 @@ public void merge(Mergeable other) { return; } - ADStatsResponse otherResponse = (ADStatsResponse) other; + StatsResponse otherResponse = (StatsResponse) other; - if (otherResponse.adStatsNodesResponse != null) { - this.adStatsNodesResponse = otherResponse.adStatsNodesResponse; + if (otherResponse.statsNodesResponse != null) { + this.statsNodesResponse = otherResponse.statsNodesResponse; } if (otherResponse.clusterStats != null) { @@ -127,23 +126,17 @@ public boolean equals(Object obj) { if (getClass() != obj.getClass()) return false; - ADStatsResponse other = (ADStatsResponse) obj; - return new EqualsBuilder() - .append(adStatsNodesResponse, other.adStatsNodesResponse) - .append(clusterStats, other.clusterStats) - .isEquals(); + StatsResponse other = (StatsResponse) obj; + return new EqualsBuilder().append(statsNodesResponse, other.statsNodesResponse).append(clusterStats, other.clusterStats).isEquals(); } @Override public int hashCode() { - return new HashCodeBuilder().append(adStatsNodesResponse).append(clusterStats).toHashCode(); + return new HashCodeBuilder().append(statsNodesResponse).append(clusterStats).toHashCode(); } @Override public String toString() { - return new ToStringBuilder(this) - .append("adStatsNodesResponse", adStatsNodesResponse) - .append("clusterStats", clusterStats) - .toString(); + return new ToStringBuilder(this).append("statsNodesResponse", statsNodesResponse).append("clusterStats", clusterStats).toString(); } } diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java b/src/main/java/org/opensearch/timeseries/transport/StatsTimeSeriesResponse.java similarity index 57% rename from src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java rename to src/main/java/org/opensearch/timeseries/transport/StatsTimeSeriesResponse.java index c3a108454..ebde3a5f0 100644 --- a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsTimeSeriesResponse.java @@ -9,41 +9,40 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; -import org.opensearch.ad.stats.ADStatsResponse; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -public class StatsAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { - private ADStatsResponse adStatsResponse; +public class StatsTimeSeriesResponse extends ActionResponse implements ToXContentObject { + private StatsResponse statsResponse; - public StatsAnomalyDetectorResponse(StreamInput in) throws IOException { + public StatsTimeSeriesResponse(StreamInput in) throws IOException { super(in); - adStatsResponse = new ADStatsResponse(in); + statsResponse = new StatsResponse(in); } - public StatsAnomalyDetectorResponse(ADStatsResponse adStatsResponse) { - this.adStatsResponse = adStatsResponse; + public StatsTimeSeriesResponse(StatsResponse adStatsResponse) { + this.statsResponse = adStatsResponse; } @Override public void writeTo(StreamOutput out) throws IOException { - adStatsResponse.writeTo(out); + statsResponse.writeTo(out); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - adStatsResponse.toXContent(builder, params); + statsResponse.toXContent(builder, params); return builder; } - protected ADStatsResponse getAdStatsResponse() { - return adStatsResponse; + public StatsResponse getAdStatsResponse() { + return statsResponse; } } diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java b/src/main/java/org/opensearch/timeseries/transport/StopConfigRequest.java similarity index 64% rename from src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java rename to src/main/java/org/opensearch/timeseries/transport/StopConfigRequest.java index 71563a2cd..da70786a3 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/StopConfigRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -19,8 +19,6 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -28,43 +26,45 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; -public class StopDetectorRequest extends ActionRequest implements ToXContentObject { +public class StopConfigRequest extends ActionRequest implements ToXContentObject { - private String adID; + private String configID; - public StopDetectorRequest() {} + public StopConfigRequest() {} - public StopDetectorRequest(StreamInput in) throws IOException { + public StopConfigRequest(StreamInput in) throws IOException { super(in); - this.adID = in.readString(); + this.configID = in.readString(); } - public StopDetectorRequest(String adID) { + public StopConfigRequest(String configID) { super(); - this.adID = adID; + this.configID = configID; } - public String getAdID() { - return adID; + public String getConfigID() { + return configID; } - public StopDetectorRequest adID(String adID) { - this.adID = adID; + public StopConfigRequest adID(String configID) { + this.configID = configID; return this; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(adID); + out.writeString(configID); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(adID)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configID)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } return validationException; } @@ -72,20 +72,20 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(CommonName.CONFIG_ID_KEY, configID); builder.endObject(); return builder; } - public static StopDetectorRequest fromActionRequest(final ActionRequest actionRequest) { - if (actionRequest instanceof StopDetectorRequest) { - return (StopDetectorRequest) actionRequest; + public static StopConfigRequest fromActionRequest(final ActionRequest actionRequest) { + if (actionRequest instanceof StopConfigRequest) { + return (StopConfigRequest) actionRequest; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new StopDetectorRequest(input); + return new StopConfigRequest(input); } } catch (IOException e) { throw new IllegalArgumentException("failed to parse ActionRequest into StopDetectorRequest", e); diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java b/src/main/java/org/opensearch/timeseries/transport/StopConfigResponse.java similarity index 78% rename from src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java rename to src/main/java/org/opensearch/timeseries/transport/StopConfigResponse.java index 00ca68649..d5ab03781 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/StopConfigResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -23,15 +23,15 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -public class StopDetectorResponse extends ActionResponse implements ToXContentObject { +public class StopConfigResponse extends ActionResponse implements ToXContentObject { public static final String SUCCESS_JSON_KEY = "success"; private boolean success; - public StopDetectorResponse(boolean success) { + public StopConfigResponse(boolean success) { this.success = success; } - public StopDetectorResponse(StreamInput in) throws IOException { + public StopConfigResponse(StreamInput in) throws IOException { super(in); success = in.readBoolean(); } @@ -53,15 +53,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static StopDetectorResponse fromActionResponse(final ActionResponse actionResponse) { - if (actionResponse instanceof StopDetectorResponse) { - return (StopDetectorResponse) actionResponse; + public static StopConfigResponse fromActionResponse(final ActionResponse actionResponse) { + if (actionResponse instanceof StopConfigResponse) { + return (StopConfigResponse) actionResponse; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (InputStreamStreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new StopDetectorResponse(input); + return new StopConfigResponse(input); } } catch (IOException e) { throw new IllegalArgumentException("failed to parse ActionResponse into StopDetectorResponse", e); diff --git a/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java new file mode 100644 index 000000000..3c7b9f45a --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java @@ -0,0 +1,80 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.model.Config; + +public class SuggestConfigParamRequest extends ActionRequest { + + private final AnalysisType context; + private final Config config; + private final String param; + private final TimeValue requestTimeout; + + public SuggestConfigParamRequest(StreamInput in) throws IOException { + super(in); + context = in.readEnum(AnalysisType.class); + if (context.isAD()) { + config = new AnomalyDetector(in); + } else if (context.isForecast()) { + config = new Forecaster(in); + } else { + throw new UnsupportedOperationException("This method is not supported"); + } + + param = in.readString(); + requestTimeout = in.readTimeValue(); + } + + public SuggestConfigParamRequest(AnalysisType context, Config config, String param, TimeValue requestTimeout) { + this.context = context; + this.config = config; + this.param = param; + this.requestTimeout = requestTimeout; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeEnum(context); + config.writeTo(out); + out.writeString(param); + out.writeTimeValue(requestTimeout); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + public Config getConfig() { + return config; + } + + public String getParam() { + return param; + } + + public TimeValue getRequestTimeout() { + return requestTimeout; + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamResponse.java b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamResponse.java new file mode 100644 index 000000000..f091b7611 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamResponse.java @@ -0,0 +1,138 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Mergeable; + +public class SuggestConfigParamResponse extends ActionResponse implements ToXContentObject, Mergeable { + public static final String INTERVAL_FIELD = "interval"; + public static final String HORIZON_FIELD = "horizon"; + public static final String HISTORY_FIELD = "history"; + + private IntervalTimeConfiguration interval; + private Integer horizon; + private Integer history; + + public IntervalTimeConfiguration getInterval() { + return interval; + } + + public Integer getHorizon() { + return horizon; + } + + public Integer getHistory() { + return history; + } + + public SuggestConfigParamResponse(IntervalTimeConfiguration interval, Integer horizon, Integer history) { + this.interval = interval; + this.horizon = horizon; + this.history = history; + } + + public SuggestConfigParamResponse(StreamInput in) throws IOException { + super(in); + if (in.readBoolean()) { + this.interval = IntervalTimeConfiguration.readFrom(in); + } else { + this.interval = null; + } + this.horizon = in.readOptionalInt(); + this.history = in.readOptionalInt(); + } + + public static class Builder { + protected IntervalTimeConfiguration interval = null; + protected Integer horizon = null; + protected Integer history = null; + + public Builder() {} + + public Builder interval(IntervalTimeConfiguration interval) { + this.interval = interval; + return this; + } + + public Builder horizon(Integer horizon) { + this.horizon = horizon; + return this; + } + + public Builder history(Integer history) { + this.history = history; + return this; + } + + public SuggestConfigParamResponse build() { + return new SuggestConfigParamResponse(interval, horizon, history); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (interval != null) { + out.writeBoolean(true); + interval.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalInt(horizon); + out.writeOptionalInt(history); + } + + public XContentBuilder toXContent(XContentBuilder builder) throws IOException { + return toXContent(builder, ToXContent.EMPTY_PARAMS); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (interval != null) { + xContentBuilder.field(INTERVAL_FIELD, interval); + } + if (horizon != null) { + xContentBuilder.field(HORIZON_FIELD, horizon); + } + if (history != null) { + xContentBuilder.field(HISTORY_FIELD, history); + } + + return xContentBuilder.endObject(); + } + + @Override + public void merge(Mergeable other) { + if (this == other || other == null || getClass() != other.getClass()) { + return; + } + SuggestConfigParamResponse otherProfile = (SuggestConfigParamResponse) other; + if (otherProfile.getInterval() != null) { + this.interval = otherProfile.getInterval(); + } + if (otherProfile.getHorizon() != null) { + this.horizon = otherProfile.getHorizon(); + } + if (otherProfile.getHistory() != null) { + this.history = otherProfile.getHistory(); + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ValidateConfigRequest.java b/src/main/java/org/opensearch/timeseries/transport/ValidateConfigRequest.java new file mode 100644 index 000000000..0ad5e86e4 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ValidateConfigRequest.java @@ -0,0 +1,123 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.model.Config; + +public class ValidateConfigRequest extends ActionRequest { + + private final AnalysisType context; + private final Config config; + private final String validationType; + private final Integer maxSingleStreamConfigs; + private final Integer maxHCConfigs; + private final Integer maxFeatures; + private final TimeValue requestTimeout; + // added during refactoring for forecasting. It is fine we add a new field + // since the request is handled by the same node. + private final Integer maxCategoricalFields; + + public ValidateConfigRequest(StreamInput in) throws IOException { + super(in); + context = in.readEnum(AnalysisType.class); + if (context.isAD()) { + config = new AnomalyDetector(in); + } else if (context.isForecast()) { + config = new Forecaster(in); + } else { + throw new UnsupportedOperationException("This method is not supported"); + } + + validationType = in.readString(); + maxSingleStreamConfigs = in.readInt(); + maxHCConfigs = in.readInt(); + maxFeatures = in.readInt(); + requestTimeout = in.readTimeValue(); + maxCategoricalFields = in.readInt(); + } + + public ValidateConfigRequest( + AnalysisType context, + Config config, + String validationType, + Integer maxSingleStreamConfigs, + Integer maxHCConfigs, + Integer maxFeatures, + TimeValue requestTimeout, + Integer maxCategoricalFields + ) { + this.context = context; + this.config = config; + this.validationType = validationType; + this.maxSingleStreamConfigs = maxSingleStreamConfigs; + this.maxHCConfigs = maxHCConfigs; + this.maxFeatures = maxFeatures; + this.requestTimeout = requestTimeout; + this.maxCategoricalFields = maxCategoricalFields; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeEnum(context); + config.writeTo(out); + out.writeString(validationType); + out.writeInt(maxSingleStreamConfigs); + out.writeInt(maxHCConfigs); + out.writeInt(maxFeatures); + out.writeTimeValue(requestTimeout); + out.writeInt(maxCategoricalFields); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + public Config getConfig() { + return config; + } + + public String getValidationType() { + return validationType; + } + + public Integer getMaxSingleEntityAnomalyDetectors() { + return maxSingleStreamConfigs; + } + + public Integer getMaxMultiEntityAnomalyDetectors() { + return maxHCConfigs; + } + + public Integer getMaxAnomalyFeatures() { + return maxFeatures; + } + + public TimeValue getRequestTimeout() { + return requestTimeout; + } + + public Integer getMaxCategoricalFields() { + return maxCategoricalFields; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java b/src/main/java/org/opensearch/timeseries/transport/ValidateConfigResponse.java similarity index 75% rename from src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java rename to src/main/java/org/opensearch/timeseries/transport/ValidateConfigResponse.java index d89022241..f3321024e 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/ValidateConfigResponse.java @@ -9,33 +9,33 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; -import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.ConfigValidationIssue; -public class ValidateAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { - private DetectorValidationIssue issue; +public class ValidateConfigResponse extends ActionResponse implements ToXContentObject { + private ConfigValidationIssue issue; - public DetectorValidationIssue getIssue() { + public ConfigValidationIssue getIssue() { return issue; } - public ValidateAnomalyDetectorResponse(DetectorValidationIssue issue) { + public ValidateConfigResponse(ConfigValidationIssue issue) { this.issue = issue; } - public ValidateAnomalyDetectorResponse(StreamInput in) throws IOException { + public ValidateConfigResponse(StreamInput in) throws IOException { super(in); if (in.readBoolean()) { - issue = new DetectorValidationIssue(in); + issue = new ConfigValidationIssue(in); } } diff --git a/src/main/java/org/opensearch/timeseries/transport/handler/IndexMemoryPressureAwareResultHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/IndexMemoryPressureAwareResultHandler.java new file mode 100644 index 000000000..628d95d1c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/handler/IndexMemoryPressureAwareResultHandler.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport.handler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; + +/** + * Different from ResultIndexingHandler and ResultBulkIndexingHandler, this class uses + * customized transport action to bulk index results. These transport action will + * reduce traffic when index memory pressure is high. + * + * + * @param Batch request type + * @param Batch response type + * @param forecasting or AD result index + * @param Index management class + */ +public abstract class IndexMemoryPressureAwareResultHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + private static final Logger LOG = LogManager.getLogger(IndexMemoryPressureAwareResultHandler.class); + + protected final Client client; + protected final IndexManagementType timeSeriesIndices; + + public IndexMemoryPressureAwareResultHandler(Client client, IndexManagementType timeSeriesIndices) { + this.client = client; + this.timeSeriesIndices = timeSeriesIndices; + } + + /** + * Execute the bulk request + * @param currentBulkRequest The bulk request + * @param listener callback after flushing + */ + public void flush(BatchRequestType currentBulkRequest, ActionListener listener) { + try { + // Only create custom result index when creating detector, won’t recreate custom AD result index in realtime + // job and historical analysis later if it’s deleted. If user delete the custom AD result index, and AD plugin + // recreate it, that may bring confusion. + if (!timeSeriesIndices.doesDefaultResultIndexExist()) { + timeSeriesIndices.initDefaultResultIndexDirectly(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + bulk(currentBulkRequest, listener); + } else { + LOG.warn("Creating result index with mappings call not acknowledged."); + listener.onFailure(new TimeSeriesException("", "Creating result index with mappings call not acknowledged.")); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + bulk(currentBulkRequest, listener); + } else { + LOG.warn("Unexpected error creating result index", exception); + listener.onFailure(exception); + } + })); + } else { + bulk(currentBulkRequest, listener); + } + } catch (Exception e) { + LOG.warn("Error in bulking results", e); + listener.onFailure(e); + } + } + + public abstract void bulk(BatchRequestType currentBulkRequest, ActionListener listener); +} diff --git a/src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java similarity index 52% rename from src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java rename to src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java index d61fd1794..4cec1c127 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java +++ b/src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java @@ -9,11 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.transport.handler; +package org.opensearch.timeseries.transport.handler; -import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_RESULT_INDEX; import java.util.List; @@ -24,107 +22,130 @@ import org.opensearch.action.bulk.BulkRequestBuilder; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.IndexableResult; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; import org.opensearch.timeseries.util.RestHandlerUtils; -public class AnomalyResultBulkIndexHandler extends AnomalyIndexHandler { - private static final Logger LOG = LogManager.getLogger(AnomalyResultBulkIndexHandler.class); +/** + * + * Utility method to bulk index results + * + */ +public class ResultBulkIndexingHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement> + extends ResultIndexingHandler { - private ADIndexManagement anomalyDetectionIndices; + private static final Logger LOG = LogManager.getLogger(ResultBulkIndexingHandler.class); - public AnomalyResultBulkIndexHandler( + public ResultBulkIndexingHandler( Client client, Settings settings, ThreadPool threadPool, + String indexName, + IndexManagementType timeSeriesIndices, ClientUtil clientUtil, IndexUtils indexUtils, ClusterService clusterService, - ADIndexManagement anomalyDetectionIndices + Setting backOffDelaySetting, + Setting maxRetrySetting ) { - super(client, settings, threadPool, ANOMALY_RESULT_INDEX_ALIAS, anomalyDetectionIndices, clientUtil, indexUtils, clusterService); - this.anomalyDetectionIndices = anomalyDetectionIndices; + super( + client, + settings, + threadPool, + indexName, + timeSeriesIndices, + clientUtil, + indexUtils, + clusterService, + backOffDelaySetting, + maxRetrySetting + ); } /** - * Bulk index anomaly results. Create anomaly result index first if it doesn't exist. + * Bulk index results. Create result index first if it doesn't exist. * - * @param resultIndex anomaly result index - * @param anomalyResults anomaly results + * @param resultIndex result index + * @param results results to save + * @param configId Config Id * @param listener action listener */ - public void bulkIndexAnomalyResult(String resultIndex, List anomalyResults, ActionListener listener) { - if (anomalyResults == null || anomalyResults.size() == 0) { + public void bulk(String resultIndex, List results, String configId, ActionListener listener) { + if (results == null || results.size() == 0) { listener.onResponse(null); return; } - String detectorId = anomalyResults.get(0).getConfigId(); + try { if (resultIndex != null) { - // Only create custom AD result index when create detector, won’t recreate custom AD result index in realtime + // Only create custom result index when creating detector, won’t recreate custom AD result index in realtime // job and historical analysis later if it’s deleted. If user delete the custom AD result index, and AD plugin // recreate it, that may bring confusion. - if (!anomalyDetectionIndices.doesIndexExist(resultIndex)) { - throw new EndRunException(detectorId, CAN_NOT_FIND_RESULT_INDEX + resultIndex, true); + if (!timeSeriesIndices.doesIndexExist(resultIndex)) { + throw new EndRunException(configId, CommonMessages.CAN_NOT_FIND_RESULT_INDEX + resultIndex, true); } - if (!anomalyDetectionIndices.isValidResultIndexMapping(resultIndex)) { - throw new EndRunException(detectorId, "wrong index mapping of custom AD result index", true); + if (!timeSeriesIndices.isValidResultIndexMapping(resultIndex)) { + throw new EndRunException(configId, "wrong index mapping of custom result index", true); } - bulkSaveDetectorResult(resultIndex, anomalyResults, listener); + bulk(resultIndex, results, listener); return; } - if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { - anomalyDetectionIndices.initDefaultResultIndexDirectly(ActionListener.wrap(response -> { + if (!timeSeriesIndices.doesDefaultResultIndexExist()) { + timeSeriesIndices.initDefaultResultIndexDirectly(ActionListener.wrap(response -> { if (response.isAcknowledged()) { - bulkSaveDetectorResult(anomalyResults, listener); + bulk(results, listener); } else { - String error = "Creating anomaly result index with mappings call not acknowledged"; + String error = "Creating result index with mappings call not acknowledged"; LOG.error(error); listener.onFailure(new TimeSeriesException(error)); } }, exception -> { if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { // It is possible the index has been created while we sending the create request - bulkSaveDetectorResult(anomalyResults, listener); + bulk(results, listener); } else { listener.onFailure(exception); } })); } else { - bulkSaveDetectorResult(anomalyResults, listener); + bulk(results, listener); } } catch (TimeSeriesException e) { listener.onFailure(e); } catch (Exception e) { - String error = "Failed to bulk index anomaly result"; + String error = "Failed to bulk index result"; LOG.error(error, e); listener.onFailure(new TimeSeriesException(error, e)); } } - private void bulkSaveDetectorResult(List anomalyResults, ActionListener listener) { - bulkSaveDetectorResult(ANOMALY_RESULT_INDEX_ALIAS, anomalyResults, listener); + private void bulk(List anomalyResults, ActionListener listener) { + bulk(defaultResultIndexName, anomalyResults, listener); } - private void bulkSaveDetectorResult(String resultIndex, List anomalyResults, ActionListener listener) { + private void bulk(String resultIndex, List results, ActionListener listener) { BulkRequestBuilder bulkRequestBuilder = client.prepareBulk(); - anomalyResults.forEach(anomalyResult -> { + results.forEach(analysisResult -> { try (XContentBuilder builder = jsonBuilder()) { IndexRequest indexRequest = new IndexRequest(resultIndex) - .source(anomalyResult.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); + .source(analysisResult.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); bulkRequestBuilder.add(indexRequest); } catch (Exception e) { - String error = "Failed to prepare request to bulk index anomaly results"; + String error = "Failed to prepare request to bulk index results"; LOG.error(error, e); throw new TimeSeriesException(error); } @@ -132,16 +153,15 @@ private void bulkSaveDetectorResult(String resultIndex, List anom client.bulk(bulkRequestBuilder.request(), ActionListener.wrap(r -> { if (r.hasFailures()) { String failureMessage = r.buildFailureMessage(); - LOG.warn("Failed to bulk index AD result " + failureMessage); + LOG.warn("Failed to bulk index result " + failureMessage); listener.onFailure(new TimeSeriesException(failureMessage)); } else { listener.onResponse(r); } }, e -> { - LOG.error("bulk index ad result failed", e); + LOG.error("bulk index result failed", e); listener.onFailure(e); })); } - } diff --git a/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java similarity index 74% rename from src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java rename to src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java index 9d539f797..01c482903 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java +++ b/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java @@ -9,10 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.transport.handler; +package org.opensearch.timeseries.transport.handler; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_RESULT_INDEX; import java.util.Iterator; import java.util.Locale; @@ -25,38 +24,40 @@ import org.opensearch.action.bulk.BackoffPolicy; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.BulkUtil; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; import org.opensearch.cluster.block.ClusterBlockLevel; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.util.BulkUtil; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; import org.opensearch.timeseries.util.RestHandlerUtils; -public class AnomalyIndexHandler { - private static final Logger LOG = LogManager.getLogger(AnomalyIndexHandler.class); - static final String FAIL_TO_SAVE_ERR_MSG = "Fail to save %s: "; - static final String SUCCESS_SAVING_MSG = "Succeed in saving %s"; - static final String CANNOT_SAVE_ERR_MSG = "Cannot save %s due to write block."; - static final String RETRY_SAVING_ERR_MSG = "Retry in saving %s: "; +public class ResultIndexingHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + private static final Logger LOG = LogManager.getLogger(ResultIndexingHandler.class); + public static final String FAIL_TO_SAVE_ERR_MSG = "Fail to save %s: "; + public static final String SUCCESS_SAVING_MSG = "Succeed in saving %s"; + public static final String CANNOT_SAVE_ERR_MSG = "Cannot save %s due to write block."; + public static final String RETRY_SAVING_ERR_MSG = "Retry in saving %s: "; protected final Client client; protected final ThreadPool threadPool; protected final BackoffPolicy savingBackoffPolicy; - protected final String indexName; - protected final ADIndexManagement anomalyDetectionIndices; + protected final String defaultResultIndexName; + protected final IndexManagementType timeSeriesIndices; // whether save to a specific doc id or not. False by default. protected boolean fixedDoc; protected final ClientUtil clientUtil; @@ -70,30 +71,28 @@ public class AnomalyIndexHandler { * @param settings accessor for node settings. * @param threadPool used to invoke specific threadpool to execute * @param indexName name of index to save to - * @param anomalyDetectionIndices anomaly detection indices + * @param timeSeriesIndices anomaly detection indices * @param clientUtil client wrapper * @param indexUtils Index util classes * @param clusterService accessor to ES cluster service */ - public AnomalyIndexHandler( + public ResultIndexingHandler( Client client, Settings settings, ThreadPool threadPool, String indexName, - ADIndexManagement anomalyDetectionIndices, + IndexManagementType timeSeriesIndices, ClientUtil clientUtil, IndexUtils indexUtils, - ClusterService clusterService + ClusterService clusterService, + Setting backOffDelaySetting, + Setting maxRetrySetting ) { this.client = client; this.threadPool = threadPool; - this.savingBackoffPolicy = BackoffPolicy - .exponentialBackoff( - AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY.get(settings), - AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF.get(settings) - ); - this.indexName = indexName; - this.anomalyDetectionIndices = anomalyDetectionIndices; + this.savingBackoffPolicy = BackoffPolicy.exponentialBackoff(backOffDelaySetting.get(settings), maxRetrySetting.get(settings)); + this.defaultResultIndexName = indexName; + this.timeSeriesIndices = timeSeriesIndices; this.fixedDoc = false; this.clientUtil = clientUtil; this.indexUtils = indexUtils; @@ -111,8 +110,8 @@ public void setFixedDoc(boolean fixedDoc) { } // TODO: check if user has permission to index. - public void index(T toSave, String detectorId, String customIndexName) { - if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.indexName)) { + public void index(ResultType toSave, String detectorId, String customIndexName) { + if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.defaultResultIndexName)) { LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, detectorId)); return; } @@ -122,17 +121,17 @@ public void index(T toSave, String detectorId, String customIndexName) { // Only create custom AD result index when create detector, won’t recreate custom AD result index in realtime // job and historical analysis later if it’s deleted. If user delete the custom AD result index, and AD plugin // recreate it, that may bring confusion. - if (!anomalyDetectionIndices.doesIndexExist(customIndexName)) { - throw new EndRunException(detectorId, CAN_NOT_FIND_RESULT_INDEX + customIndexName, true); + if (!timeSeriesIndices.doesIndexExist(customIndexName)) { + throw new EndRunException(detectorId, CommonMessages.CAN_NOT_FIND_RESULT_INDEX + customIndexName, true); } - if (!anomalyDetectionIndices.isValidResultIndexMapping(customIndexName)) { + if (!timeSeriesIndices.isValidResultIndexMapping(customIndexName)) { throw new EndRunException(detectorId, "wrong index mapping of custom AD result index", true); } save(toSave, detectorId, customIndexName); return; } - if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { - anomalyDetectionIndices + if (!timeSeriesIndices.doesDefaultResultIndexExist()) { + timeSeriesIndices .initDefaultResultIndexDirectly( ActionListener.wrap(initResponse -> onCreateIndexResponse(initResponse, toSave, detectorId), exception -> { if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { @@ -141,7 +140,7 @@ public void index(T toSave, String detectorId, String customIndexName) { } else { throw new TimeSeriesException( detectorId, - String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), + String.format(Locale.ROOT, "Unexpected error creating index %s", defaultResultIndexName), exception ); } @@ -153,32 +152,32 @@ public void index(T toSave, String detectorId, String customIndexName) { } catch (Exception e) { throw new TimeSeriesException( detectorId, - String.format(Locale.ROOT, "Error in saving %s for detector %s", indexName, detectorId), + String.format(Locale.ROOT, "Error in saving %s for detector %s", defaultResultIndexName, detectorId), e ); } } - private void onCreateIndexResponse(CreateIndexResponse response, T toSave, String detectorId) { + private void onCreateIndexResponse(CreateIndexResponse response, ResultType toSave, String detectorId) { if (response.isAcknowledged()) { save(toSave, detectorId); } else { throw new TimeSeriesException( detectorId, - String.format(Locale.ROOT, "Creating %s with mappings call not acknowledged.", indexName) + String.format(Locale.ROOT, "Creating %s with mappings call not acknowledged.", defaultResultIndexName) ); } } - protected void save(T toSave, String detectorId) { - save(toSave, detectorId, indexName); + protected void save(ResultType toSave, String detectorId) { + save(toSave, detectorId, defaultResultIndexName); } // TODO: Upgrade custom result index mapping to latest version? // It may bring some issue if we upgrade the custom result index mapping while user is using that index // for other use cases. One easy solution is to tell user only use custom result index for AD plugin. // For the first release of custom result index, it's not a issue. Will leave this to next phase. - protected void save(T toSave, String detectorId, String indexName) { + protected void save(ResultType toSave, String detectorId, String indexName) { try (XContentBuilder builder = jsonBuilder()) { IndexRequest indexRequest = new IndexRequest(indexName).source(toSave.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); if (fixedDoc) { @@ -192,9 +191,9 @@ protected void save(T toSave, String detectorId, String indexName) { } } - void saveIteration(IndexRequest indexRequest, String detectorId, Iterator backoff) { + void saveIteration(IndexRequest indexRequest, String configId, Iterator backoff) { clientUtil.asyncRequest(indexRequest, client::index, ActionListener.wrap(response -> { - LOG.debug(String.format(Locale.ROOT, SUCCESS_SAVING_MSG, detectorId)); + LOG.debug(String.format(Locale.ROOT, SUCCESS_SAVING_MSG, configId)); }, exception -> { // OpenSearch has a thread pool and a queue for write per node. A thread // pool will have N number of workers ready to handle the requests. When a @@ -210,13 +209,13 @@ void saveIteration(IndexRequest indexRequest, String detectorId, Iterator saveIteration(BulkUtil.cloneIndexRequest(indexRequest), detectorId, backoff), + () -> saveIteration(BulkUtil.cloneIndexRequest(indexRequest), configId, backoff), nextDelay, ThreadPool.Names.SAME ); diff --git a/src/main/java/org/opensearch/timeseries/transport/handler/SearchHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/SearchHandler.java new file mode 100644 index 000000000..e4c9a893e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/handler/SearchHandler.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport.handler; + +import static org.opensearch.timeseries.util.ParseUtils.isAdmin; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.util.ParseUtils; + +/** + * Handle general search request, check user role and return search response. + */ +public class SearchHandler { + private final Logger logger = LogManager.getLogger(SearchHandler.class); + private final Client client; + private volatile Boolean filterEnabled; + + public SearchHandler(Settings settings, ClusterService clusterService, Client client, Setting filterByBackendRoleSetting) { + this.client = client; + filterEnabled = filterByBackendRoleSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterEnabled = it); + } + + /** + * Validate user role, add backend role filter if filter enabled + * and execute search. + * + * @param request search request + * @param actionListener action listerner + */ + public void search(SearchRequest request, ActionListener actionListener) { + User user = ParseUtils.getUserContext(client); + ActionListener listener = wrapRestActionListener(actionListener, CommonMessages.FAIL_TO_SEARCH); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + validateRole(request, user, listener); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + private void validateRole(SearchRequest request, User user, ActionListener listener) { + if (user == null || !filterEnabled || isAdmin(user)) { + // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin + // Case 2: If Security is enabled and filter is disabled, proceed with search as + // user is already authenticated to hit this API. + // case 3: user is admin which means we don't have to check backend role filtering + client.search(request, listener); + } else { + // Security is enabled, filter is enabled and user isn't admin + try { + ParseUtils.addUserBackendRolesFilter(user, request.source()); + logger.debug("Filtering result by " + user.getBackendRoles()); + client.search(request, listener); + } catch (Exception e) { + listener.onFailure(e); + } + } + } + +} diff --git a/src/main/java/org/opensearch/ad/util/BulkUtil.java b/src/main/java/org/opensearch/timeseries/util/BulkUtil.java similarity index 96% rename from src/main/java/org/opensearch/ad/util/BulkUtil.java rename to src/main/java/org/opensearch/timeseries/util/BulkUtil.java index b754b1951..c2b275a1f 100644 --- a/src/main/java/org/opensearch/ad/util/BulkUtil.java +++ b/src/main/java/org/opensearch/timeseries/util/BulkUtil.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.ArrayList; import java.util.HashSet; @@ -23,7 +23,6 @@ import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; -import org.opensearch.timeseries.util.ExceptionUtil; public class BulkUtil { private static final Logger logger = LogManager.getLogger(BulkUtil.class); diff --git a/src/main/java/org/opensearch/ad/util/DateUtils.java b/src/main/java/org/opensearch/timeseries/util/DateUtils.java similarity index 96% rename from src/main/java/org/opensearch/ad/util/DateUtils.java rename to src/main/java/org/opensearch/timeseries/util/DateUtils.java index e7cfc21ce..a76fc5bcb 100644 --- a/src/main/java/org/opensearch/ad/util/DateUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/DateUtils.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.time.Duration; import java.time.Instant; diff --git a/src/main/java/org/opensearch/timeseries/util/DiscoveryNodeFilterer.java b/src/main/java/org/opensearch/timeseries/util/DiscoveryNodeFilterer.java index ca3ba4eba..80ffd5c9f 100644 --- a/src/main/java/org/opensearch/timeseries/util/DiscoveryNodeFilterer.java +++ b/src/main/java/org/opensearch/timeseries/util/DiscoveryNodeFilterer.java @@ -17,10 +17,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.timeseries.constant.CommonName; /** * Util class to filter unwanted node types @@ -91,8 +91,8 @@ public boolean test(DiscoveryNode discoveryNode) { return discoveryNode.isDataNode() && discoveryNode .getAttributes() - .getOrDefault(ADCommonName.BOX_TYPE_KEY, ADCommonName.HOT_BOX_TYPE) - .equals(ADCommonName.HOT_BOX_TYPE); + .getOrDefault(CommonName.BOX_TYPE_KEY, CommonName.HOT_BOX_TYPE) + .equals(CommonName.HOT_BOX_TYPE); } } } diff --git a/src/main/java/org/opensearch/ad/util/IndexUtils.java b/src/main/java/org/opensearch/timeseries/util/IndexUtils.java similarity index 89% rename from src/main/java/org/opensearch/ad/util/IndexUtils.java rename to src/main/java/org/opensearch/timeseries/util/IndexUtils.java index c93511849..cf845dc88 100644 --- a/src/main/java/org/opensearch/ad/util/IndexUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/IndexUtils.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.List; import java.util.Locale; @@ -17,7 +17,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.support.IndicesOptions; -import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.block.ClusterBlockLevel; import org.opensearch.cluster.health.ClusterIndexHealth; @@ -25,7 +24,6 @@ import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.timeseries.util.ClientUtil; public class IndexUtils { /** @@ -41,28 +39,17 @@ public class IndexUtils { private static final Logger logger = LogManager.getLogger(IndexUtils.class); - private Client client; - private ClientUtil clientUtil; private ClusterService clusterService; private final IndexNameExpressionResolver indexNameExpressionResolver; /** * Inject annotation required by Guice to instantiate EntityResultTransportAction (transitive dependency) * - * @param client Client to make calls to OpenSearch - * @param clientUtil AD Client utility * @param clusterService ES ClusterService * @param indexNameExpressionResolver index name resolver */ @Inject - public IndexUtils( - Client client, - ClientUtil clientUtil, - ClusterService clusterService, - IndexNameExpressionResolver indexNameExpressionResolver - ) { - this.client = client; - this.clientUtil = clientUtil; + public IndexUtils(ClusterService clusterService, IndexNameExpressionResolver indexNameExpressionResolver) { this.clusterService = clusterService; this.indexNameExpressionResolver = indexNameExpressionResolver; } diff --git a/src/main/java/org/opensearch/timeseries/util/MultiResponsesDelegateActionListener.java b/src/main/java/org/opensearch/timeseries/util/MultiResponsesDelegateActionListener.java index 5d0998d27..7dd830435 100644 --- a/src/main/java/org/opensearch/timeseries/util/MultiResponsesDelegateActionListener.java +++ b/src/main/java/org/opensearch/timeseries/util/MultiResponsesDelegateActionListener.java @@ -19,8 +19,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.model.Mergeable; import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.model.Mergeable; /** * A listener wrapper to help send multiple requests asynchronously and return one final responses together diff --git a/src/main/java/org/opensearch/timeseries/util/ParseUtils.java b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java index 0978a0de5..5be698e9b 100644 --- a/src/main/java/org/opensearch/timeseries/util/ParseUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java @@ -11,7 +11,6 @@ package org.opensearch.timeseries.util; -import static org.opensearch.ad.constant.ADCommonName.EPOCH_MILLIS_FORMAT; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.search.aggregations.AggregationBuilders.dateRange; import static org.opensearch.search.aggregations.AggregatorFactories.VALID_AGG_NAME; @@ -24,6 +23,7 @@ import java.util.Collection; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Objects; @@ -35,11 +35,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -47,7 +46,9 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.ParsingException; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; @@ -69,7 +70,6 @@ import org.opensearch.search.aggregations.bucket.range.DateRangeAggregationBuilder; import org.opensearch.search.aggregations.metrics.Max; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; @@ -302,23 +302,23 @@ public static AggregatorFactories.Builder parseAggregators(XContentParser parser } public static SearchSourceBuilder generateInternalFeatureQuery( - AnomalyDetector detector, + Config config, long startTime, long endTime, NamedXContentRegistry xContentRegistry ) throws IOException { - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(detector.getTimeField()) + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(config.getTimeField()) .from(startTime) .to(endTime) .format("epoch_millis") .includeLower(true) .includeUpper(false); - BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(detector.getFilterQuery()); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(config.getFilterQuery()); SearchSourceBuilder internalSearchSourceBuilder = new SearchSourceBuilder().query(internalFilterQuery); - if (detector.getFeatureAttributes() != null) { - for (Feature feature : detector.getFeatureAttributes()) { + if (config.getFeatureAttributes() != null) { + for (Feature feature : config.getFeatureAttributes()) { AggregatorFactories.Builder internalAgg = parseAggregators( feature.getAggregation().toString(), xContentRegistry, @@ -366,7 +366,7 @@ public static SearchSourceBuilder generateColdStartQuery( BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(config.getFilterQuery()); if (entity.isPresent()) { - for (TermQueryBuilder term : entity.get().getTermQueryBuilders()) { + for (TermQueryBuilder term : entity.get().getTermQueryForCustomerIndex()) { internalFilterQuery.filter(term); } } @@ -393,12 +393,12 @@ public static SearchSourceBuilder generateColdStartQuery( /** * Map feature data to its Id and name * @param currentFeature Feature data - * @param detector Detector Config object + * @param config Config object * @return a list of feature data with Id and name */ - public static List getFeatureData(double[] currentFeature, AnomalyDetector detector) { - List featureIds = detector.getEnabledFeatureIds(); - List featureNames = detector.getEnabledFeatureNames(); + public static List getFeatureData(double[] currentFeature, Config config) { + List featureIds = config.getEnabledFeatureIds(); + List featureNames = config.getEnabledFeatureNames(); int featureLen = featureIds.size(); List featureData = new ArrayList<>(); for (int i = 0; i < featureLen; i++) { @@ -425,6 +425,7 @@ public static SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSou } else if (query instanceof BoolQueryBuilder) { ((BoolQueryBuilder) query).filter(boolQueryBuilder); } else { + // e.g., wild card query throw new TimeSeriesException("Search API does not support queries other than BoolQuery"); } return searchSourceBuilder; @@ -444,7 +445,20 @@ public static User getUserContext(Client client) { return User.parse(userStr); } - public static void resolveUserAndExecute( + /** + * run the given function based on given user + * @param Config response type. Can be either GetAnomalyDetectorResponse or GetForecasterResponse + * @param requestedUser requested user + * @param configId config Id + * @param filterByEnabled filter by backend is enabled + * @param listener listener. We didn't provide the generic type of listener and therefore can return anything using the listener. + * @param function Function to execute + * @param client Client to OS. + * @param clusterService Cluster service of OS. + * @param xContentRegistry Used to deserialize the get config response. + * @param configTypeClass the class of the ConfigType, used by the ConfigFactory to parse the correct type of Config + */ + public static void resolveUserAndExecute( User requestedUser, String configId, boolean filterByEnabled, @@ -491,10 +505,10 @@ public static void resolveUserAndExecute( * @param filterByBackendRole filter by backend role or not * @param configTypeClass the class of the ConfigType, used by the ConfigFactory to parse the correct type of Config */ - public static void getConfig( + public static void getConfig( User requestUser, String configId, - ActionListener listener, + ActionListener listener, Consumer function, Client client, ClusterService clusterService, @@ -520,7 +534,7 @@ public static void getConfig( configTypeClass ), exception -> { - logger.error("Failed to get anomaly detector: " + configId, exception); + logger.error("Failed to get config: " + configId, exception); listener.onFailure(exception); } ) @@ -542,6 +556,7 @@ public static void getConfig( * provided the user holds the requisite permissions. * * @param The type of Config to be processed in this method, which extends from the Config base type. + * @param The type of ActionResponse to be used, which extends from the ActionResponse base type. * @param response The GetResponse from the getConfig request. This contains the information about the config that is to be processed. * @param requestUser The User from the request. This user's permissions will be checked to ensure they have access to the config. * @param configId The ID of the config. This is used for logging and error messages. @@ -551,11 +566,11 @@ public static void getConfig( * @param filterByBackendRole A boolean indicating whether to filter by backend role. If true, the user's backend roles will be checked to ensure they have access to the config. * @param configTypeClass The class of the ConfigType, used by the ConfigFactory to parse the correct type of Config. */ - public static void onGetConfigResponse( + public static void onGetConfigResponse( GetResponse response, User requestUser, String configId, - ActionListener listener, + ActionListener listener, Consumer function, NamedXContentRegistry xContentRegistry, boolean filterByBackendRole, @@ -574,13 +589,17 @@ public static void onGetConfigResponse( function.accept(config); } else { logger.debug("User: " + requestUser.getName() + " does not have permissions to access config: " + configId); - listener.onFailure(new TimeSeriesException(CommonMessages.NO_PERMISSION_TO_ACCESS_CONFIG + configId)); + listener + .onFailure( + new OpenSearchStatusException(CommonMessages.NO_PERMISSION_TO_ACCESS_CONFIG + configId, RestStatus.FORBIDDEN) + ); } } catch (Exception e) { - listener.onFailure(new TimeSeriesException(CommonMessages.FAIL_TO_GET_USER_INFO + configId)); + logger.error("Fail to parse user out of config", e); + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_GET_USER_INFO + configId, RestStatus.BAD_REQUEST)); } } else { - listener.onFailure(new ResourceNotFoundException(configId, FAIL_TO_FIND_CONFIG_MSG + configId)); + listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); } } @@ -596,7 +615,7 @@ public static boolean isAdmin(User user) { return user.getRoles().contains("all_access"); } - private static boolean checkUserPermissions(User requestedUser, User resourceUser, String detectorId) throws Exception { + private static boolean checkUserPermissions(User requestedUser, User resourceUser, String configId) throws Exception { if (resourceUser.getBackendRoles() == null || requestedUser.getBackendRoles() == null) { return false; } @@ -609,8 +628,8 @@ private static boolean checkUserPermissions(User requestedUser, User resourceUse + requestedUser.getName() + " has backend role: " + backendRole - + " permissions to access detector: " - + detectorId + + " permissions to access config: " + + configId ); return true; } @@ -618,20 +637,19 @@ private static boolean checkUserPermissions(User requestedUser, User resourceUse return false; } - public static boolean checkFilterByBackendRoles(User requestedUser, ActionListener listener) { + public static String checkFilterByBackendRoles(User requestedUser) { if (requestedUser == null) { - return false; + return "Filter by backend roles is enabled and User is null"; } if (requestedUser.getBackendRoles().isEmpty()) { - listener - .onFailure( - new TimeSeriesException( - "Filter by backend roles is enabled and User " + requestedUser.getName() + " does not have backend roles configured" - ) + return String + .format( + Locale.ROOT, + "Filter by backend roles is enabled and User %s does not have backend roles configured", + requestedUser.getName() ); - return false; } - return true; + return null; } /** @@ -651,7 +669,7 @@ public static Optional getLatestDataTime(SearchResponse searchResponse) { /** * Generate batch query request for feature aggregation on given date range. * - * @param detector anomaly detector + * @param config config accessor * @param entity entity * @param startTime start time * @param endTime end time @@ -661,46 +679,46 @@ public static Optional getLatestDataTime(SearchResponse searchResponse) { * @throws TimeSeriesException throw AD exception if no enabled feature */ public static SearchSourceBuilder batchFeatureQuery( - AnomalyDetector detector, + Config config, Entity entity, long startTime, long endTime, NamedXContentRegistry xContentRegistry ) throws IOException { - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(detector.getTimeField()) + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(config.getTimeField()) .from(startTime) .to(endTime) - .format(EPOCH_MILLIS_FORMAT) + .format(CommonName.EPOCH_MILLIS_FORMAT) .includeLower(true) .includeUpper(false); - BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(detector.getFilterQuery()); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(config.getFilterQuery()); - if (detector.isHighCardinality() && entity != null && entity.getAttributes().size() > 0) { + if (config.isHighCardinality() && entity != null && entity.getAttributes().size() > 0) { entity .getAttributes() .entrySet() .forEach(attr -> { internalFilterQuery.filter(new TermQueryBuilder(attr.getKey(), attr.getValue())); }); } - long intervalSeconds = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().getSeconds(); + long intervalSeconds = ((IntervalTimeConfiguration) config.getInterval()).toDuration().getSeconds(); List> sources = new ArrayList<>(); sources .add( new DateHistogramValuesSourceBuilder(CommonName.DATE_HISTOGRAM) - .field(detector.getTimeField()) + .field(config.getTimeField()) .fixedInterval(DateHistogramInterval.seconds((int) intervalSeconds)) ); CompositeAggregationBuilder aggregationBuilder = new CompositeAggregationBuilder(CommonName.FEATURE_AGGS, sources) .size(MAX_BATCH_TASK_PIECE_SIZE); - if (detector.getEnabledFeatureIds().size() == 0) { + if (config.getEnabledFeatureIds().size() == 0) { throw new TimeSeriesException("No enabled feature configured").countedInStats(false); } - for (Feature feature : detector.getFeatureAttributes()) { + for (Feature feature : config.getFeatureAttributes()) { if (feature.getEnabled()) { AggregatorFactories.Builder internalAgg = parseAggregators( feature.getAggregation().toString(), @@ -776,9 +794,9 @@ public static List parseAggregationRequest(XContentParser parser) throws return fieldNames; } - public static List getFeatureFieldNames(AnomalyDetector detector, NamedXContentRegistry xContentRegistry) throws IOException { + public static List getFeatureFieldNames(Config config, NamedXContentRegistry xContentRegistry) throws IOException { List featureFields = new ArrayList<>(); - for (Feature feature : detector.getFeatureAttributes()) { + for (Feature feature : config.getFeatureAttributes()) { featureFields.add(getFieldNamesForFeature(feature, xContentRegistry).get(0)); } return featureFields; diff --git a/src/main/java/org/opensearch/timeseries/util/QueryUtil.java b/src/main/java/org/opensearch/timeseries/util/QueryUtil.java new file mode 100644 index 000000000..e98a5d248 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/QueryUtil.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.util; + +import java.util.Collections; + +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; + +import com.google.common.collect.ImmutableMap; + +public class QueryUtil { + /** + * Generates the painless script to fetch results that have an entity name matching the passed-in category field. + * + * @param categoryField the category field to be used as a source + * @return the painless script used to get all docs with entity name values matching the category field + */ + public static Script getScriptForCategoryField(String categoryField) { + StringBuilder builder = new StringBuilder() + .append("String value = null;") + .append("if (params == null || params._source == null || params._source.entity == null) {") + .append("return \"\"") + .append("}") + .append("for (item in params._source.entity) {") + .append("if (item[\"name\"] == params[\"categoryField\"]) {") + .append("value = item['value'];") + .append("break;") + .append("}") + .append("}") + .append("return value;"); + + // The last argument contains the K/V pair to inject the categoryField value into the script + return new Script( + ScriptType.INLINE, + "painless", + builder.toString(), + Collections.emptyMap(), + ImmutableMap.of("categoryField", categoryField) + ); + } +} diff --git a/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java b/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java index 45e318aa2..47ba48dba 100644 --- a/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import org.apache.commons.lang.ArrayUtils; @@ -45,7 +46,9 @@ import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Feature; import com.google.common.base.Throwables; @@ -63,11 +66,8 @@ public final class RestHandlerUtils { public static final String _PRIMARY_TERM = "_primary_term"; public static final String IF_PRIMARY_TERM = "if_primary_term"; public static final String REFRESH = "refresh"; - public static final String DETECTOR_ID = "detectorID"; public static final String RESULT_INDEX = "resultIndex"; - public static final String ANOMALY_DETECTOR = "anomaly_detector"; - public static final String ANOMALY_DETECTOR_JOB = "anomaly_detector_job"; - public static final String REALTIME_TASK = "realtime_detection_task"; + public static final String REALTIME_TASK = "realtime_task"; public static final String HISTORICAL_ANALYSIS_TASK = "historical_analysis_task"; public static final String RUN = "_run"; public static final String PREVIEW = "_preview"; @@ -79,16 +79,31 @@ public final class RestHandlerUtils { public static final String COUNT = "count"; public static final String MATCH = "match"; public static final String RESULTS = "results"; - public static final String TOP_ANOMALIES = "_topAnomalies"; public static final String VALIDATE = "_validate"; + public static final String SEARCH = "_search"; public static final ToXContent.MapParams XCONTENT_WITH_TYPE = new ToXContent.MapParams(ImmutableMap.of("with_type", "true")); + public static final String REST_STATUS = "rest_status"; + public static final String RUN_ONCE = "_run_once"; + public static final String SUGGEST = "_suggest"; + public static final String RUN_ONCE_TASK = "run_once_task"; public static final String OPENSEARCH_DASHBOARDS_USER_AGENT = "OpenSearch Dashboards"; public static final String[] UI_METADATA_EXCLUDE = new String[] { Config.UI_METADATA_FIELD }; + public static final String NODE_ID = "nodeId"; + public static final String STATS = "stats"; + public static final String STAT = "stat"; + + // AD constants + public static final String DETECTOR_ID = "detectorID"; + public static final String ANOMALY_DETECTOR = "anomaly_detector"; + public static final String ANOMALY_DETECTOR_JOB = "anomaly_detector_job"; + public static final String TOP_ANOMALIES = "_topAnomalies"; + // forecast constants public static final String FORECASTER_ID = "forecasterID"; public static final String FORECASTER = "forecaster"; - public static final String REST_STATUS = "rest_status"; + public static final String FORECASTER_JOB = "forecaster_job"; + public static final String TOP_FORECASTS = "_topForecasts"; private RestHandlerUtils() {} @@ -247,4 +262,32 @@ public static boolean isProperExceptionToReturn(Throwable e) { private static String coalesceToEmpty(@Nullable String s) { return s == null ? "" : s; } + + public static Entity buildEntity(RestRequest request, String detectorId) throws IOException { + if (org.opensearch.core.common.Strings.isEmpty(detectorId)) { + throw new IllegalStateException(CommonMessages.CONFIG_ID_MISSING_MSG); + } + + String entityName = request.param(CommonName.CATEGORICAL_FIELD); + String entityValue = request.param(CommonName.ENTITY_KEY); + + if (entityName != null && entityValue != null) { + // single-stream profile request: + // GET + // _plugins/_anomaly_detection/detectors//_profile/init_progress?category_field=&entity= + return Entity.createSingleAttributeEntity(entityName, entityValue); + } else if (request.hasContent()) { + /* + * HCAD profile request: GET + * _plugins/_anomaly_detection/detectors//_profile/init_progress { + * "entity": [{ "name": "clientip", "value": "13.24.0.0" }] } + */ + Optional entity = Entity.fromJsonObject(request.contentParser()); + if (entity.isPresent()) { + return entity.get(); + } + } + // not a valid profile request with correct entity information + return null; + } } diff --git a/src/main/java/org/opensearch/timeseries/util/TaskUtil.java b/src/main/java/org/opensearch/timeseries/util/TaskUtil.java new file mode 100644 index 000000000..a92d81043 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/TaskUtil.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.util; + +import static org.opensearch.ad.model.ADTaskType.ALL_HISTORICAL_TASK_TYPES; +import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; +import static org.opensearch.ad.model.ADTaskType.REALTIME_TASK_TYPES; + +import java.util.List; + +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.TaskType; + +public class TaskUtil { + public static List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag, AnalysisType analysisType) { + if (analysisType == AnalysisType.FORECAST) { + if (dateRange == null) { + return ForecastTaskType.REALTIME_TASK_TYPES; + } else { + throw new UnsupportedOperationException("Forecasting does not support historical tasks"); + } + } else { + if (dateRange == null) { + return REALTIME_TASK_TYPES; + } else { + if (resetLatestTaskStateFlag) { + // return all task types include HC entity task to make sure we can reset all tasks latest flag + return ALL_HISTORICAL_TASK_TYPES; + } else { + return HISTORICAL_DETECTOR_TASK_TYPES; + } + } + } + + } +} diff --git a/src/main/resources/mappings/anomaly-checkpoint.json b/src/main/resources/mappings/anomaly-checkpoint.json index 5e515a803..af485860a 100644 --- a/src/main/resources/mappings/anomaly-checkpoint.json +++ b/src/main/resources/mappings/anomaly-checkpoint.json @@ -1,7 +1,7 @@ { "dynamic": true, "_meta": { - "schema_version": 3 + "schema_version": 4 }, "properties": { "detectorId": { @@ -35,6 +35,30 @@ }, "modelV2": { "type": "text" + }, + "samples": { + "type": "nested", + "properties": { + "value_list": { + "type": "nested", + "properties": { + "feature_id": { + "type": "keyword" + }, + "data": { + "type": "double" + } + } + }, + "data_start_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "data_end_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + } + } } } } diff --git a/src/main/resources/mappings/anomaly-results.json b/src/main/resources/mappings/anomaly-results.json index 8c377e78e..3fad67ec2 100644 --- a/src/main/resources/mappings/anomaly-results.json +++ b/src/main/resources/mappings/anomaly-results.json @@ -1,7 +1,7 @@ { "dynamic": false, "_meta": { - "schema_version": 5 + "schema_version": 6 }, "properties": { "detector_id": { @@ -25,6 +25,9 @@ "feature_id": { "type": "keyword" }, + "feature_name": { + "type": "keyword" + }, "data": { "type": "double" } diff --git a/src/main/resources/mappings/config.json b/src/main/resources/mappings/config.json index 7db1e6d08..c64a697e7 100644 --- a/src/main/resources/mappings/config.json +++ b/src/main/resources/mappings/config.json @@ -150,6 +150,23 @@ }, "detector_type": { "type": "keyword" + }, + "forecast_interval": { + "properties": { + "period": { + "properties": { + "interval": { + "type": "integer" + }, + "unit": { + "type": "keyword" + } + } + } + } + }, + "horizon": { + "type": "integer" } } } diff --git a/src/main/resources/mappings/forecast-results.json b/src/main/resources/mappings/forecast-results.json index 745d308ad..6e6bbdc92 100644 --- a/src/main/resources/mappings/forecast-results.json +++ b/src/main/resources/mappings/forecast-results.json @@ -1,5 +1,5 @@ { - "dynamic": true, + "dynamic": false, "_meta": { "schema_version": 1 }, @@ -13,6 +13,9 @@ "feature_id": { "type": "keyword" }, + "feature_name": { + "type": "keyword" + }, "data": { "type": "double" } @@ -95,9 +98,6 @@ "task_id": { "type": "keyword" }, - "model_id": { - "type": "keyword" - }, "entity_id": { "type": "keyword" }, diff --git a/src/main/resources/mappings/job.json b/src/main/resources/mappings/job.json index fb26d56d2..5783c701d 100644 --- a/src/main/resources/mappings/job.json +++ b/src/main/resources/mappings/job.json @@ -1,7 +1,7 @@ { "dynamic": false, "_meta": { - "schema_version": 3 + "schema_version": 4 }, "properties": { "schema_version": { @@ -100,6 +100,9 @@ } } } + }, + "type": { + "type": "keyword" } } } diff --git a/src/test/java/org/opensearch/StreamInputOutputTests.java b/src/test/java/org/opensearch/StreamInputOutputTests.java index 82ff5cc24..1fff02fa3 100644 --- a/src/test/java/org/opensearch/StreamInputOutputTests.java +++ b/src/test/java/org/opensearch/StreamInputOutputTests.java @@ -26,15 +26,7 @@ import java.util.Set; import org.opensearch.action.FailedNodeException; -import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.model.ModelProfileOnNode; -import org.opensearch.ad.transport.EntityProfileAction; -import org.opensearch.ad.transport.EntityProfileRequest; -import org.opensearch.ad.transport.EntityProfileResponse; -import org.opensearch.ad.transport.EntityResultRequest; -import org.opensearch.ad.transport.ProfileNodeResponse; -import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.ADEntityProfileAction; import org.opensearch.ad.transport.RCFResultResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -42,7 +34,16 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; +import org.opensearch.timeseries.transport.EntityProfileRequest; +import org.opensearch.timeseries.transport.EntityProfileResponse; +import org.opensearch.timeseries.transport.EntityResultRequest; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileResponse; /** * Put in core package so that we can using Version's package private constructor @@ -98,7 +99,7 @@ private void setUpEntityResultRequest() { entities.put(entity, feature); start = 10L; end = 20L; - entityResultRequest = new EntityResultRequest(detectorId, entities, start, end); + entityResultRequest = new EntityResultRequest(detectorId, entities, start, end, AnalysisType.AD, null); } /** @@ -111,7 +112,7 @@ public void testDeSerializeEntityResultRequest() throws IOException { StreamInput streamInput = output.bytes().streamInput(); EntityResultRequest readRequest = new EntityResultRequest(streamInput); - assertThat(readRequest.getId(), equalTo(detectorId)); + assertThat(readRequest.getConfigId(), equalTo(detectorId)); assertThat(readRequest.getStart(), equalTo(start)); assertThat(readRequest.getEnd(), equalTo(end)); assertTrue(areEqualWithArrayValue(readRequest.getEntities(), entities)); @@ -133,7 +134,7 @@ public void testDeserializeEntityProfileRequest() throws IOException { StreamInput streamInput = output.bytes().streamInput(); EntityProfileRequest readRequest = new EntityProfileRequest(streamInput); - assertThat(readRequest.getAdID(), equalTo(detectorId)); + assertThat(readRequest.getConfigID(), equalTo(detectorId)); assertThat(readRequest.getEntityValue(), equalTo(entity)); assertThat(readRequest.getProfilesToCollect(), equalTo(profilesToCollect)); } @@ -157,7 +158,7 @@ public void testDeserializeEntityProfileResponse() throws IOException { entityProfileResponse.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - EntityProfileResponse readResponse = EntityProfileAction.INSTANCE.getResponseReader().read(streamInput); + EntityProfileResponse readResponse = ADEntityProfileAction.INSTANCE.getResponseReader().read(streamInput); assertThat(readResponse.getModelProfile(), equalTo(entityProfileResponse.getModelProfile())); assertThat(readResponse.getLastActiveMs(), equalTo(entityProfileResponse.getLastActiveMs())); assertThat(readResponse.getTotalUpdates(), equalTo(entityProfileResponse.getTotalUpdates())); diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java index aa2f30b02..72b9caf64 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java @@ -22,9 +22,9 @@ import static org.mockito.Mockito.when; import java.io.IOException; -import java.time.Clock; import java.util.Arrays; import java.util.Locale; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -54,7 +54,6 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.rest.RestRequest; @@ -64,6 +63,7 @@ import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.util.SecurityClientUtil; @@ -78,7 +78,6 @@ */ public class IndexAnomalyDetectorActionHandlerTests extends AbstractTimeSeriesTest { static ThreadPool threadPool; - private ThreadContext threadContext; private String TEXT_FIELD_TYPE = "text"; private IndexAnomalyDetectorActionHandler handler; private ClusterService clusterService; @@ -96,11 +95,11 @@ public class IndexAnomalyDetectorActionHandlerTests extends AbstractTimeSeriesTe private Integer maxSingleEntityAnomalyDetectors; private Integer maxMultiEntityAnomalyDetectors; private Integer maxAnomalyFeatures; + private Integer maxCategoricalFields; private Settings settings; private RestRequest.Method method; private ADTaskManager adTaskManager; private SearchFeatureDao searchFeatureDao; - private Clock clock; @BeforeClass public static void beforeClass() { @@ -122,7 +121,6 @@ public void setUp() throws Exception { settings = Settings.EMPTY; clusterService = mock(ClusterService.class); clientMock = spy(new NodeClient(settings, threadPool)); - clock = mock(Clock.class); NodeStateManager nodeStateManager = mock(NodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, settings); transportService = mock(TransportService.class); @@ -149,6 +147,8 @@ public void setUp() throws Exception { maxAnomalyFeatures = 5; + maxCategoricalFields = 2; + method = RestRequest.Method.POST; adTaskManager = mock(ADTaskManager.class); @@ -160,7 +160,6 @@ public void setUp() throws Exception { clientMock, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -171,6 +170,7 @@ public void setUp() throws Exception { maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -188,8 +188,7 @@ public void testThreeCategoricalFields() throws IOException { ); } - @SuppressWarnings("unchecked") - public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { + public void testMoreThanTenThousandSingleEntityDetectors() throws IOException, InterruptedException { SearchResponse mockResponse = mock(SearchResponse.class); int totalHits = 1001; when(mockResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); @@ -211,7 +210,6 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { clientSpy, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -223,6 +221,7 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -231,7 +230,9 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { Settings.EMPTY ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientMock, never()).execute(eq(GetMappingsAction.INSTANCE), any(), any()); verify(channel).onFailure(response.capture()); @@ -240,14 +241,14 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { String errorMsg = String .format( Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, + IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG, maxSingleEntityAnomalyDetectors ); assertTrue(value.getMessage().contains(errorMsg)); } @SuppressWarnings("unchecked") - public void testTextField() throws IOException { + public void testTextField() throws IOException, InterruptedException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); @@ -289,7 +290,6 @@ public void doE client, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -300,6 +300,7 @@ public void doE maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -310,16 +311,18 @@ public void doE ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); verify(channel).onFailure(response.capture()); Exception value = response.getValue(); assertTrue(value instanceof Exception); - assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + assertTrue(value.getMessage().contains(CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG)); } @SuppressWarnings("unchecked") - private void testValidTypeTemplate(String filedTypeName) throws IOException { + private void testValidTypeTemplate(String filedTypeName) throws IOException, InterruptedException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); @@ -376,7 +379,6 @@ public void doE clientSpy, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -387,6 +389,7 @@ public void doE maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -397,7 +400,9 @@ public void doE ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); verify(clientSpy, times(2)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); verify(channel).onFailure(response.capture()); @@ -406,16 +411,16 @@ public void doE assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG)); } - public void testIpField() throws IOException { + public void testIpField() throws IOException, InterruptedException { testValidTypeTemplate(CommonName.IP_TYPE); } - public void testKeywordField() throws IOException { + public void testKeywordField() throws IOException, InterruptedException { testValidTypeTemplate(CommonName.KEYWORD_TYPE); } @SuppressWarnings("unchecked") - private void testUpdateTemplate(String fieldTypeName) throws IOException { + private void testUpdateTemplate(String fieldTypeName) throws IOException, InterruptedException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); @@ -474,7 +479,6 @@ public void doE clientSpy, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -485,6 +489,7 @@ public void doE maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, RestRequest.Method.PUT, xContentRegistry(), null, @@ -495,7 +500,9 @@ public void doE ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); verify(channel).onFailure(response.capture()); @@ -503,22 +510,22 @@ public void doE if (fieldTypeName.equals(CommonName.IP_TYPE) || fieldTypeName.equals(CommonName.KEYWORD_TYPE)) { assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG)); } else { - assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + assertTrue(value.getMessage().contains(CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG)); } } @Ignore - public void testUpdateIpField() throws IOException { + public void testUpdateIpField() throws IOException, InterruptedException { testUpdateTemplate(CommonName.IP_TYPE); } @Ignore - public void testUpdateKeywordField() throws IOException { + public void testUpdateKeywordField() throws IOException, InterruptedException { testUpdateTemplate(CommonName.KEYWORD_TYPE); } @Ignore - public void testUpdateTextField() throws IOException { + public void testUpdateTextField() throws IOException, InterruptedException { testUpdateTemplate(TEXT_FIELD_TYPE); } @@ -558,7 +565,7 @@ public void doE } @SuppressWarnings("unchecked") - public void testMoreThanTenMultiEntityDetectors() throws IOException { + public void testMoreThanTenMultiEntityDetectors() throws IOException, InterruptedException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); SearchResponse detectorResponse = mock(SearchResponse.class); @@ -580,7 +587,6 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException { clientSpy, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -591,6 +597,7 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException { maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -599,24 +606,22 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException { Settings.EMPTY ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientSpy, times(1)).search(any(SearchRequest.class), any()); verify(channel).onFailure(response.capture()); Exception value = response.getValue(); assertTrue(value instanceof IllegalArgumentException); String errorMsg = String - .format( - Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, - maxMultiEntityAnomalyDetectors - ); + .format(Locale.ROOT, IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG, maxMultiEntityAnomalyDetectors); assertTrue(value.getMessage().contains(errorMsg)); } @Ignore @SuppressWarnings("unchecked") - public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOException { + public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOException, InterruptedException { int totalHits = 10; AnomalyDetector existingDetector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, null); GetResponse getDetectorResponse = TestHelpers @@ -668,7 +673,6 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx clientMock, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -679,6 +683,7 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, RestRequest.Method.PUT, xContentRegistry(), null, @@ -687,7 +692,9 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx Settings.EMPTY ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientMock, times(1)).search(any(SearchRequest.class), any()); @@ -695,12 +702,12 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx verify(channel).onFailure(response.capture()); Exception value = response.getValue(); assertTrue(value instanceof IllegalArgumentException); - assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG)); + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG)); } @Ignore @SuppressWarnings("unchecked") - public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOException { + public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOException, InterruptedException { int totalHits = 10; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a")); GetResponse getDetectorResponse = TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX); @@ -751,7 +758,6 @@ public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOEx clientMock, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -762,6 +768,7 @@ public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOEx maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, RestRequest.Method.PUT, xContentRegistry(), null, @@ -770,7 +777,9 @@ public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOEx Settings.EMPTY ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientMock, times(0)).search(any(SearchRequest.class), any()); diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java index 4873d1501..413892755 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java @@ -23,6 +23,8 @@ import java.time.Clock; import java.util.Arrays; import java.util.Locale; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -31,13 +33,15 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler; import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; @@ -53,6 +57,8 @@ import org.opensearch.timeseries.common.exception.ValidationException; import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.ValidateConfigResponse; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; @@ -60,9 +66,9 @@ public class ValidateAnomalyDetectorActionHandlerTests extends AbstractTimeSeriesTest { - protected AbstractAnomalyDetectorActionHandler handler; + protected ValidateAnomalyDetectorActionHandler handler; protected ClusterService clusterService; - protected ActionListener channel; + protected ActionListener channel; protected TransportService transportService; protected ADIndexManagement anomalyDetectionIndices; protected String detectorId; @@ -74,9 +80,10 @@ public class ValidateAnomalyDetectorActionHandlerTests extends AbstractTimeSerie protected Integer maxSingleEntityAnomalyDetectors; protected Integer maxMultiEntityAnomalyDetectors; protected Integer maxAnomalyFeatures; + protected Integer maxCategoricalFields; protected Settings settings; protected RestRequest.Method method; - protected ADTaskManager adTaskManager; + protected TaskManager adTaskManager; protected SearchFeatureDao searchFeatureDao; protected Clock clock; @@ -116,6 +123,7 @@ public void setUp() throws Exception { maxSingleEntityAnomalyDetectors = 1000; maxMultiEntityAnomalyDetectors = 10; maxAnomalyFeatures = 5; + maxCategoricalFields = 10; method = RestRequest.Method.POST; adTaskManager = mock(ADTaskManager.class); searchFeatureDao = mock(SearchFeatureDao.class); @@ -126,7 +134,7 @@ public void setUp() throws Exception { } @SuppressWarnings("unchecked") - public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOException { + public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOException, InterruptedException { SearchResponse mockResponse = mock(SearchResponse.class); int totalHits = maxSingleEntityAnomalyDetectors + 1; when(mockResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); @@ -150,13 +158,13 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc clusterService, clientSpy, clientUtil, - channel, anomalyDetectionIndices, singleEntityDetector, requestTimeout, maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -165,7 +173,9 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc clock, settings ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientSpy, never()).execute(eq(GetMappingsAction.INSTANCE), any(), any()); verify(channel).onFailure(response.capture()); @@ -174,14 +184,14 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc String errorMsg = String .format( Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, + IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG, maxSingleEntityAnomalyDetectors ); assertTrue(value.getMessage().contains(errorMsg)); } @SuppressWarnings("unchecked") - public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOException { + public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOException, InterruptedException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); @@ -204,13 +214,13 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOExceptio clusterService, clientSpy, clientUtil, - channel, anomalyDetectionIndices, detector, requestTimeout, maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -219,18 +229,16 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOExceptio clock, Settings.EMPTY ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientSpy, never()).execute(eq(GetMappingsAction.INSTANCE), any(), any()); verify(channel).onFailure(response.capture()); Exception value = response.getValue(); assertTrue(value instanceof ValidationException); String errorMsg = String - .format( - Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, - maxMultiEntityAnomalyDetectors - ); + .format(Locale.ROOT, IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG, maxMultiEntityAnomalyDetectors); assertTrue(value.getMessage().contains(errorMsg)); } } diff --git a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java index a98eef88d..69fb5176d 100644 --- a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java @@ -19,7 +19,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import java.time.Clock; import java.util.Arrays; import java.util.HashSet; import java.util.Optional; @@ -33,7 +32,6 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.AnomalyResultTests; import org.opensearch.client.Client; @@ -44,6 +42,7 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.ProfileName; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; @@ -78,12 +77,12 @@ protected enum ErrorResultStatus { protected TransportService transportService; protected ADTaskManager adTaskManager; - protected static Set stateOnly; - protected static Set stateNError; - protected static Set modelProfile; - protected static Set stateInitProgress; - protected static Set totalInitProgress; - protected static Set initProgressErrorProfile; + protected static Set stateOnly; + protected static Set stateNError; + protected static Set modelProfile; + protected static Set stateInitProgress; + protected static Set totalInitProgress; + protected static Set initProgressErrorProfile; protected static String noFullShingleError = "No full shingle in current detection window"; protected static String stoppedError = @@ -113,32 +112,23 @@ protected enum ErrorResultStatus { protected int detectorIntervalMin; protected GetResponse detectorGetReponse; protected String messaingExceptionError = "blah"; + protected ADTaskProfileRunner taskProfileRunner; @BeforeClass public static void setUpOnce() { - stateOnly = new HashSet(); - stateOnly.add(DetectorProfileName.STATE); - stateNError = new HashSet(); - stateNError.add(DetectorProfileName.ERROR); - stateNError.add(DetectorProfileName.STATE); - stateInitProgress = new HashSet(); - stateInitProgress.add(DetectorProfileName.INIT_PROGRESS); - stateInitProgress.add(DetectorProfileName.STATE); - modelProfile = new HashSet( - Arrays - .asList( - DetectorProfileName.SHINGLE_SIZE, - DetectorProfileName.MODELS, - DetectorProfileName.COORDINATING_NODE, - DetectorProfileName.TOTAL_SIZE_IN_BYTES - ) - ); - totalInitProgress = new HashSet( - Arrays.asList(DetectorProfileName.TOTAL_ENTITIES, DetectorProfileName.INIT_PROGRESS) - ); - initProgressErrorProfile = new HashSet( - Arrays.asList(DetectorProfileName.INIT_PROGRESS, DetectorProfileName.ERROR) + stateOnly = new HashSet(); + stateOnly.add(ProfileName.STATE); + stateNError = new HashSet(); + stateNError.add(ProfileName.ERROR); + stateNError.add(ProfileName.STATE); + stateInitProgress = new HashSet(); + stateInitProgress.add(ProfileName.INIT_PROGRESS); + stateInitProgress.add(ProfileName.STATE); + modelProfile = new HashSet( + Arrays.asList(ProfileName.SHINGLE_SIZE, ProfileName.MODELS, ProfileName.COORDINATING_NODE, ProfileName.TOTAL_SIZE_IN_BYTES) ); + totalInitProgress = new HashSet(Arrays.asList(ProfileName.TOTAL_ENTITIES, ProfileName.INIT_PROGRESS)); + initProgressErrorProfile = new HashSet(Arrays.asList(ProfileName.INIT_PROGRESS, ProfileName.ERROR)); clusterName = "test-cluster-name"; discoveryNode1 = new DiscoveryNode( "nodeName1", @@ -163,7 +153,7 @@ public void setUp() throws Exception { super.setUp(); client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); - Clock clock = mock(Clock.class); + taskProfileRunner = mock(ADTaskProfileRunner.class); nodeFilter = mock(DiscoveryNodeFilterer.class); clusterService = mock(ClusterService.class); @@ -178,7 +168,7 @@ public void setUp() throws Exception { Consumer> function = (Consumer>) args[2]; function.accept(Optional.of(TestHelpers.randomAdTask())); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(any(), any(), any(), any(), anyBoolean(), any()); detectorIntervalMin = 3; detectorGetReponse = mock(GetResponse.class); diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java index ed5be8fb0..106465dc6 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java @@ -23,7 +23,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_MIN_SAMPLES; import java.io.IOException; import java.time.Instant; @@ -53,6 +52,7 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; @@ -61,7 +61,6 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.AnomalyResultAction; import org.opensearch.ad.transport.AnomalyResultResponse; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -84,6 +83,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.JobProcessor; +import org.opensearch.timeseries.JobRunner; import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; @@ -94,6 +95,7 @@ import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -118,7 +120,9 @@ public class AnomalyDetectorJobRunnerTests extends AbstractTimeSeriesTest { @Mock private JobExecutionContext context; - private AnomalyDetectorJobRunner runner = AnomalyDetectorJobRunner.getJobRunnerInstance(); + private JobRunner runner = JobRunner.getJobRunnerInstance(); + + private ADJobProcessor adJobProcessor = ADJobProcessor.getInstance(); @Mock private ThreadPool mockedThreadPool; @@ -129,7 +133,7 @@ public class AnomalyDetectorJobRunnerTests extends AbstractTimeSeriesTest { private Iterator backoff; @Mock - private AnomalyIndexHandler anomalyResultHandler; + private ResultBulkIndexingHandler anomalyResultHandler; @Mock private ADTaskManager adTaskManager; @@ -163,7 +167,7 @@ public static void tearDownAfterClass() { @Before public void setup() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(AnomalyDetectorJobRunner.class); + super.setUpLog4jForJUnit(JobProcessor.class); MockitoAnnotations.initMocks(this); ThreadFactory threadFactory = OpenSearchExecutors.daemonThreadFactory(OpenSearchExecutors.threadName("node1", "test-ad")); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); @@ -171,9 +175,9 @@ public void setup() throws Exception { Mockito.doReturn(executorService).when(mockedThreadPool).executor(anyString()); Mockito.doReturn(mockedThreadPool).when(client).threadPool(); Mockito.doReturn(threadContext).when(mockedThreadPool).getThreadContext(); - runner.setThreadPool(mockedThreadPool); - runner.setClient(client); - runner.setAdTaskManager(adTaskManager); + adJobProcessor.setThreadPool(mockedThreadPool); + adJobProcessor.setClient(client); + adJobProcessor.setTaskManager(adTaskManager); Settings settings = Settings .builder() @@ -183,11 +187,11 @@ public void setup() throws Exception { .build(); setUpJobParameter(); - runner.setSettings(settings); + adJobProcessor.registerSettings(settings); anomalyDetectionIndices = mock(ADIndexManagement.class); - runner.setAnomalyDetectionIndices(anomalyDetectionIndices); + adJobProcessor.setIndexManagement(anomalyDetectionIndices); lockService = new LockService(client, clusterService); doReturn(lockService).when(context).getLockService(); @@ -235,7 +239,7 @@ public void setup() throws Exception { listener.onResponse(Optional.of(detector)); return null; }).when(nodeStateManager).getConfig(any(String.class), eq(AnalysisType.AD), any(ActionListener.class)); - runner.setNodeStateManager(nodeStateManager); + adJobProcessor.setNodeStateManager(nodeStateManager); recorder = new ExecuteADResultResponseRecorder( anomalyDetectionIndices, @@ -248,7 +252,7 @@ public void setup() throws Exception { adTaskCacheManager, 32 ); - runner.setExecuteADResultResponseRecorder(recorder); + adJobProcessor.setExecuteResultResponseRecorder(recorder); } @Rule @@ -293,7 +297,7 @@ public void testRunJobWithLockDuration() throws InterruptedException { @Test public void testRunAdJobWithNullLock() { LockModel lock = null; - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, never()).execute(any(), any(), any()); } @@ -301,7 +305,7 @@ public void testRunAdJobWithNullLock() { public void testRunAdJobWithLock() { LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, times(1)).execute(any(), any(), any()); } @@ -311,7 +315,7 @@ public void testRunAdJobWithExecuteException() { doThrow(RuntimeException.class).when(client).execute(any(), any(), any()); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, times(1)).execute(any(), any(), any()); assertTrue(testAppender.containsMessage("Failed to execute AD job")); } @@ -320,8 +324,8 @@ public void testRunAdJobWithExecuteException() { public void testRunAdJobWithEndRunExceptionNow() { LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); - runner - .handleAdException( + adJobProcessor + .handleException( jobParameter, lockService, lock, @@ -406,7 +410,8 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b Instant.now(), 60L, TestHelpers.randomUser(), - jobParameter.getCustomResultIndex() + jobParameter.getCustomResultIndex(), + AnalysisType.AD ).toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS) ), Collections.emptyMap(), @@ -430,8 +435,8 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b return null; }).when(client).index(any(IndexRequest.class), any()); - runner - .handleAdException( + adJobProcessor + .handleException( jobParameter, lockService, lock, @@ -454,8 +459,8 @@ public void testRunAdJobWithEndRunExceptionNowAndGetJobException() { return null; }).when(client).get(any(GetRequest.class), any()); - runner - .handleAdException( + adJobProcessor + .handleException( jobParameter, lockService, lock, @@ -488,8 +493,8 @@ public void testRunAdJobWithEndRunExceptionNowAndFailToGetJob() { return null; }).when(client).get(any(), any()); - runner - .handleAdException( + adJobProcessor + .handleException( jobParameter, lockService, lock, @@ -519,10 +524,10 @@ public void testRunAdJobWithEndRunExceptionNotNowAndRetryUntilStop() throws Inte }).when(client).execute(any(), any(), any()); for (int i = 0; i < 3; i++) { - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); assertEquals(i + 1, testAppender.countMessage("EndRunException happened for")); } - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); assertEquals(1, testAppender.countMessage("JobRunner will stop AD job due to EndRunException retry exceeds upper limit")); } @@ -564,7 +569,8 @@ public Instant confirmInitializedSetup() { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -586,7 +592,7 @@ public void testFailtoFindDetector() { LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); @@ -615,7 +621,7 @@ public void testFailtoFindJob() { LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); @@ -638,7 +644,7 @@ public void testEmptyDetector() { LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); @@ -667,7 +673,7 @@ public void testEmptyJob() { LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); @@ -756,7 +762,7 @@ public void testMarkResultIndexQueried() throws IOException { LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(nodeStateManager, times(1)).getConfig(any(String.class), eq(AnalysisType.AD), any(ActionListener.class)); @@ -766,7 +772,7 @@ public void testMarkResultIndexQueried() throws IOException { ArgumentCaptor totalUpdates = ArgumentCaptor.forClass(Long.class); verify(adTaskManager, times(1)) .updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), totalUpdates.capture(), any(), any(), any()); - assertEquals(NUM_MIN_SAMPLES, totalUpdates.getValue().longValue()); + assertEquals(TimeSeriesSettings.NUM_MIN_SAMPLES, totalUpdates.getValue().longValue()); assertEquals(true, adTaskCacheManager.hasQueriedResultIndex(detector.getId())); } } diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java index cb88bde96..3f5e14f26 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java @@ -42,13 +42,7 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.DetectorState; -import org.opensearch.ad.model.InitProgressProfile; -import org.opensearch.ad.model.ModelProfileOnNode; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileNodeResponse; -import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingResponse; import org.opensearch.cluster.ClusterName; @@ -65,8 +59,15 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ConfigProfile; +import org.opensearch.timeseries.model.ConfigState; +import org.opensearch.timeseries.model.InitProgressProfile; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.ModelProfileOnNode; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileResponse; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.RemoteTransportException; @@ -114,7 +115,8 @@ private void setUpClientGet( nodeFilter, requiredSamples, transportService, - adTaskManager + adTaskManager, + taskProfileRunner ); doAnswer(invocation -> { @@ -208,7 +210,7 @@ public void testDetectorNotExist() throws IOException, InterruptedException { public void testDisabledJobIndexTemplate(JobStatus status) throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, status, RCFPollingStatus.EMPTY, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.DISABLED).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.DISABLED).build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); runner.profile(detector.getId(), ActionListener.wrap(response -> { @@ -229,10 +231,10 @@ public void testJobDisabled() throws IOException, InterruptedException { testDisabledJobIndexTemplate(JobStatus.DISABLED); } - public void testInitOrRunningStateTemplate(RCFPollingStatus status, DetectorState expectedState) throws IOException, + public void testInitOrRunningStateTemplate(RCFPollingStatus status, ConfigState expectedState) throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, status, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(expectedState).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(expectedState).build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); runner.profile(detector.getId(), ActionListener.wrap(response -> { @@ -250,37 +252,37 @@ public void testInitOrRunningStateTemplate(RCFPollingStatus status, DetectorStat } public void testResultNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.INIT_NOT_EXIT, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.INIT_NOT_EXIT, ConfigState.INIT); } public void testRemoteResultNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INIT_NOT_EXIT, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INIT_NOT_EXIT, ConfigState.INIT); } public void testCheckpointIndexNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.INDEX_NOT_FOUND, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.INDEX_NOT_FOUND, ConfigState.INIT); } public void testRemoteCheckpointIndexNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INDEX_NOT_FOUND, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INDEX_NOT_FOUND, ConfigState.INIT); } public void testResultEmpty() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.EMPTY, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.EMPTY, ConfigState.INIT); } public void testResultGreaterThanZero() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.INIT_DONE, DetectorState.RUNNING); + testInitOrRunningStateTemplate(RCFPollingStatus.INIT_DONE, ConfigState.RUNNING); } @SuppressWarnings("unchecked") public void testErrorStateTemplate( RCFPollingStatus initStatus, ErrorResultStatus status, - DetectorState state, + ConfigState state, String error, JobStatus jobStatus, - Set profilesToCollect + Set profilesToCollect ) throws IOException, InterruptedException { ADTask adTask = TestHelpers.randomAdTask(); @@ -291,18 +293,18 @@ public void testErrorStateTemplate( Consumer> function = (Consumer>) args[2]; function.accept(Optional.of(adTask)); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(any(), any(), any(), any(), anyBoolean(), any()); setUpClientExecuteRCFPollingAction(initStatus); setUpClientGet(DetectorStatus.EXIST, jobStatus, initStatus, status); DetectorProfile.Builder builder = new DetectorProfile.Builder(); - if (profilesToCollect.contains(DetectorProfileName.STATE)) { + if (profilesToCollect.contains(ProfileName.STATE)) { builder.state(state); } - if (profilesToCollect.contains(DetectorProfileName.ERROR)) { + if (profilesToCollect.contains(ProfileName.ERROR)) { builder.error(error); } - DetectorProfile expectedProfile = builder.build(); + ConfigProfile expectedProfile = builder.build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); runner.profile(detector.getId(), ActionListener.wrap(response -> { @@ -322,7 +324,7 @@ public void testErrorStateTemplate( public void testErrorStateTemplate( RCFPollingStatus initStatus, ErrorResultStatus status, - DetectorState state, + ConfigState state, String error, JobStatus jobStatus ) throws IOException, @@ -331,14 +333,14 @@ public void testErrorStateTemplate( } public void testRunningNoError() throws IOException, InterruptedException { - testErrorStateTemplate(RCFPollingStatus.INIT_DONE, ErrorResultStatus.NO_ERROR, DetectorState.RUNNING, null, JobStatus.ENABLED); + testErrorStateTemplate(RCFPollingStatus.INIT_DONE, ErrorResultStatus.NO_ERROR, ConfigState.RUNNING, null, JobStatus.ENABLED); } public void testRunningWithError() throws IOException, InterruptedException { testErrorStateTemplate( RCFPollingStatus.INIT_DONE, ErrorResultStatus.SHINGLE_ERROR, - DetectorState.RUNNING, + ConfigState.RUNNING, noFullShingleError, JobStatus.ENABLED ); @@ -348,7 +350,7 @@ public void testDisabledForStateError() throws IOException, InterruptedException testErrorStateTemplate( RCFPollingStatus.INITTING, ErrorResultStatus.STOPPED_ERROR, - DetectorState.DISABLED, + ConfigState.DISABLED, stoppedError, JobStatus.DISABLED ); @@ -358,7 +360,7 @@ public void testDisabledForStateInit() throws IOException, InterruptedException testErrorStateTemplate( RCFPollingStatus.INITTING, ErrorResultStatus.STOPPED_ERROR, - DetectorState.DISABLED, + ConfigState.DISABLED, stoppedError, JobStatus.DISABLED, stateInitProgress @@ -369,7 +371,7 @@ public void testInitWithError() throws IOException, InterruptedException { testErrorStateTemplate( RCFPollingStatus.EMPTY, ErrorResultStatus.SHINGLE_ERROR, - DetectorState.INIT, + ConfigState.INIT, noFullShingleError, JobStatus.ENABLED ); @@ -448,7 +450,7 @@ private void setUpClientExecuteProfileAction() { listener.onResponse(profileResponse); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); } @@ -541,7 +543,7 @@ public void testProfileModels() throws InterruptedException, IOException { public void testInitProgress() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.INITTING, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.INIT).build(); // 123 / 128 rounded to 96% InitProgressProfile profile = new InitProgressProfile("96%", neededSamples * detectorIntervalMin, neededSamples); @@ -560,7 +562,7 @@ public void testInitProgress() throws IOException, InterruptedException { public void testInitProgressFailImmediately() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.NO_DOC, JobStatus.ENABLED, RCFPollingStatus.INITTING, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.INIT).build(); // 123 / 128 rounded to 96% InitProgressProfile profile = new InitProgressProfile("96%", neededSamples * detectorIntervalMin, neededSamples); @@ -579,8 +581,8 @@ public void testInitProgressFailImmediately() throws IOException, InterruptedExc public void testInitNoUpdateNoIndex() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.EMPTY, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder() - .state(DetectorState.INIT) + ConfigProfile expectedProfile = new DetectorProfile.Builder() + .state(ConfigState.INIT) .initProgress(new InitProgressProfile("0%", detectorIntervalMin * requiredSamples, requiredSamples)) .build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -601,8 +603,8 @@ public void testInitNoUpdateNoIndex() throws IOException, InterruptedException { public void testInitNoIndex() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.INDEX_NOT_FOUND, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder() - .state(DetectorState.INIT) + ConfigProfile expectedProfile = new DetectorProfile.Builder() + .state(ConfigState.INIT) .initProgress(new InitProgressProfile("0%", 0, requiredSamples)) .build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -624,7 +626,16 @@ public void testInitNoIndex() throws IOException, InterruptedException { public void testInvalidRequiredSamples() { expectThrows( IllegalArgumentException.class, - () -> new AnomalyDetectorProfileRunner(client, clientUtil, xContentRegistry(), nodeFilter, 0, transportService, adTaskManager) + () -> new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry(), + nodeFilter, + 0, + transportService, + adTaskManager, + taskProfileRunner + ) ); } diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java index 020281ac6..98528cbae 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java @@ -306,7 +306,11 @@ public ToXContentObject[] getConfig(String detectorId, BasicHeader header, boole null, detector.getUser(), detector.getCustomResultIndex(), - detector.getImputationOption() + detector.getImputationOption(), + detector.getRecencyEmphasis(), + detector.getSeasonIntervals(), + detector.getHistoryIntervals(), + null ), detectorJob, historicalAdTask, @@ -639,7 +643,11 @@ protected AnomalyDetector cloneDetector(AnomalyDetector anomalyDetector, String anomalyDetector.getCategoryFields(), null, resultIndex, - anomalyDetector.getImputationOption() + anomalyDetector.getImputationOption(), + anomalyDetector.getRecencyEmphasis(), + anomalyDetector.getSeasonIntervals(), + anomalyDetector.getHistoryIntervals(), + null ); return detector; } diff --git a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java index d94226baa..0ec38eb9a 100644 --- a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java @@ -33,14 +33,7 @@ import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.EntityProfile; -import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.EntityState; -import org.opensearch.ad.model.InitProgressProfile; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.model.ModelProfileOnNode; -import org.opensearch.ad.transport.EntityProfileAction; -import org.opensearch.ad.transport.EntityProfileResponse; +import org.opensearch.ad.transport.ADEntityProfileAction; import org.opensearch.client.Client; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -60,8 +53,15 @@ import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.EntityState; +import org.opensearch.timeseries.model.InitProgressProfile; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; +import org.opensearch.timeseries.transport.EntityProfileResponse; import org.opensearch.timeseries.util.SecurityClientUtil; public class EntityProfileRunnerTests extends AbstractTimeSeriesTest { @@ -69,7 +69,7 @@ public class EntityProfileRunnerTests extends AbstractTimeSeriesTest { private int detectorIntervalMin; private Client client; private SecurityClientUtil clientUtil; - private EntityProfileRunner runner; + private ADEntityProfileRunner runner; private Set state; private Set initNInfo; private Set model; @@ -139,7 +139,7 @@ public void setUp() throws Exception { }).when(nodeStateManager).getConfig(any(String.class), eq(AnalysisType.AD), any(ActionListener.class)); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); - runner = new EntityProfileRunner(client, clientUtil, xContentRegistry(), requiredSamples); + runner = new ADEntityProfileRunner(client, clientUtil, xContentRegistry(), requiredSamples); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -220,7 +220,7 @@ private void setUpExecuteEntityProfileAction(InittedEverResultStatus initted) { listener.onResponse(profileResponseBuilder.build()); return null; - }).when(client).execute(any(EntityProfileAction.class), any(), any()); + }).when(client).execute(any(ADEntityProfileAction.class), any(), any()); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -400,7 +400,7 @@ public void testNotMultiEntityDetector() throws IOException, InterruptedExceptio assertTrue("Should not reach here", false); inProgressLatch.countDown(); }, exception -> { - assertTrue(exception.getMessage().contains(EntityProfileRunner.NOT_HC_DETECTOR_ERR_MSG)); + assertTrue(exception.getMessage().contains(ADEntityProfileRunner.NOT_HC_DETECTOR_ERR_MSG)); inProgressLatch.countDown(); })); assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java index 9b8356081..d2c27a68b 100644 --- a/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java @@ -12,11 +12,6 @@ package org.opensearch.ad; import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_START_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; -import static org.opensearch.ad.model.ADTask.PARENT_TASK_ID_FIELD; -import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM; -import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; import java.io.IOException; @@ -40,7 +35,6 @@ import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.transport.AnomalyDetectorJobAction; -import org.opensearch.ad.transport.AnomalyDetectorJobRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -56,6 +50,8 @@ import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.transport.JobRequest; import org.opensearch.timeseries.transport.JobResponse; import com.google.common.collect.ImmutableList; @@ -180,14 +176,14 @@ public List searchADTasks(String detectorId, String parentTaskId, Boolea BoolQueryBuilder query = new BoolQueryBuilder(); query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); if (isLatest != null) { - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, isLatest)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, isLatest)); } if (parentTaskId != null) { - query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, parentTaskId)); + query.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, parentTaskId)); } SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(query).sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC).trackTotalHits(true).size(size); + sourceBuilder.query(query).sort(TimeSeriesTask.EXECUTION_START_TIME_FIELD, SortOrder.DESC).trackTotalHits(true).size(size); searchRequest.source(sourceBuilder).indices(ADCommonName.DETECTION_STATE_INDEX); SearchResponse searchResponse = client().search(searchRequest).actionGet(); Iterator iterator = searchResponse.getHits().iterator(); @@ -224,28 +220,14 @@ public ADTask startHistoricalAnalysis(Instant startTime, Instant endTime) throws AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); return getADTask(response.getId()); } public ADTask startHistoricalAnalysis(String detectorId, Instant startTime, Instant endTime) throws IOException { DateRange dateRange = new DateRange(startTime, endTime); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); return getADTask(response.getId()); } diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java index 35bf1a29f..588598400 100644 --- a/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java @@ -28,6 +28,7 @@ import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Before; import org.opensearch.ad.mock.model.MockSimpleLog; +import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.client.Response; @@ -36,6 +37,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Feature; @@ -251,9 +253,9 @@ protected List waitUntilTaskDone(String detectorId) throws InterruptedEx protected List waitUntilTaskReachState(String detectorId, Set targetStates) throws InterruptedException { List results = new ArrayList<>(); int i = 0; - ADTaskProfile adTaskProfile = null; + TaskProfile adTaskProfile = null; // Increase retryTimes if some task can't reach done state - while ((adTaskProfile == null || !targetStates.contains(adTaskProfile.getAdTask().getState())) && i < MAX_RETRY_TIMES) { + while ((adTaskProfile == null || !targetStates.contains(adTaskProfile.getTask().getState())) && i < MAX_RETRY_TIMES) { try { adTaskProfile = getADTaskProfile(detectorId); } catch (Exception e) { diff --git a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java index 05a63e3df..bc20f29c1 100644 --- a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java @@ -18,7 +18,6 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import java.time.Clock; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; @@ -48,13 +47,9 @@ import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.DetectorState; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.ad.transport.AnomalyResultTests; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileNodeResponse; -import org.opensearch.ad.transport.ProfileResponse; import org.opensearch.ad.util.*; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; @@ -66,7 +61,12 @@ import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ConfigProfile; +import org.opensearch.timeseries.model.ConfigState; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; @@ -79,7 +79,7 @@ public class MultiEntityProfileRunnerTests extends AbstractTimeSeriesTest { private int requiredSamples; private AnomalyDetector detector; private String detectorId; - private Set stateNError; + private Set stateNError; private DetectorInternalState.Builder result; private String node1; private String nodeName1; @@ -97,6 +97,7 @@ public class MultiEntityProfileRunnerTests extends AbstractTimeSeriesTest { private Job job; private TransportService transportService; private ADTaskManager adTaskManager; + private ADTaskProfileRunner taskProfileRunner; enum InittedEverResultStatus { INITTED, @@ -119,7 +120,7 @@ public static void tearDownAfterClass() { public void setUp() throws Exception { super.setUp(); client = mock(Client.class); - Clock clock = mock(Clock.class); + taskProfileRunner = mock(ADTaskProfileRunner.class); NodeStateManager nodeStateManager = mock(NodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); nodeFilter = mock(DiscoveryNodeFilterer.class); @@ -137,7 +138,7 @@ public void setUp() throws Exception { function.accept(Optional.of(TestHelpers.randomAdTask())); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(any(), any(), any(), any(), anyBoolean(), any()); runner = new AnomalyDetectorProfileRunner( client, clientUtil, @@ -145,7 +146,8 @@ public void setUp() throws Exception { nodeFilter, requiredSamples, transportService, - adTaskManager + adTaskManager, + taskProfileRunner ); doAnswer(invocation -> { @@ -165,9 +167,9 @@ public void setUp() throws Exception { return null; }).when(client).get(any(), any()); - stateNError = new HashSet(); - stateNError.add(DetectorProfileName.ERROR); - stateNError.add(DetectorProfileName.STATE); + stateNError = new HashSet(); + stateNError.add(ProfileName.ERROR); + stateNError.add(ProfileName.STATE); } @SuppressWarnings("unchecked") @@ -248,7 +250,7 @@ private void setUpClientExecuteProfileAction(InittedEverResultStatus initted) { listener.onResponse(profileResponse); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); } @@ -285,7 +287,7 @@ public void testInit() throws InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.INIT).build(); runner.profile(detectorId, ActionListener.wrap(response -> { assertEquals(expectedProfile, response); inProgressLatch.countDown(); @@ -302,7 +304,7 @@ public void testRunning() throws InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.RUNNING).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.RUNNING).build(); runner.profile(detectorId, ActionListener.wrap(response -> { assertEquals(expectedProfile, response); inProgressLatch.countDown(); @@ -323,7 +325,7 @@ public void testResultIndexFinalTruth() throws InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.RUNNING).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.RUNNING).build(); runner.profile(detectorId, ActionListener.wrap(response -> { assertEquals(expectedProfile, response); inProgressLatch.countDown(); diff --git a/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java b/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java index 0f513b502..0c1a6812d 100644 --- a/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java +++ b/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java @@ -42,13 +42,13 @@ import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; -import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.rest.ADRestTestUtils; import org.opensearch.client.Response; import org.opensearch.common.settings.Settings; import org.opensearch.core.rest.RestStatus; import org.opensearch.test.rest.OpenSearchRestTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.util.ExceptionUtil; import org.opensearch.timeseries.util.RestHandlerUtils; @@ -435,7 +435,7 @@ private List startAnomalyDetector(Response response, boolean historicalD Map responseMap = entityAsMap(response); String detectorId = (String) responseMap.get("_id"); int version = (int) responseMap.get("_version"); - assertNotEquals("response is missing Id", AnomalyDetector.NO_ID, detectorId); + assertNotEquals("response is missing Id", Config.NO_ID, detectorId); assertTrue("incorrect version", version > 0); Response startDetectorResponse = TestHelpers diff --git a/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java b/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java index 2c990682a..8b72773ea 100644 --- a/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java +++ b/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java @@ -19,34 +19,38 @@ import java.time.Instant; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Optional; import java.util.Random; import org.junit.Before; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +import test.org.opensearch.ad.util.MLUtil; + public class AbstractCacheTest extends AbstractTimeSeriesTest { protected String modelId1, modelId2, modelId3, modelId4; protected Entity entity1, entity2, entity3, entity4; - protected ModelState modelState1, modelState2, modelState3, modelState4; + protected ModelState modelState1, modelState2, modelState3, modelState4; protected String detectorId; protected AnomalyDetector detector; protected Clock clock; protected Duration detectorDuration; protected float initialPriority; - protected CacheBuffer cacheBuffer; + protected ADCacheBuffer cacheBuffer; protected long memoryPerEntity; protected MemoryTracker memoryTracker; - protected CheckpointWriteWorker checkpointWriteQueue; - protected CheckpointMaintainWorker checkpointMaintainQueue; + protected ADCheckpointWriteWorker checkpointWriteQueue; + protected ADCheckpointMaintainWorker checkpointMaintainQueue; protected Random random; protected int shingleSize; @@ -85,58 +89,66 @@ public void setUp() throws Exception { memoryPerEntity = 81920; memoryTracker = mock(MemoryTracker.class); - checkpointWriteQueue = mock(CheckpointWriteWorker.class); - checkpointMaintainQueue = mock(CheckpointMaintainWorker.class); + checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); + checkpointMaintainQueue = mock(ADCheckpointMaintainWorker.class); - cacheBuffer = new CacheBuffer( - 1, + cacheBuffer = new ADCacheBuffer( 1, - memoryPerEntity, - memoryTracker, clock, + memoryTracker, + 1, TimeSeriesSettings.HOURLY_MAINTENANCE, - detectorId, + memoryPerEntity, checkpointWriteQueue, checkpointMaintainQueue, + detectorId, Duration.ofHours(12).toHoursPart() ); initialPriority = cacheBuffer.getPriorityTracker().getUpdatedPriority(0); - modelState1 = new ModelState<>( - new EntityModel(entity1, new ArrayDeque<>(), null), + modelState1 = new ModelState( + MLUtil.createNonEmptyModel(detectorId, 0, entity1).getLeft(), modelId1, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + Optional.of(entity1), + new ArrayDeque<>() ); - modelState2 = new ModelState<>( - new EntityModel(entity2, new ArrayDeque<>(), null), + modelState2 = new ModelState( + MLUtil.createNonEmptyModel(detectorId, 0, entity2).getLeft(), modelId2, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + Optional.of(entity2), + new ArrayDeque<>() ); - modelState3 = new ModelState<>( - new EntityModel(entity3, new ArrayDeque<>(), null), + modelState3 = new ModelState( + MLUtil.createNonEmptyModel(detectorId, 0, entity3).getLeft(), modelId3, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + Optional.of(entity3), + new ArrayDeque<>() ); - modelState4 = new ModelState<>( - new EntityModel(entity4, new ArrayDeque<>(), null), + modelState4 = new ModelState( + MLUtil.createNonEmptyModel(detectorId, 0, entity4).getLeft(), modelId4, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + Optional.of(entity4), + new ArrayDeque<>() ); } } diff --git a/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java b/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java index 265560ab5..0eb4ac947 100644 --- a/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java +++ b/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java @@ -22,8 +22,8 @@ import java.util.Optional; import org.mockito.ArgumentCaptor; -import org.opensearch.ad.ratelimit.CheckpointMaintainRequest; import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; @@ -69,7 +69,7 @@ public void testRemovalCandidate2() throws InterruptedException { cacheBuffer.put(modelId2, modelState2); cacheBuffer.put(modelId2, modelState2); cacheBuffer.put(modelId4, modelState4); - assertTrue(cacheBuffer.getModel(modelId2).isPresent()); + assertTrue(cacheBuffer.getModelState(modelId2) != null); ArgumentCaptor memoryReleased = ArgumentCaptor.forClass(Long.class); ArgumentCaptor reserved = ArgumentCaptor.forClass(Boolean.class); @@ -93,10 +93,10 @@ public void testCanRemove() { String modelId2 = "2"; String modelId3 = "3"; assertTrue(cacheBuffer.dedicatedCacheAvailable()); - assertTrue(!cacheBuffer.canReplaceWithinDetector(100)); + assertTrue(!cacheBuffer.canReplaceWithinConfig(100)); cacheBuffer.put(modelId1, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); - assertTrue(cacheBuffer.canReplaceWithinDetector(100)); + assertTrue(cacheBuffer.canReplaceWithinConfig(100)); assertTrue(!cacheBuffer.dedicatedCacheAvailable()); assertTrue(!cacheBuffer.canRemove()); cacheBuffer.put(modelId2, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); @@ -117,7 +117,7 @@ public void testMaintenance() { cacheBuffer.put(modelId3, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); cacheBuffer.maintenance(); assertEquals(3, cacheBuffer.getActiveEntities()); - assertEquals(3, cacheBuffer.getAllModels().size()); + assertEquals(3, cacheBuffer.getAllModelStates().size()); // the year of 2122, 100 years later to simulate we are gonna remove all cached entries when(clock.instant()).thenReturn(Instant.ofEpochSecond(4814540761L)); cacheBuffer.maintenance(); @@ -167,7 +167,7 @@ public void testMaintainByHourSaveOne() { verify(checkpointMaintainQueue, times(1)).putAll(savedStates.capture()); List toSave = savedStates.getValue(); assertEquals(1, toSave.size()); - assertEquals(modelId1, toSave.get(0).getEntityModelId()); + assertEquals(modelId1, toSave.get(0).getModelId()); } /** diff --git a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java index 4154687cf..bcbf54b67 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java @@ -35,6 +35,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -44,11 +45,8 @@ import org.apache.logging.log4j.Logger; import org.junit.Before; import org.mockito.ArgumentCaptor; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -65,15 +63,20 @@ import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + public class PriorityCacheTests extends AbstractCacheTest { private static final Logger LOG = LogManager.getLogger(PriorityCacheTests.class); - EntityCache entityCache; - CheckpointDao checkpoint; - ModelManager modelManager; + ADPriorityCache entityCache; + ADCheckpointDao checkpoint; + ADModelManager modelManager; ClusterService clusterService; Settings settings; @@ -87,9 +90,9 @@ public class PriorityCacheTests extends AbstractCacheTest { public void setUp() throws Exception { super.setUp(); - checkpoint = mock(CheckpointDao.class); + checkpoint = mock(ADCheckpointDao.class); - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); clusterService = mock(ClusterService.class); ClusterSettings settings = new ClusterSettings( @@ -115,7 +118,7 @@ public void setUp() throws Exception { threadPool = mock(ThreadPool.class); setUpADThreadPool(threadPool); - EntityCache cache = new PriorityCache( + ADPriorityCache cache = new ADPriorityCache( checkpoint, dedicatedCacheSize, AnomalyDetectorSettings.AD_CHECKPOINT_TTL, @@ -126,14 +129,14 @@ public void setUp() throws Exception { clusterService, TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, - checkpointWriteQueue, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointMaintainQueue, Settings.EMPTY, - AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + checkpointWriteQueue, + checkpointMaintainQueue ); - CacheProvider cacheProvider = new CacheProvider(); + ADCacheProvider cacheProvider = new ADCacheProvider(); cacheProvider.set(cache); entityCache = cacheProvider.get(); @@ -171,7 +174,7 @@ public void testCacheHit() { memoryTracker = spy(new MemoryTracker(jvmService, modelMaxPercen, clusterService, mock(CircuitBreakerService.class))); - EntityCache cache = new PriorityCache( + ADPriorityCache cache = new ADPriorityCache( checkpoint, dedicatedCacheSize, AnomalyDetectorSettings.AD_CHECKPOINT_TTL, @@ -182,14 +185,14 @@ public void testCacheHit() { clusterService, TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, - checkpointWriteQueue, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointMaintainQueue, Settings.EMPTY, - AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + checkpointWriteQueue, + checkpointMaintainQueue ); - CacheProvider cacheProvider = new CacheProvider(); + ADCacheProvider cacheProvider = new ADCacheProvider(); cacheProvider.set(cache); entityCache = cacheProvider.get(); @@ -200,13 +203,14 @@ public void testCacheHit() { entityCache.hostIfPossible(detector, modelState1); assertEquals(1, entityCache.getTotalActiveEntities()); assertEquals(1, entityCache.getAllModels().size()); - ModelState hitState = entityCache.get(modelState1.getModelId(), detector); - assertEquals(detectorId, hitState.getId()); - EntityModel model = hitState.getModel(); - assertEquals(false, model.getTrcf().isPresent()); - assertTrue(model.getSamples().isEmpty()); - modelState1.getModel().addSample(point); - assertTrue(Arrays.equals(point, model.getSamples().peek())); + ModelState hitState = entityCache.get(modelState1.getModelId(), detector); + assertEquals(detectorId, hitState.getConfigId()); + Optional model = hitState.getModel(); + assertTrue(model.isEmpty()); + assertTrue(hitState.getSamples().isEmpty()); + Sample sample = new Sample(point, Instant.now(), Instant.now()); + modelState1.addSample(sample); + assertTrue(Arrays.equals(point, hitState.getSamples().peek().getValueList())); ArgumentCaptor memoryConsumed = ArgumentCaptor.forClass(Long.class); ArgumentCaptor reserved = ArgumentCaptor.forClass(Boolean.class); @@ -260,12 +264,14 @@ public void testSharedCache() { entityCache.get(modelId3, detector2); } modelState3 = new ModelState<>( - new EntityModel(entity3, new ArrayDeque<>(), null), + null, modelId3, detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + Optional.of(entity3), + new ArrayDeque<>() ); entityCache.hostIfPossible(detector2, modelState3); @@ -276,12 +282,14 @@ public void testSharedCache() { entityCache.get(modelId4, detector2); } modelState4 = new ModelState<>( - new EntityModel(entity4, new ArrayDeque<>(), null), + null, modelId4, detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + Optional.of(entity4), + new ArrayDeque<>() ); entityCache.hostIfPossible(detector2, modelState4); assertEquals(2, entityCache.getActiveEntities(detectorId2)); @@ -303,7 +311,7 @@ public void testReplace() { entityCache.hostIfPossible(detector, modelState1); assertEquals(1, entityCache.getActiveEntities(detectorId)); when(memoryTracker.canAllocate(anyLong())).thenReturn(false); - ModelState state = null; + ModelState state = null; for (int i = 0; i < 4; i++) { entityCache.get(modelId2, detector); @@ -366,7 +374,7 @@ public void testClear() { assertEquals(2, entityCache.getTotalActiveEntities()); assertTrue(entityCache.isActive(detectorId, modelId1)); assertEquals(0, entityCache.getTotalUpdates(detectorId)); - modelState1.getModel().addSample(point); + modelState1.addSample(new Sample(point, Instant.now(), Instant.now())); assertEquals(1, entityCache.getTotalUpdates(detectorId)); assertEquals(1, entityCache.getTotalUpdates(detectorId, modelId1)); entityCache.clear(detectorId); @@ -538,21 +546,25 @@ public void testSelectToReplaceInCache() { private void replaceInOtherCacheSetUp() { Entity entity5 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal5"); Entity entity6 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal6"); - ModelState modelState5 = new ModelState<>( - new EntityModel(entity5, new ArrayDeque<>(), null), + ModelState modelState5 = new ModelState<>( + null, entity5.getModelId(detectorId2).get(), detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + Optional.of(entity5), + new ArrayDeque<>() ); - ModelState modelState6 = new ModelState<>( - new EntityModel(entity6, new ArrayDeque<>(), null), + ModelState modelState6 = new ModelState<>( + null, entity6.getModelId(detectorId2).get(), detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + Optional.of(entity6), + new ArrayDeque<>() ); for (int i = 0; i < 3; i++) { @@ -660,7 +672,7 @@ public void testLongDetectorInterval() { String modelId = entity1.getModelId(detectorId).get(); // record last access time 1000 assertTrue(null == entityCache.get(modelId, detector)); - assertEquals(-1, entityCache.getLastActiveMs(detectorId, modelId)); + assertEquals(-1, entityCache.getLastActiveTime(detectorId, modelId)); // 2 hour = 7200 seconds have passed long currentTimeEpoch = 8200; when(clock.instant()).thenReturn(Instant.ofEpochSecond(currentTimeEpoch)); @@ -669,7 +681,7 @@ public void testLongDetectorInterval() { // door keeper still has the record and won't blocks entity state being created entityCache.get(modelId, detector); // * 1000 to convert to milliseconds - assertEquals(currentTimeEpoch * 1000, entityCache.getLastActiveMs(detectorId, modelId)); + assertEquals(currentTimeEpoch * 1000, entityCache.getLastActiveTime(detectorId, modelId)); } finally { ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.DOOR_KEEPER_IN_CACHE_ENABLED, false); } @@ -725,7 +737,7 @@ public void testRemoveEntityModel() { assertTrue(null != entityCache.get(entity2.getModelId(detectorId).get(), detector)); - entityCache.removeEntityModel(detectorId, entity2.getModelId(detectorId).get()); + entityCache.removeModel(detectorId, entity2.getModelId(detectorId).get()); assertTrue(null == entityCache.get(entity2.getModelId(detectorId).get(), detector)); diff --git a/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java b/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java index 4e721d68e..09cc23bd6 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java @@ -21,6 +21,7 @@ import org.junit.Before; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.caching.PriorityTracker; public class PriorityTrackerTests extends OpenSearchTestCase { Clock clock; diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java index c5acc7064..78ecba766 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java @@ -13,10 +13,10 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.transport.GetAnomalyDetectorRequest; import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.common.lucene.uid.Versions; import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.transport.GetConfigRequest; public class AnomalyDetectionClientTests { @@ -48,7 +48,7 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener listener) { + public void getDetectorProfile(GetConfigRequest profileRequest, ActionListener listener) { listener.onResponse(profileResponse); } }; @@ -66,16 +66,7 @@ public void searchAnomalyResults() { @Test public void getDetectorProfile() { - GetAnomalyDetectorRequest profileRequest = new GetAnomalyDetectorRequest( - "foo", - Versions.MATCH_ANY, - true, - false, - "", - "", - false, - null - ); + GetConfigRequest profileRequest = new GetConfigRequest("foo", Versions.MATCH_ANY, true, false, "", "", false, null); assertEquals(profileResponse, anomalyDetectionClient.getDetectorProfile(profileRequest).actionGet()); } diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java index c142e5e3d..79443a0e8 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java @@ -33,9 +33,7 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorType; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.model.DetectorState; import org.opensearch.ad.transport.GetAnomalyDetectorAction; -import org.opensearch.ad.transport.GetAnomalyDetectorRequest; import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.common.lucene.uid.Versions; @@ -46,7 +44,9 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ConfigState; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.GetConfigRequest; import com.google.common.collect.ImmutableList; @@ -148,16 +148,7 @@ public void testGetDetectorProfile_NoIndices() throws ExecutionException, Interr deleteIndexIfExists(ALL_AD_RESULTS_INDEX_PATTERN); deleteIndexIfExists(ADCommonName.DETECTION_STATE_INDEX); - GetAnomalyDetectorRequest profileRequest = new GetAnomalyDetectorRequest( - "foo", - Versions.MATCH_ANY, - true, - false, - "", - "", - false, - null - ); + GetConfigRequest profileRequest = new GetConfigRequest("foo", Versions.MATCH_ANY, true, false, "", "", false, null); OpenSearchStatusException exception = expectThrows( OpenSearchStatusException.class, @@ -190,7 +181,7 @@ public void testGetDetectorProfile_Populated() throws IOException { // Setting up mock profile to test that the state is returned correctly in the client response DetectorProfile mockProfile = mock(DetectorProfile.class); - when(mockProfile.getState()).thenReturn(DetectorState.DISABLED); + when(mockProfile.getState()).thenReturn(ConfigState.DISABLED); GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( 1234, @@ -213,16 +204,7 @@ public void testGetDetectorProfile_Populated() throws IOException { return null; }).when(clientSpy).execute(any(GetAnomalyDetectorAction.class), any(), any()); - GetAnomalyDetectorRequest profileRequest = new GetAnomalyDetectorRequest( - detectorId, - Versions.MATCH_ANY, - true, - false, - "", - "", - false, - null - ); + GetConfigRequest profileRequest = new GetConfigRequest(detectorId, Versions.MATCH_ANY, true, false, "", "", false, null); GetAnomalyDetectorResponse response = adClient.getDetectorProfile(profileRequest).actionGet(10000); @@ -230,7 +212,7 @@ public void testGetDetectorProfile_Populated() throws IOException { assertNotEquals(null, response.getDetectorProfile()); assertEquals(null, response.getAdJob()); assertEquals(detector.getName(), response.getDetector().getName()); - assertEquals(DetectorState.DISABLED, response.getDetectorProfile().getState()); + assertEquals(ConfigState.DISABLED, response.getDetectorProfile().getState()); verify(clientSpy, times(1)).execute(any(GetAnomalyDetectorAction.class), any(), any()); } diff --git a/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java b/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java index 88546e5ce..ba6ee0374 100644 --- a/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java @@ -27,7 +27,6 @@ import org.junit.Before; import org.junit.BeforeClass; import org.opensearch.Version; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -38,6 +37,9 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.gateway.GatewayService; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.ClusterEventListener; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.constant.CommonName; public class ADClusterEventListenerTests extends AbstractTimeSeriesTest { private final String clusterManagerNodeId = "clusterManagerNode"; @@ -45,7 +47,7 @@ public class ADClusterEventListenerTests extends AbstractTimeSeriesTest { private final String clusterName = "multi-node-cluster"; private ClusterService clusterService; - private ADClusterEventListener listener; + private ClusterEventListener listener; private HashRing hashRing; private ClusterState oldClusterState; private ClusterState newClusterState; @@ -66,7 +68,7 @@ public static void tearDownAfterClass() { @Before public void setUp() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(ADClusterEventListener.class); + super.setUpLog4jForJUnit(ClusterEventListener.class); clusterService = createClusterService(threadPool); hashRing = mock(HashRing.class); @@ -98,7 +100,7 @@ public void setUp() throws Exception { ) .build(); - listener = new ADClusterEventListener(clusterService, hashRing); + listener = new ClusterEventListener(clusterService, hashRing); } @Override @@ -114,12 +116,12 @@ public void tearDown() throws Exception { public void testUnchangedClusterState() { listener.clusterChanged(new ClusterChangedEvent("foo", oldClusterState, oldClusterState)); - assertTrue(!testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + assertTrue(!testAppender.containsMessage(ClusterEventListener.NODE_CHANGED_MSG)); } public void testIsWarmNode() { HashMap attributesForNode1 = new HashMap<>(); - attributesForNode1.put(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE); + attributesForNode1.put(CommonName.BOX_TYPE_KEY, CommonName.WARM_BOX_TYPE); dataNode1 = new DiscoveryNode(dataNode1Id, buildNewFakeTransportAddress(), attributesForNode1, BUILT_IN_ROLES, Version.CURRENT); ClusterState warmNodeClusterState = ClusterState @@ -134,7 +136,7 @@ public void testIsWarmNode() { .blocks(ClusterBlocks.builder().addGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) .build(); listener.clusterChanged(new ClusterChangedEvent("foo", warmNodeClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NOT_RECOVERED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NOT_RECOVERED_MSG)); } public void testNotRecovered() { @@ -150,7 +152,7 @@ public void testNotRecovered() { .blocks(ClusterBlocks.builder().addGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) .build(); listener.clusterChanged(new ClusterChangedEvent("foo", blockedClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NOT_RECOVERED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NOT_RECOVERED_MSG)); } class ListenerRunnable implements Runnable { @@ -170,7 +172,7 @@ public void testInProgress() { }).when(hashRing).buildCircles(any(), any()); new Thread(new ListenerRunnable()).start(); listener.clusterChanged(new ClusterChangedEvent("bar", newClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.IN_PROGRESS_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.IN_PROGRESS_MSG)); } public void testNodeAdded() { @@ -182,10 +184,10 @@ public void testNodeAdded() { doAnswer(invocation -> Optional.of(clusterManagerNode)) .when(hashRing) - .getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class)); + .getOwningNodeWithSameLocalVersionForRealtime(any(String.class)); listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NODE_CHANGED_MSG)); assertTrue(testAppender.containsMessage("node removed: false, node added: true")); } @@ -203,7 +205,7 @@ public void testNodeRemoved() { .build(); listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, twoDataNodeClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NODE_CHANGED_MSG)); assertTrue(testAppender.containsMessage("node removed: true, node added: true")); } } diff --git a/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java b/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java index 66fbd3e78..928a0ddaf 100644 --- a/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java @@ -47,6 +47,7 @@ import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.cluster.ADDataMigrator; import org.opensearch.timeseries.constant.CommonName; public class ADDataMigratorTests extends ADUnitTestCase { diff --git a/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java b/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java index aa5fcc55b..79f1cd26d 100644 --- a/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java @@ -13,22 +13,23 @@ import org.opensearch.Version; import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.timeseries.cluster.VersionUtil; public class ADVersionUtilTests extends ADUnitTestCase { public void testParseVersionFromString() { - Version version = ADVersionUtil.fromString("2.1.0.0"); + Version version = VersionUtil.fromString("2.1.0.0"); assertEquals(Version.V_2_1_0, version); - version = ADVersionUtil.fromString("2.1.0"); + version = VersionUtil.fromString("2.1.0"); assertEquals(Version.V_2_1_0, version); } public void testParseVersionFromStringWithNull() { - expectThrows(IllegalArgumentException.class, () -> ADVersionUtil.fromString(null)); + expectThrows(IllegalArgumentException.class, () -> VersionUtil.fromString(null)); } public void testParseVersionFromStringWithWrongFormat() { - expectThrows(IllegalArgumentException.class, () -> ADVersionUtil.fromString("1.1")); + expectThrows(IllegalArgumentException.class, () -> VersionUtil.fromString("1.1")); } } diff --git a/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java b/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java index 9c2e79236..33ab3958e 100644 --- a/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java @@ -24,11 +24,10 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Locale; import org.junit.Before; -import org.opensearch.ad.cluster.diskcleanup.ModelCheckpointIndexRetention; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -36,9 +35,14 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.forecast.settings.ForecastSettings; import org.opensearch.threadpool.Scheduler.Cancellable; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.ClusterManagerEventListener; +import org.opensearch.timeseries.cluster.HourlyCron; +import org.opensearch.timeseries.cluster.diskcleanup.BaseModelCheckpointIndexRetention; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -69,13 +73,13 @@ public void setUp() throws Exception { checkpointIndexRetentionCancellable = mock(Cancellable.class); when(threadPool.scheduleWithFixedDelay(any(HourlyCron.class), any(TimeValue.class), any(String.class))) .thenReturn(hourlyCancellable); - when(threadPool.scheduleWithFixedDelay(any(ModelCheckpointIndexRetention.class), any(TimeValue.class), any(String.class))) + when(threadPool.scheduleWithFixedDelay(any(BaseModelCheckpointIndexRetention.class), any(TimeValue.class), any(String.class))) .thenReturn(checkpointIndexRetentionCancellable); client = mock(Client.class); clock = mock(Clock.class); clientUtil = mock(ClientUtil.class); HashMap ignoredAttributes = new HashMap(); - ignoredAttributes.put(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE); + ignoredAttributes.put(CommonName.BOX_TYPE_KEY, CommonName.WARM_BOX_TYPE); nodeFilter = new DiscoveryNodeFilterer(clusterService); clusterManagerService = new ClusterManagerEventListener( @@ -86,6 +90,7 @@ public void setUp() throws Exception { clientUtil, nodeFilter, AnomalyDetectorSettings.AD_CHECKPOINT_TTL, + ForecastSettings.FORECAST_CHECKPOINT_TTL, Settings.EMPTY ); } @@ -95,7 +100,10 @@ public void testOnOffClusterManager() { assertThat(hourlyCancellable, is(notNullValue())); assertThat(checkpointIndexRetentionCancellable, is(notNullValue())); assertTrue(!clusterManagerService.getHourlyCron().isCancelled()); - assertTrue(!clusterManagerService.getCheckpointIndexRetentionCron().isCancelled()); + List checkpointIndexRetention = clusterManagerService.getCheckpointIndexRetentionCron(); + for (Cancellable cancellable : checkpointIndexRetention) { + assertTrue(!cancellable.isCancelled()); + } clusterManagerService.offClusterManager(); assertThat(clusterManagerService.getCheckpointIndexRetentionCron(), is(nullValue())); assertThat(clusterManagerService.getHourlyCron(), is(nullValue())); diff --git a/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java b/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java index a57a4c649..1ad8834e4 100644 --- a/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java +++ b/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java @@ -28,6 +28,7 @@ import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.DailyCron; import org.opensearch.timeseries.util.ClientUtil; public class DailyCronTests extends AbstractTimeSeriesTest { diff --git a/src/test/java/org/opensearch/ad/cluster/HashRingTests.java b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java index 69bb38a57..12826aea1 100644 --- a/src/test/java/org/opensearch/ad/cluster/HashRingTests.java +++ b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java @@ -36,8 +36,7 @@ import org.opensearch.action.admin.cluster.node.info.NodesInfoResponse; import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; import org.opensearch.ad.ADUnitTestCase; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; @@ -50,6 +49,8 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.plugins.PluginInfo; +import org.opensearch.timeseries.cluster.ADDataMigrator; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -74,7 +75,7 @@ public class HashRingTests extends ADUnitTestCase { private DiscoveryNode localNode; private DiscoveryNode newNode; private DiscoveryNode warmNode; - private ModelManager modelManager; + private ADModelManager modelManager; @Override @Before @@ -86,7 +87,7 @@ public void setUp() throws Exception { newNodeId = "newNode"; newNode = createNode(newNodeId, "127.0.0.2", 9201, emptyMap()); warmNodeId = "warmNode"; - warmNode = createNode(warmNodeId, "127.0.0.3", 9202, ImmutableMap.of(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE)); + warmNode = createNode(warmNodeId, "127.0.0.3", 9202, ImmutableMap.of(CommonName.BOX_TYPE_KEY, CommonName.WARM_BOX_TYPE)); settings = Settings.builder().put(AD_COOLDOWN_MINUTES.getKey(), TimeValue.timeValueSeconds(5)).build(); ClusterSettings clusterSettings = clusterSetting(settings, AD_COOLDOWN_MINUTES); @@ -107,7 +108,7 @@ public void setUp() throws Exception { when(adminClient.cluster()).thenReturn(clusterAdminClient); String modelId = "123_model_threshold"; - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); doAnswer(invocation -> { Set res = new HashSet<>(); res.add(modelId); @@ -121,7 +122,7 @@ public void testGetOwningNodeWithEmptyResult() throws UnknownHostException { DiscoveryNode node1 = createNode(Integer.toString(1), "127.0.0.4", 9204, emptyMap()); doReturn(node1).when(clusterService).localNode(); - Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD("http-latency-rcf-1"); + Optional node = hashRing.getOwningNodeWithSameLocalVersionForRealtime("http-latency-rcf-1"); assertFalse(node.isPresent()); } @@ -130,10 +131,10 @@ public void testGetOwningNode() throws UnknownHostException { // Add first node, hashRing.buildCircles(delta, ActionListener.wrap(r -> { - Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD("http-latency-rcf-1"); + Optional node = hashRing.getOwningNodeWithSameLocalVersionForRealtime("http-latency-rcf-1"); assertTrue(node.isPresent()); assertTrue(asList(newNodeId, localNodeId).contains(node.get().getId())); - DiscoveryNode[] nodesWithSameLocalAdVersion = hashRing.getNodesWithSameLocalAdVersion(); + DiscoveryNode[] nodesWithSameLocalAdVersion = hashRing.getNodesWithSameLocalVersion(); Set nodesWithSameLocalAdVersionIds = new HashSet<>(); for (DiscoveryNode n : nodesWithSameLocalAdVersion) { nodesWithSameLocalAdVersionIds.add(n.getId()); @@ -143,10 +144,10 @@ public void testGetOwningNode() throws UnknownHostException { assertEquals( "Wrong hash ring size for historical analysis", 2, - hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + hashRing.getNodesWithSameVersion(Version.V_2_1_0, false).size() ); // Circles for realtime AD will change as it's eligible to build for when its empty - assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameVersion(Version.V_2_1_0, true).size()); }, e -> { logger.error("building hash ring failed", e); assertFalse("Build hash ring failed", true); @@ -162,10 +163,10 @@ public void testGetOwningNode() throws UnknownHostException { assertEquals( "Wrong hash ring size for historical analysis", 3, - hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + hashRing.getNodesWithSameVersion(Version.V_2_1_0, false).size() ); // Circles for realtime AD will not change as it's eligible to rebuild - assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameVersion(Version.V_2_1_0, true).size()); }, e -> { logger.error("building hash ring failed", e); @@ -183,9 +184,9 @@ public void testGetOwningNode() throws UnknownHostException { assertEquals( "Wrong hash ring size for historical analysis", 4, - hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + hashRing.getNodesWithSameVersion(Version.V_2_1_0, false).size() ); - assertEquals("Wrong hash ring size for realtime AD", 4, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + assertEquals("Wrong hash ring size for realtime AD", 4, hashRing.getNodesWithSameVersion(Version.V_2_1_0, true).size()); }, e -> { logger.error("building hash ring failed", e); assertFalse("Failed to build hash ring", true); @@ -194,7 +195,7 @@ public void testGetOwningNode() throws UnknownHostException { public void testGetAllEligibleDataNodesWithKnownAdVersionAndGetNodeByAddress() { setupNodeDelta(); - hashRing.getAllEligibleDataNodesWithKnownAdVersion(nodes -> { + hashRing.getAllEligibleDataNodesWithKnownVersion(nodes -> { assertEquals("Wrong hash ring size for historical analysis", 2, nodes.length); Optional node = hashRing.getNodeByAddress(newNode.getAddress()); assertTrue(node.isPresent()); @@ -205,7 +206,7 @@ public void testGetAllEligibleDataNodesWithKnownAdVersionAndGetNodeByAddress() { public void testBuildAndGetOwningNodeWithSameLocalAdVersion() { setupNodeDelta(); hashRing - .buildAndGetOwningNodeWithSameLocalAdVersion( + .buildAndGetOwningNodeWithSameLocalVersion( "testModelId", node -> { assertTrue(node.isPresent()); }, ActionListener.wrap(r -> {}, e -> { diff --git a/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java b/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java index 2806138d9..831b546da 100644 --- a/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java +++ b/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java @@ -27,10 +27,7 @@ import org.opensearch.OpenSearchException; import org.opensearch.Version; import org.opensearch.action.FailedNodeException; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.transport.CronAction; -import org.opensearch.ad.transport.CronNodeResponse; -import org.opensearch.ad.transport.CronResponse; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -39,6 +36,10 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.HourlyCron; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.transport.CronNodeResponse; +import org.opensearch.timeseries.transport.CronResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import test.org.opensearch.ad.util.ClusterCreation; @@ -59,7 +60,7 @@ public void templateHourlyCron(HourlyCronTestExecutionMode mode) { ClusterState state = ClusterCreation.state(1); when(clusterService.state()).thenReturn(state); HashMap ignoredAttributes = new HashMap(); - ignoredAttributes.put(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE); + ignoredAttributes.put(CommonName.BOX_TYPE_KEY, CommonName.WARM_BOX_TYPE); DiscoveryNodeFilterer nodeFilter = new DiscoveryNodeFilterer(clusterService); Client client = mock(Client.class); diff --git a/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java b/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java index 0748fe122..399da125e 100644 --- a/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java +++ b/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java @@ -37,6 +37,7 @@ import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.store.StoreStats; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.diskcleanup.IndexCleanup; import org.opensearch.timeseries.util.ClientUtil; public class IndexCleanupTests extends AbstractTimeSeriesTest { diff --git a/src/test/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetentionTests.java b/src/test/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetentionTests.java index b95757925..eca572199 100644 --- a/src/test/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetentionTests.java +++ b/src/test/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetentionTests.java @@ -27,8 +27,11 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.core.action.ActionListener; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.diskcleanup.BaseModelCheckpointIndexRetention; +import org.opensearch.timeseries.cluster.diskcleanup.IndexCleanup; public class ModelCheckpointIndexRetentionTests extends AbstractTimeSeriesTest { @@ -39,7 +42,7 @@ public class ModelCheckpointIndexRetentionTests extends AbstractTimeSeriesTest { @Mock IndexCleanup indexCleanup; - ModelCheckpointIndexRetention modelCheckpointIndexRetention; + BaseModelCheckpointIndexRetention modelCheckpointIndexRetention; @SuppressWarnings("unchecked") @Before @@ -47,7 +50,12 @@ public void setUp() throws Exception { super.setUp(); super.setUpLog4jForJUnit(IndexCleanup.class); MockitoAnnotations.initMocks(this); - modelCheckpointIndexRetention = new ModelCheckpointIndexRetention(defaultCheckpointTtl, clock, indexCleanup); + modelCheckpointIndexRetention = new BaseModelCheckpointIndexRetention( + defaultCheckpointTtl, + clock, + indexCleanup, + ADIndex.CHECKPOINT.getIndexName() + ); doAnswer(invocation -> { Object[] args = invocation.getArguments(); ActionListener listener = (ActionListener) args[2]; diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java index 4330118b6..b74f1ea58 100644 --- a/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java +++ b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java @@ -11,9 +11,9 @@ package org.opensearch.ad.e2e; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE; import static org.opensearch.timeseries.TestHelpers.toHttpEntity; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.BACKOFF_MINUTES; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; import java.io.File; import java.io.FileReader; @@ -61,8 +61,8 @@ protected void disableResourceNotFoundFaultTolerence() throws IOException { settingCommand.startObject(); settingCommand.startObject("persistent"); - settingCommand.field(MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); - settingCommand.field(BACKOFF_MINUTES.getKey(), 0); + settingCommand.field(AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); + settingCommand.field(AD_BACKOFF_MINUTES.getKey(), 0); settingCommand.endObject(); settingCommand.endObject(); Request request = new Request("PUT", "/_cluster/settings"); diff --git a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java index b78647c11..01549b02d 100644 --- a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java +++ b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java @@ -61,7 +61,10 @@ import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.Features; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.feature.SinglePointFeatures; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; @@ -142,8 +145,6 @@ public void setup() { searchFeatureDao, imputer, clock, - maxTrainSamples, - maxSampleStride, trainSampleTimeRangeInHours, minTrainSamples, maxMissingPointsRate, @@ -203,7 +204,7 @@ public void getColdStartData_returnExpectedToListener( ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.ofNullable(latestTime)); return null; - }).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class)); + }).when(searchFeatureDao).getLatestDataTime(eq(detector), eq(Optional.empty()), eq(AnalysisType.AD), any(ActionListener.class)); if (latestTime != null) { doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(3); @@ -220,8 +221,6 @@ public void getColdStartData_returnExpectedToListener( searchFeatureDao, imputer, clock, - maxTrainSamples, - maxSampleStride, trainSampleTimeRangeInHours, minTrainSamples, 0.5, /*maxMissingPointsRate*/ @@ -248,7 +247,7 @@ public void getColdStartData_throwToListener_whenSearchFail() { ActionListener> listener = invocation.getArgument(1); listener.onFailure(new RuntimeException()); return null; - }).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class)); + }).when(searchFeatureDao).getLatestDataTime(eq(detector), eq(Optional.empty()), eq(AnalysisType.AD), any(ActionListener.class)); ActionListener> listener = mock(ActionListener.class); featureManager.getColdStartData(detector, listener); @@ -263,7 +262,7 @@ public void getColdStartData_throwToListener_onQueryCreationError() throws Excep ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.ofNullable(0L)); return null; - }).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class)); + }).when(searchFeatureDao).getLatestDataTime(eq(detector), eq(Optional.empty()), eq(AnalysisType.AD), any(ActionListener.class)); doThrow(IOException.class) .when(searchFeatureDao) .getFeatureSamplesForPeriods(eq(detector), any(), eq(AnalysisType.AD), any(ActionListener.class)); diff --git a/src/test/java/org/opensearch/ad/feature/FeaturesTests.java b/src/test/java/org/opensearch/ad/feature/FeaturesTests.java index 7a6b3b8e1..f7716e81b 100644 --- a/src/test/java/org/opensearch/ad/feature/FeaturesTests.java +++ b/src/test/java/org/opensearch/ad/feature/FeaturesTests.java @@ -20,6 +20,7 @@ import org.junit.Test; import org.junit.runner.RunWith; +import org.opensearch.timeseries.feature.Features; import junitparams.JUnitParamsRunner; import junitparams.Parameters; diff --git a/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java b/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java index 53bea9015..ebeded321 100644 --- a/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java +++ b/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java @@ -142,7 +142,7 @@ private Map createMapping() { entity_nested_mapping.put("name", Collections.singletonMap("type", "keyword")); entity_nested_mapping.put("value", Collections.singletonMap("type", "keyword")); entity_mapping.put(CommonName.PROPERTIES, entity_nested_mapping); - mappings.put(CommonName.ENTITY_FIELD, entity_mapping); + mappings.put(CommonName.ENTITY_KEY, entity_mapping); Map error_mapping = new HashMap<>(); error_mapping.put("type", "text"); @@ -188,7 +188,7 @@ private Map createMapping() { attribution_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); mappings.put(AnomalyResult.RELEVANT_ATTRIBUTION_FIELD, attribution_mapping); - mappings.put(CommonName.SCHEMA_VERSION_FIELD, Collections.singletonMap("type", "integer")); + mappings.put(org.opensearch.timeseries.constant.CommonName.SCHEMA_VERSION_FIELD, Collections.singletonMap("type", "integer")); mappings.put(CommonName.TASK_ID_FIELD, Collections.singletonMap("type", "keyword")); diff --git a/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java b/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java index 10b00426a..0a837603d 100644 --- a/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java +++ b/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java @@ -55,7 +55,6 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.index.IndexNotFoundException; import org.opensearch.timeseries.AbstractTimeSeriesTest; -import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -168,7 +167,7 @@ public void testUpdateMapping() throws IOException { put(ADIndexManagement.META, new HashMap() { { // version 1 will cause update - put(CommonName.SCHEMA_VERSION_FIELD, 1); + put(org.opensearch.timeseries.constant.CommonName.SCHEMA_VERSION_FIELD, 1); } }); } diff --git a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java index 1a86e45d4..9901ddd84 100644 --- a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java +++ b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java @@ -21,6 +21,7 @@ import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -29,9 +30,8 @@ import org.opensearch.Version; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; @@ -52,12 +52,16 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.ClientUtil; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.collect.ImmutableList; public class AbstractCosineDataTest extends AbstractTimeSeriesTest { @@ -65,37 +69,40 @@ public class AbstractCosineDataTest extends AbstractTimeSeriesTest { String modelId; String entityName; String detectorId; - ModelState modelState; + ModelState modelState; Clock clock; float priority; - EntityColdStarter entityColdStarter; + ADColdStart entityColdStarter; NodeStateManager stateManager; SearchFeatureDao searchFeatureDao; Imputer imputer; - CheckpointDao checkpoint; FeatureManager featureManager; Settings settings; ThreadPool threadPool; AtomicBoolean released; Runnable releaseSemaphore; - ActionListener listener; + ActionListener> listener; CountDownLatch inProgressLatch; - CheckpointWriteWorker checkpointWriteQueue; + ADCheckpointWriteWorker checkpointWriteQueue; Entity entity; AnomalyDetector detector; long rcfSeed; - ModelManager modelManager; + ADModelManager modelManager; ClientUtil clientUtil; ClusterService clusterService; ClusterSettings clusterSettings; DiscoveryNode discoveryNode; Set> nodestateSetting; + int detectorInterval = 1; + int shingleSize; @SuppressWarnings("unchecked") @Override public void setUp() throws Exception { super.setUp(); - numMinSamples = TimeSeriesSettings.NUM_MIN_SAMPLES; + // numMinSamples should be larger than shingleSize; otherwise, we will get rcf exception + numMinSamples = 3; + shingleSize = 2; clock = mock(Clock.class); when(clock.instant()).thenReturn(Instant.now()); @@ -110,8 +117,9 @@ public void setUp() throws Exception { detector = TestHelpers.AnomalyDetectorBuilder .newInstance() - .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) + .setDetectionInterval(new IntervalTimeConfiguration(detectorInterval, ChronoUnit.MINUTES)) .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .setShingleSize(shingleSize) .build(); when(clock.millis()).thenReturn(1602401500000L); doAnswer(invocation -> { @@ -153,16 +161,13 @@ public void setUp() throws Exception { imputer = new LinearUniformImputer(true); searchFeatureDao = mock(SearchFeatureDao.class); - checkpoint = mock(CheckpointDao.class); featureManager = new FeatureManager( searchFeatureDao, imputer, clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + TimeSeriesSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + TimeSeriesSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, @@ -172,28 +177,27 @@ public void setUp() throws Exception { TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); - checkpointWriteQueue = mock(CheckpointWriteWorker.class); + checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); rcfSeed = 2051L; - entityColdStarter = new EntityColdStarter( + entityColdStarter = new ADColdStart( clock, threadPool, stateManager, TimeSeriesSettings.NUM_SAMPLES_PER_TREE, TimeSeriesSettings.NUM_TREES, - TimeSeriesSettings.TIME_DECAY, numMinSamples, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - imputer, searchFeatureDao, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, - settings, + // settings, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, rcfSeed, - TimeSeriesSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 ); detectorId = "123"; @@ -204,19 +208,13 @@ public void setUp() throws Exception { released = new AtomicBoolean(); - inProgressLatch = new CountDownLatch(1); - releaseSemaphore = () -> { - released.set(true); - inProgressLatch.countDown(); - }; - listener = ActionListener.wrap(releaseSemaphore); + resetListener(); - modelManager = new ModelManager( - mock(CheckpointDao.class), + modelManager = new ADModelManager( + mock(ADCheckpointDao.class), mock(Clock.class), TimeSeriesSettings.NUM_TREES, TimeSeriesSettings.NUM_SAMPLES_PER_TREE, - TimeSeriesSettings.TIME_DECAY, TimeSeriesSettings.NUM_MIN_SAMPLES, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, AnomalyDetectorSettings.MIN_PREVIEW_SIZE, @@ -230,8 +228,17 @@ public void setUp() throws Exception { ); } + protected void resetListener() { + inProgressLatch = new CountDownLatch(1); + releaseSemaphore = () -> { + released.set(true); + inProgressLatch.countDown(); + }; + listener = ActionListener.wrap(releaseSemaphore); + } + protected void checkSemaphoreRelease() throws InterruptedException { - assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + assertTrue(inProgressLatch.await(30, TimeUnit.SECONDS)); assertTrue(released.get()); } @@ -239,12 +246,14 @@ public int searchInsert(long[] timestamps, long target) { int pivot, left = 0, right = timestamps.length - 1; while (left <= right) { pivot = left + (right - left) / 2; - if (timestamps[pivot] == target) + if (timestamps[pivot] == target) { return pivot; - if (target < timestamps[pivot]) + } + if (target < timestamps[pivot]) { right = pivot - 1; - else + } else { left = pivot + 1; + } } return left; } diff --git a/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java index 72358af10..b482422ee 100644 --- a/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java +++ b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java @@ -23,7 +23,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.ad.ml.CheckpointDao.FIELD_MODELV2; +import static org.opensearch.ad.ml.ADCheckpointDao.FIELD_MODELV2; import java.io.BufferedReader; import java.io.File; @@ -40,19 +40,15 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.time.Month; -import java.time.OffsetDateTime; -import java.time.ZoneOffset; import java.util.ArrayList; import java.util.Arrays; +import java.util.Deque; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.NoSuchElementException; import java.util.Optional; -import java.util.Queue; import java.util.Random; import java.util.Set; import java.util.concurrent.CountDownLatch; @@ -101,7 +97,11 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.ClientUtil; @@ -127,7 +127,7 @@ public class CheckpointDaoTests extends OpenSearchTestCase { private static final Logger logger = LogManager.getLogger(CheckpointDaoTests.class); - private CheckpointDao checkpointDao; + private ADCheckpointDao checkpointDao; // dependencies @Mock(answer = Answers.RETURNS_DEEP_STUBS) @@ -162,6 +162,7 @@ public class CheckpointDaoTests extends OpenSearchTestCase { private ThresholdedRandomCutForestMapper trcfMapper; private V1JsonToV3StateConverter converter; double anomalyRate; + private Instant now; @Before public void setup() { @@ -169,12 +170,12 @@ public void setup() { indexName = "testIndexName"; - // gson = PowerMockito.mock(Gson.class); gson = new GsonBuilder().serializeSpecialFloatingPointValues().create(); thresholdingModelClass = HybridThresholdingModel.class; - when(clock.instant()).thenReturn(Instant.now()); + now = Instant.now(); + when(clock.instant()).thenReturn(now); mapper = new RandomCutForestMapper(); mapper.setSaveExecutorContextEnabled(true); @@ -211,10 +212,9 @@ public PooledObject wrap(LinkedBuffer obj) { serializeRCFBufferPool.setTimeBetweenEvictionRuns(TimeSeriesSettings.HOURLY_MAINTENANCE); anomalyRate = 0.005; - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -225,7 +225,8 @@ public PooledObject wrap(LinkedBuffer obj) { maxCheckpointBytes, serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); when(indexUtil.doesCheckpointIndexExist()).thenReturn(true); @@ -281,7 +282,7 @@ private void verifyPutModelCheckpointAsync() { checkpointDao.putTRCFCheckpoint(modelId, createTRCF(), listener); UpdateRequest updateRequest = requestCaptor.getValue(); - assertEquals(indexName, updateRequest.index()); + assertEquals(ADCommonName.CHECKPOINT_INDEX_NAME, updateRequest.index()); assertEquals(modelId, updateRequest.id()); IndexRequest indexRequest = updateRequest.doc(); Set expectedSourceKeys = new HashSet(Arrays.asList(FIELD_MODELV2, CommonName.TIMESTAMP)); @@ -380,7 +381,7 @@ public void test_getModelCheckpoint_returnExpectedToListener() { checkpointDao.getTRCFModel(modelId, listener); GetRequest capturedGetRequest = getRequest.get(); - assertEquals(indexName, capturedGetRequest.index()); + assertEquals(ADCommonName.CHECKPOINT_INDEX_NAME, capturedGetRequest.index()); assertEquals(modelId, capturedGetRequest.id()); ArgumentCaptor> responseCaptor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(responseCaptor.capture()); @@ -434,7 +435,7 @@ public void test_getModelCheckpoint_Bwc() { checkpointDao.getTRCFModel(modelId, listener); GetRequest capturedGetRequest = getRequest.get(); - assertEquals(indexName, capturedGetRequest.index()); + assertEquals(ADCommonName.CHECKPOINT_INDEX_NAME, capturedGetRequest.index()); assertEquals(modelId, capturedGetRequest.id()); ArgumentCaptor> responseCaptor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(responseCaptor.capture()); @@ -461,12 +462,8 @@ public void test_getModelCheckpoint_returnEmptyToListener_whenModelNotFound() { checkpointDao.getTRCFModel(modelId, listener); GetRequest getRequest = requestCaptor.getValue(); - assertEquals(indexName, getRequest.index()); + assertEquals(ADCommonName.CHECKPOINT_INDEX_NAME, getRequest.index()); assertEquals(modelId, getRequest.id()); - // ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Exception.class); - // verify(listener).onFailure(responseCaptor.capture()); - // Exception exception = responseCaptor.getValue(); - // assertTrue(exception instanceof ResourceNotFoundException); ArgumentCaptor> responseCaptor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(responseCaptor.capture()); assertTrue(!responseCaptor.getValue().isPresent()); @@ -485,7 +482,7 @@ public void test_deleteModelCheckpoint_callListener_whenCompleted() { checkpointDao.deleteModelCheckpoint(modelId, listener); DeleteRequest deleteRequest = requestCaptor.getValue(); - assertEquals(indexName, deleteRequest.index()); + assertEquals(ADCommonName.CHECKPOINT_INDEX_NAME, deleteRequest.index()); assertEquals(modelId, deleteRequest.id()); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Void.class); @@ -494,50 +491,29 @@ public void test_deleteModelCheckpoint_callListener_whenCompleted() { assertEquals(null, response); } - @SuppressWarnings("unchecked") + // @SuppressWarnings("unchecked") public void test_restore() throws IOException { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - EntityModel modelToSave = state.getModel(); - - GetResponse getResponse = mock(GetResponse.class); - when(getResponse.isExists()).thenReturn(true); - Map source = new HashMap<>(); - source.put(CheckpointDao.DETECTOR_ID, state.getId()); - source.put(CheckpointDao.FIELD_MODELV2, checkpointDao.toCheckpoint(modelToSave, modelId).get()); - source.put(CommonName.TIMESTAMP, "2020-10-11T22:58:23.610392Z"); - when(getResponse.getSource()).thenReturn(source); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ThresholdedRandomCutForest modelToSave = state.getModel().get(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + Map source = checkpointDao.toIndexSource(state); + ModelState modelState = checkpointDao + .processHCGetResponse(TestHelpers.createGetResponse(source, modelId, "blah"), modelId, "123"); + // ModelState modelState = checkpointDao + // .processHCGetResponse(TestHelpers.createGetResponse(source, modelId, "blah"), modelId, ADCheckpointDao.DETECTOR_ID); - listener.onResponse(getResponse); - return null; - }).when(clientUtil).asyncRequest(any(GetRequest.class), any(BiConsumer.class), any(ActionListener.class)); + Instant utcTime = modelState.getLastCheckpointTime(); + // Oct 11, 2020 22:58:23 UTC + assertEquals(now, utcTime);// Instant.ofEpochSecond(1602457103) - ActionListener>> listener = mock(ActionListener.class); - checkpointDao.deserializeModelCheckpoint(modelId, listener); + ThresholdedRandomCutForest model = modelState.getModel().get(); + assertEquals(modelToSave.getForest().getTotalUpdates(), model.getForest().getTotalUpdates()); - ArgumentCaptor>> responseCaptor = ArgumentCaptor.forClass(Optional.class); - verify(listener).onResponse(responseCaptor.capture()); - Optional> response = responseCaptor.getValue(); - assertTrue(response.isPresent()); - Entry entry = response.get(); - OffsetDateTime utcTime = entry.getValue().atOffset(ZoneOffset.UTC); - assertEquals(2020, utcTime.getYear()); - assertEquals(Month.OCTOBER, utcTime.getMonth()); - assertEquals(11, utcTime.getDayOfMonth()); - assertEquals(22, utcTime.getHour()); - assertEquals(58, utcTime.getMinute()); - assertEquals(23, utcTime.getSecond()); - - EntityModel model = entry.getKey(); - Queue queue = model.getSamples(); - Queue samplesToSave = modelToSave.getSamples(); + Deque queue = modelState.getSamples(); + Deque samplesToSave = state.getSamples(); assertEquals(samplesToSave.size(), queue.size()); - assertTrue(Arrays.equals(samplesToSave.peek(), queue.peek())); - logger.info(modelToSave.getTrcf()); - logger.info(model.getTrcf()); - assertEquals(modelToSave.getTrcf().get().getForest().getTotalUpdates(), model.getTrcf().get().getForest().getTotalUpdates()); + assertEquals(samplesToSave.peek(), queue.peek()); } public void test_batch_write_no_index() { @@ -644,7 +620,7 @@ public void test_batch_write_no_init() throws InterruptedException { final CountDownLatch processingLatch = new CountDownLatch(1); checkpointDao - .batchWrite(new BulkRequest(), ActionListener.wrap(response -> processingLatch.countDown(), e -> { assertTrue(false); })); + .batchWrite(new BulkRequest(), ActionListener.wrap(response -> processingLatch.countDown(), e -> { assertTrue(false); })); // we don't expect the waiting time elapsed before the count reached zero assertTrue(processingLatch.await(100, TimeUnit.SECONDS)); @@ -679,10 +655,9 @@ public void test_batch_read() throws InterruptedException { } public void test_too_large_checkpoint() throws IOException { - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -693,16 +668,19 @@ public void test_too_large_checkpoint() throws IOException { 1, // make the max checkpoint size 1 byte only serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); assertTrue(checkpointDao.toIndexSource(state).isEmpty()); } public void test_to_index_source() throws IOException { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); Map source = checkpointDao.toIndexSource(state); assertTrue(!source.isEmpty()); @@ -716,10 +694,9 @@ public void test_to_index_source() throws IOException { public void testBorrowFromPoolFailure() throws Exception { GenericObjectPool mockSerializeRCFBufferPool = mock(GenericObjectPool.class); when(mockSerializeRCFBufferPool.borrowObject()).thenThrow(NoSuchElementException.class); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -730,21 +707,22 @@ public void testBorrowFromPoolFailure() throws Exception { 1, // make the max checkpoint size 1 byte only mockSerializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - assertTrue(!checkpointDao.toCheckpoint(state.getModel(), modelId).get().isEmpty()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + assertTrue(!checkpointDao.toCheckpoint(state.getModel().get(), modelId).get().isEmpty()); } public void testMapperFailure() throws IOException { ThresholdedRandomCutForestMapper mockMapper = mock(ThresholdedRandomCutForestMapper.class); when(mockMapper.toState(any())).thenThrow(RuntimeException.class); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -755,44 +733,42 @@ public void testMapperFailure() throws IOException { 1, // make the max checkpoint size 1 byte only serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); // make sure sample size is not 0 otherwise sample size won't be written to checkpoint - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(1).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); - assertEquals(null, JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); - assertTrue(null != JsonDeserializer.getChildNode(json, CommonName.ENTITY_SAMPLE)); - // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_THRESHOLD)); - // assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(1).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); + assertEquals(null, JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testEmptySample() throws IOException { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); - // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); assertEquals(null, JsonDeserializer.getChildNode(json, CommonName.ENTITY_SAMPLE)); - // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_THRESHOLD)); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testToCheckpointErcfCheckoutFail() throws Exception { when(serializeRCFBufferPool.borrowObject()).thenThrow(RuntimeException.class); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } @SuppressWarnings("unchecked") private void setUpMockTrcf() { trcfMapper = mock(ThresholdedRandomCutForestMapper.class); trcfSchema = mock(Schema.class); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -803,7 +779,8 @@ private void setUpMockTrcf() { maxCheckpointBytes, serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); } @@ -811,10 +788,11 @@ public void testToCheckpointTrcfCheckoutBufferFail() throws Exception { setUpMockTrcf(); when(trcfMapper.toState(any())).thenThrow(RuntimeException.class).thenReturn(null); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testToCheckpointTrcfFailNewBuffer() throws Exception { @@ -822,10 +800,11 @@ public void testToCheckpointTrcfFailNewBuffer() throws Exception { doReturn(null).when(serializeRCFBufferPool).borrowObject(); when(trcfMapper.toState(any())).thenThrow(RuntimeException.class); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - assertNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testToCheckpointTrcfCheckoutBufferInvalidateFail() throws Exception { @@ -833,42 +812,52 @@ public void testToCheckpointTrcfCheckoutBufferInvalidateFail() throws Exception when(trcfMapper.toState(any())).thenThrow(RuntimeException.class).thenReturn(null); doThrow(RuntimeException.class).when(serializeRCFBufferPool).invalidateObject(any()); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testFromEntityModelCheckpointWithTrcf() throws Exception { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - String model = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + String model = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - Map entity = new HashMap<>(); - entity.put(FIELD_MODELV2, model); - entity.put(CommonName.TIMESTAMP, Instant.now().toString()); - Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); + Map source = new HashMap<>(); + source.put(ADCheckpointDao.DETECTOR_ID, state.getConfigId()); + source.put(FIELD_MODELV2, model); + source.put(CommonName.TIMESTAMP, Instant.now().toString()); - assertTrue(result.isPresent()); - Entry pair = result.get(); - EntityModel entityModel = pair.getKey(); - assertTrue(entityModel.getTrcf().isPresent()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(source); + + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + + assertTrue(result != null); + assertTrue(result.getModel().isPresent()); } public void testFromEntityModelCheckpointTrcfMapperFail() throws Exception { setUpMockTrcf(); when(trcfMapper.toModel(any())).thenThrow(RuntimeException.class); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - String model = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + String model = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - Map entity = new HashMap<>(); - entity.put(FIELD_MODELV2, model); - entity.put(CommonName.TIMESTAMP, Instant.now().toString()); - Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); + Map source = new HashMap<>(); + source.put(FIELD_MODELV2, model); + source.put(CommonName.TIMESTAMP, Instant.now().toString()); - assertTrue(result.isPresent()); - Entry pair = result.get(); - EntityModel entityModel = pair.getKey(); - assertFalse(entityModel.getTrcf().isPresent()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(source); + + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + + assertTrue(result != null); + assertTrue(result.getModel().isEmpty()); } private Pair, Instant> setUp1_0Model(String checkpointFileName) throws FileNotFoundException, @@ -896,20 +885,22 @@ public void testFromEntityModelCheckpointBWC() throws FileNotFoundException, IOE Pair, Instant> modelPair = setUp1_0Model("checkpoint_2.json"); Instant now = modelPair.getRight(); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); - assertTrue(result.isPresent()); - Entry pair = result.get(); - assertEquals(now, pair.getValue()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(modelPair.getLeft()); - EntityModel entityModel = pair.getKey(); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + assertTrue(result != null); + assertEquals(now, result.getLastCheckpointTime()); + + Deque samples = result.getSamples(); - Queue samples = entityModel.getSamples(); assertEquals(6, samples.size()); - double[] firstSample = samples.peek(); + double[] firstSample = samples.peek().getValueList(); assertEquals(1, firstSample.length); assertEquals(0.6832234717598454, firstSample[0], 1e-10); - ThresholdedRandomCutForest trcf = entityModel.getTrcf().get(); + ThresholdedRandomCutForest trcf = result.getModel().get(); RandomCutForest forest = trcf.getForest(); assertEquals(1, forest.getDimensions()); assertEquals(10, forest.getNumberOfTrees()); @@ -926,10 +917,9 @@ public void testFromEntityModelCheckpointBWC() throws FileNotFoundException, IOE public void testFromEntityModelCheckpointModelTooLarge() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_2.json"); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -940,43 +930,60 @@ public void testFromEntityModelCheckpointModelTooLarge() throws FileNotFoundExce 100_000, // checkpoint_2.json is of 224603 bytes. serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(modelPair.getLeft()); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); // checkpoint is only configured to take in 1 MB checkpoint at most. But the checkpoint here is of 1408047 bytes. - assertTrue(!result.isPresent()); + assertTrue(result == null); } // test no model is present in checkpoint public void testFromEntityModelCheckpointEmptyModel() throws FileNotFoundException, IOException, URISyntaxException { Map entity = new HashMap<>(); + entity.put(ADCheckpointDao.DETECTOR_ID, ADCheckpointDao.DETECTOR_ID); entity.put(CommonName.TIMESTAMP, Instant.now().toString()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(entity); - Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); - assertTrue(!result.isPresent()); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + assertTrue(result == null); } public void testFromEntityModelCheckpointEmptySamples() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_1.json"); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); - assertTrue(result.isPresent()); - Queue samples = result.get().getKey().getSamples(); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(modelPair.getLeft()); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + assertTrue(result != null); + Deque samples = result.getSamples(); assertEquals(0, samples.size()); } public void testFromEntityModelCheckpointNoRCF() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_3.json"); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); - assertTrue(result.isPresent()); - assertTrue(!result.get().getKey().getTrcf().isPresent()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(modelPair.getLeft()); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + assertTrue(result != null); + assertTrue(result.getModel().isEmpty()); } public void testFromEntityModelCheckpointNoThreshold() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_4.json"); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); - assertTrue(result.isPresent()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(modelPair.getLeft()); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + assertTrue(result != null); - ThresholdedRandomCutForest trcf = result.get().getKey().getTrcf().get(); + ThresholdedRandomCutForest trcf = result.getModel().get(); RandomCutForest forest = trcf.getForest(); assertEquals(1, forest.getDimensions()); assertEquals(10, forest.getNumberOfTrees()); @@ -984,19 +991,18 @@ public void testFromEntityModelCheckpointNoThreshold() throws FileNotFoundExcept } public void testFromEntityModelCheckpointWithEntity() throws Exception { - ModelState state = MLUtil + ModelState state = MLUtil .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).entityAttributes(true).build()); Map content = checkpointDao.toIndexSource(state); // Opensearch will convert from java.time.ZonedDateTime to String. Here I am converting to simulate that content.put(CommonName.TIMESTAMP, "2021-09-23T05:00:37.93195Z"); - Optional> result = checkpointDao.fromEntityModelCheckpoint(content, this.modelId); + ModelState result = checkpointDao + .processHCGetResponse(TestHelpers.createGetResponse(content, modelId, "blah"), this.modelId, ADCheckpointDao.DETECTOR_ID); - assertTrue(result.isPresent()); - Entry pair = result.get(); - EntityModel entityModel = pair.getKey(); - assertTrue(entityModel.getEntity().isPresent()); - assertEquals(state.getModel().getEntity().get(), entityModel.getEntity().get()); + assertTrue(result != null); + assertTrue(result.getEntity().isPresent()); + assertEquals(state.getEntity().get(), result.getEntity().get()); } private double[] getPoint(int dimensions, Random random) { @@ -1067,37 +1073,46 @@ public void testDeserializeTRCFModel() throws Exception { coldStartData.add(sample4); coldStartData.add(sample5); - // This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them - // to the scores generated by the imported RCF3.0-rc2.1 + // This scores were generated with the sample data on RCF4.0. RCF4.0 changed implementation + // and we are seeing different rcf scores between 4.0 and 3.8. This is verified by switching + // rcf version between 3.8 and 4.0 while other code in AD unchanged. But we get different scores. List scores = new ArrayList<>(); - scores.add(4.814651669367903); - scores.add(5.566968073093689); - scores.add(5.919907610660049); - scores.add(5.770278090352401); - scores.add(5.319779117320102); - - List grade = new ArrayList<>(); - grade.add(1.0); - grade.add(0.0); - grade.add(0.0); - grade.add(0.0); - grade.add(0.0); + scores.add(5.052069275347555); + scores.add(6.117465704461799); + scores.add(6.6401649744661055); + scores.add(6.918514609476484); + scores.add(6.928318158276434); + // rcf 3.8 has a number of improvements on thresholder and predictor corrector. // We don't expect the results have the same anomaly grade. for (int i = 0; i < coldStartData.size(); i++) { forest.process(coldStartData.get(i), 0); AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0); - assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9); + assertEquals(scores.get(i), descriptor.getRCFScore(), 1e-9); } } public void testShouldSave() { - assertTrue(!checkpointDao.shouldSave(Instant.MIN, false, null, clock)); - assertTrue(checkpointDao.shouldSave(Instant.ofEpochMilli(Instant.now().toEpochMilli()), true, Duration.ofHours(6), clock)); + ModelState modelState = new ModelState( + null, + modelId, + "123", + ModelManager.ModelType.TRCF.getName(), + clock, + 0.1f, + Optional.empty(), + MLUtil.createQueueSamples(1) + ); + modelState.setLastCheckpointTime(Instant.MIN); + assertTrue(!checkpointDao.shouldSave(modelState, false, null, clock)); + modelState.setLastCheckpointTime(Instant.ofEpochMilli(Instant.now().toEpochMilli())); + assertTrue(checkpointDao.shouldSave(modelState, true, Duration.ofHours(6), clock)); // now + 6 hrs > Instant.now - assertTrue(!checkpointDao.shouldSave(Instant.ofEpochMilli(Instant.now().toEpochMilli()), false, Duration.ofHours(6), clock)); + modelState.setLastCheckpointTime(Instant.ofEpochMilli(Instant.now().toEpochMilli())); + assertTrue(!checkpointDao.shouldSave(modelState, false, Duration.ofHours(6), clock)); // 1658863778000L + 6 hrs < Instant.now - assertTrue(checkpointDao.shouldSave(Instant.ofEpochMilli(1658863778000L), false, Duration.ofHours(6), clock)); + modelState.setLastCheckpointTime(Instant.ofEpochMilli(1658863778000L)); + assertTrue(checkpointDao.shouldSave(modelState, false, Duration.ofHours(6), clock)); } // This test is intended to check if given a checkpoint created by RCF-3.0-rc3 ("rcf_3_0_rc3_single_stream.json") @@ -1133,21 +1148,22 @@ public void testDeserialize_rcf3_rc3_single_stream_model() throws Exception { coldStartData.add(sample4); coldStartData.add(sample5); - // This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them - // to the scores generated by the imported RCF3.0-rc2.1 + // This scores were generated with the sample data on RCF4.0. RCF4.0 changed implementation + // and we are seeing different rcf scores between 4.0 and 3.8. This is verified by switching + // rcf version between 3.8 and 4.0 while other code in AD unchanged. But we get different scores. List scores = new ArrayList<>(); - scores.add(3.3830441158587066); - scores.add(2.825961659490065); - scores.add(2.4685871670647384); - scores.add(2.3123460886413647); - scores.add(2.1401987653477135); + scores.add(3.678754481587072); + scores.add(3.6809634269790252); + scores.add(3.683659822587799); + scores.add(3.6852688612219646); + scores.add(3.6859330728661064); // rcf 3.8 has a number of improvements on thresholder and predictor corrector. // We don't expect the results have the same anomaly grade. for (int i = 0; i < coldStartData.size(); i++) { forest.process(coldStartData.get(i), 0); AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0); - assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9); + assertEquals(scores.get(i), descriptor.getRCFScore(), 1e-9); } } @@ -1190,21 +1206,22 @@ public void testDeserialize_rcf3_rc3_hc_model() throws Exception { coldStartData.add(sample4); coldStartData.add(sample5); - // This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them - // to the scores generated by the imported RCF3.0-rc2.1 + // This scores were generated with the sample data but on RCF4.0 that changed implementation + // and we are seeing different rcf scores between 4.0 and 3.8. This is verified by switching + // rcf version between 3.8 and 4.0 while other code in AD unchanged. But we get different scores. List scores = new ArrayList<>(); - scores.add(1.86645896573027); - scores.add(1.8760247712797833); - scores.add(1.6809181763279901); - scores.add(1.7126716645678555); - scores.add(1.323776514074674); + scores.add(2.119532552959117); + scores.add(2.7347456872746325); + scores.add(3.066704948143919); + scores.add(3.2965580521876725); + scores.add(3.1888920146607047); // rcf 3.8 has a number of improvements on thresholder and predictor corrector. // We don't expect the results have the same anomaly grade. for (int i = 0; i < coldStartData.size(); i++) { forest.process(coldStartData.get(i), 0); AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0); - assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9); + assertEquals(scores.get(i), descriptor.getRCFScore(), 1e-9); } } @@ -1234,4 +1251,24 @@ public static String unescapeJavaString(String st) { } return sb.toString(); } + + public void testProcessEmptyCheckpoint() throws IOException { + String modelId = "abc"; + ModelState modelState = checkpointDao + .processHCGetResponse(TestHelpers.createBrokenGetResponse(modelId, "blah"), modelId, "123"); + assertEquals(null, modelState); + } + + public void testNonEmptyCheckpoint() throws IOException { + String modelId = "abc"; + ModelState inputModelState = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + + Map source = checkpointDao.toIndexSource(inputModelState); + ModelState modelState = checkpointDao + .processHCGetResponse(TestHelpers.createGetResponse(source, modelId, "blah"), modelId, "123"); + assertEquals(now, modelState.getLastCheckpointTime()); + assertEquals(inputModelState.getSamples().size(), modelState.getSamples().size()); + assertEquals(now, modelState.getLastUsedTime()); + } } diff --git a/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java b/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java index dcda1ff92..c9578ff50 100644 --- a/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java +++ b/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java @@ -17,6 +17,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.time.Clock; import java.util.Arrays; import java.util.Collections; import java.util.Locale; @@ -35,6 +36,7 @@ import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.ScrollableHitSource; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.ml.CheckpointDao; import org.opensearch.timeseries.util.ClientUtil; import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; @@ -60,7 +62,7 @@ private enum DeleteExecutionMode { PARTIAL_FAILURE } - private CheckpointDao checkpointDao; + private ADCheckpointDao checkpointDao; private Client client; private ClientUtil clientUtil; private Gson gson; @@ -77,12 +79,14 @@ private enum DeleteExecutionMode { double anomalyRate; + private Clock clock; + @SuppressWarnings("unchecked") @Override @Before public void setUp() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(CheckpointDao.class); + super.setUpLog4jForJUnit(ADCheckpointDao.class); client = mock(Client.class); clientUtil = mock(ClientUtil.class); @@ -97,10 +101,10 @@ public void setUp() throws Exception { objectPool = mock(GenericObjectPool.class); int deserializeRCFBufferSize = 512; anomalyRate = 0.005; - checkpointDao = new CheckpointDao( + clock = mock(Clock.class); + checkpointDao = new ADCheckpointDao( client, clientUtil, - ADCommonName.CHECKPOINT_INDEX_NAME, gson, mapper, converter, @@ -111,7 +115,8 @@ public void setUp() throws Exception { maxCheckpointBytes, objectPool, deserializeRCFBufferSize, - anomalyRate + anomalyRate, + clock ); } @@ -157,7 +162,7 @@ public void delete_by_detector_id_template(DeleteExecutionMode mode) { return null; }).when(client).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); - checkpointDao.deleteModelCheckpointByDetectorId(detectorId); + checkpointDao.deleteModelCheckpointByConfigId(detectorId); } public void testDeleteSingleNormal() throws Exception { @@ -172,7 +177,7 @@ public void testDeleteSingleIndexNotFound() throws Exception { public void testDeleteSingleResultFailure() throws Exception { delete_by_detector_id_template(DeleteExecutionMode.FAILURE); - assertTrue(testAppender.containsMessage(CheckpointDao.NOT_ABLE_TO_DELETE_LOG_MSG)); + assertTrue(testAppender.containsMessage(CheckpointDao.NOT_ABLE_TO_DELETE_CHECKPOINT_MSG)); } public void testDeleteSingleResultPartialFailure() throws Exception { diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index 188146f69..ec8fb8c35 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -29,10 +29,10 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; +import java.util.Deque; import java.util.List; import java.util.Map.Entry; import java.util.Optional; -import java.util.Queue; import java.util.Random; import java.util.Set; import java.util.concurrent.CountDownLatch; @@ -43,8 +43,6 @@ import org.junit.BeforeClass; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; @@ -59,9 +57,17 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.config.ForestMode; import com.amazon.randomcutforest.config.Precision; import com.amazon.randomcutforest.config.TransformMethod; import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; @@ -105,85 +111,156 @@ public void tearDown() throws Exception { // train using samples directly public void testTrainUsingSamples() throws InterruptedException { - Queue samples = MLUtil.createQueueSamples(numMinSamples); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); - assertTrue(model.getTrcf().isPresent()); - ThresholdedRandomCutForest ercf = model.getTrcf().get(); + Deque samples = MLUtil.createQueueSamples(numMinSamples); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); + assertTrue(modelState.getModel().isPresent()); + ThresholdedRandomCutForest ercf = modelState.getModel().get(); assertEquals(numMinSamples, ercf.getForest().getTotalUpdates()); checkSemaphoreRelease(); } public void testColdStart() throws InterruptedException, IOException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + // By default startNormalization is 10, thus we won't see total updates until the 10th point + numMinSamples = 10; + shingleSize = 8; + entityColdStarter = new ADColdStart( + clock, + threadPool, + stateManager, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + searchFeatureDao, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + featureManager, + // settings, + TimeSeriesSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 + ); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState<>( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); + long startTime = 1602269260000L; doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); - listener.onResponse(Optional.of(1602269260000L)); + listener.onResponse(Optional.of(startTime)); return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); List> coldStartSamples = new ArrayList<>(); + for (int i = 0; i < 11; i++) { + if (i == 3) { + coldStartSamples.add(Optional.empty()); + } else { + coldStartSamples.add(Optional.of(new double[] { i })); + } + } - double[] sample1 = new double[] { 57.0 }; - double[] sample2 = new double[] { 1.0 }; - double[] sample3 = new double[] { -19.0 }; - - coldStartSamples.add(Optional.of(sample1)); - coldStartSamples.add(Optional.of(sample2)); - coldStartSamples.add(Optional.of(sample3)); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(5); listener.onResponse(coldStartSamples); return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); - checkSemaphoreRelease(); - - assertTrue(model.getTrcf().isPresent()); - ThresholdedRandomCutForest ercf = model.getTrcf().get(); - // 1 round: stride * (samples - 1) + 1 = 60 * 2 + 1 = 121 - // plus 1 existing sample - assertEquals(121, ercf.getForest().getTotalUpdates()); - assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); - + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + startTime + coldStartSamples.size() * detector.getIntervalInMilliseconds(), + entity, + "123" + ); + resetListener(); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - released.set(false); - // too frequent cold start of the same detector will fail - samples = MLUtil.createQueueSamples(1); - model = new EntityModel(entity, samples, null); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + assertTrue(modelState.getModel().isPresent()); + ThresholdedRandomCutForest ercf = modelState.getModel().get(); + + assertEquals(coldStartSamples.size(), ercf.getForest().getTotalUpdates()); + assertTrue("size: " + modelState.getSamples().size(), modelState.getSamples().isEmpty()); + + List expectedColdStartData = new ArrayList<>(); + long currentStartTimeMillis = startTime; + for (int i = 0; i < 11; i++) { + if (i != 3) { + expectedColdStartData + .add( + new Sample( + new double[] { i }, + Instant.ofEpochMilli(currentStartTimeMillis), + Instant.ofEpochMilli(currentStartTimeMillis + detector.getIntervalInMilliseconds()) + ) + ); + } + currentStartTimeMillis += detector.getIntervalInMilliseconds(); + } - assertFalse(model.getTrcf().isPresent()); - // the samples is not touched since cold start does not happen - assertEquals("size: " + model.getSamples().size(), 1, model.getSamples().size()); - checkSemaphoreRelease(); + diffTesting(modelState, expectedColdStartData); - List expectedColdStartData = new ArrayList<>(); + for (int i = 0; i <= TimeSeriesSettings.COLD_START_DOOR_KEEPER_COUNT_THRESHOLD; i++) { + resetListener(); + modelState = createStateForCacheRelease(); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); + checkSemaphoreRelease(); + } - // for function interpolate: - // 1st parameter is a matrix of size numFeatures * numSamples - // 2nd parameter is the number of interpolants including two samples - double[][] interval1 = imputer.impute(new double[][] { new double[] { sample1[0], sample2[0] } }, 61); - expectedColdStartData.addAll(convertToFeatures(interval1, 60)); - double[][] interval2 = imputer.impute(new double[][] { new double[] { sample2[0], sample3[0] } }, 61); - expectedColdStartData.addAll(convertToFeatures(interval2, 61)); - assertEquals(121, expectedColdStartData.size()); + // model is not trained as the door keeper remembers it after TimeSeriesSettings.DOOR_KEEPER_COUNT_THRESHOLD retries and won't retry + // training + assertTrue(modelState.getModel().isEmpty()); - diffTesting(modelState, expectedColdStartData); + // the samples is not touched since cold start does not happen + assertEquals("size: " + modelState.getSamples().size(), 1, modelState.getSamples().size()); } // min max: miss one public void testMissMin() throws IOException, InterruptedException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); @@ -191,11 +268,20 @@ public void testMissMin() throws IOException, InterruptedException { return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); verify(searchFeatureDao, never()).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - assertTrue(!model.getTrcf().isPresent()); + assertTrue(modelState.getModel().isEmpty()); checkSemaphoreRelease(); } @@ -204,10 +290,10 @@ public void testMissMin() throws IOException, InterruptedException { * @param modelState an initialized model state * @param coldStartData cold start data that initialized the modelState */ - private void diffTesting(ModelState modelState, List coldStartData) { + private void diffTesting(ModelState modelState, List coldStartData) { int inputDimension = detector.getEnabledFeatureIds().size(); - ThresholdedRandomCutForest refTRcf = ThresholdedRandomCutForest + ThresholdedRandomCutForest.Builder refTRcfBuilder = ThresholdedRandomCutForest .builder() .compact(true) .dimensions(inputDimension * detector.getShingleSize()) @@ -216,35 +302,56 @@ private void diffTesting(ModelState modelState, List cold .numberOfTrees(TimeSeriesSettings.NUM_TREES) .shingleSize(detector.getShingleSize()) .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) - .timeDecay(TimeSeriesSettings.TIME_DECAY) - .outputAfter(numMinSamples) - .initialAcceptFraction(0.125d) + .timeDecay(detector.getTimeDecay()) + .transformDecay(detector.getTimeDecay()) + .outputAfter(Math.max(detector.getShingleSize(), numMinSamples)) + .initialAcceptFraction(numMinSamples * 1.0d / TimeSeriesSettings.NUM_SAMPLES_PER_TREE) .parallelExecutionEnabled(false) .sampleSize(TimeSeriesSettings.NUM_SAMPLES_PER_TREE) .internalShinglingEnabled(true) .anomalyRate(1 - TimeSeriesSettings.THRESHOLD_MIN_PVALUE) .transformMethod(TransformMethod.NORMALIZE) .alertOnce(true) - .autoAdjust(true) - .build(); + .autoAdjust(true); + + if (detector.getShingleSize() > 1) { + refTRcfBuilder.forestMode(ForestMode.STREAMING_IMPUTE); + refTRcfBuilder = ModelColdStart.applyImputationMethod(detector, refTRcfBuilder); + } else { + // imputation with shingle size 1 is not meaningful + refTRcfBuilder.forestMode(ForestMode.STANDARD); + } + ThresholdedRandomCutForest refTRcf = refTRcfBuilder.build(); + + long lastSampleEndTime = 0; for (int i = 0; i < coldStartData.size(); i++) { - refTRcf.process(coldStartData.get(i), 0); + Sample sample = coldStartData.get(i); + lastSampleEndTime = sample.getDataEndTime().getEpochSecond(); + refTRcf.process(sample.getValueList(), lastSampleEndTime); } + assertEquals( - "Expect " + coldStartData.size() + " but got " + refTRcf.getForest().getTotalUpdates(), - coldStartData.size(), - refTRcf.getForest().getTotalUpdates() + refTRcf.getForest().getTotalUpdates() + " != " + modelState.getModel().get().getForest().getTotalUpdates(), + refTRcf.getForest().getTotalUpdates(), + modelState.getModel().get().getForest().getTotalUpdates() ); Random r = new Random(); + // ThresholdedRandomCutForest refTRcf3 = modelState.getModel().get(); // make sure we trained the expected models for (int i = 0; i < 100; i++) { + lastSampleEndTime += detector.getIntervalInSeconds(); double[] point = r.ints(inputDimension, 0, 50).asDoubleStream().toArray(); - AnomalyDescriptor descriptor = refTRcf.process(point, 0); - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(point, modelState, modelId, entity, detector.getShingleSize()); + assertEquals(refTRcf.getForest().getTotalUpdates(), modelState.getModel().get().getForest().getTotalUpdates()); + AnomalyDescriptor descriptor = refTRcf.process(point, lastSampleEndTime); + Sample sample = new Sample( + point, + Instant.ofEpochSecond(lastSampleEndTime - detector.getIntervalInSeconds()), + Instant.ofEpochSecond(lastSampleEndTime) + ); + ThresholdingResult result = modelManager.getResult(sample, modelState, modelId, Optional.of(entity), detector, "123"); assertEquals(descriptor.getRCFScore(), result.getRcfScore(), 1e-10); assertEquals(descriptor.getAnomalyGrade(), result.getGrade(), 1e-10); } @@ -268,120 +375,208 @@ private List convertToFeatures(double[][] interval, int numValsToKeep) // two segments of samples, one segment has 3 samples, while another one has only 1 public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOException { - Queue samples = MLUtil.createQueueSamples(1); - double[] savedSample = samples.peek(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + // By default startNormalization is 10, thus we won't see total updates until the 10th point + numMinSamples = 10; + shingleSize = 8; + entityColdStarter = new ADColdStart( + clock, + threadPool, + stateManager, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + searchFeatureDao, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + featureManager, + // settings, + TimeSeriesSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 + ); + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(detectorInterval, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .setShingleSize(shingleSize) + .build(); + + Deque samples = MLUtil.createQueueSamples(0); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); + long startTime = 1602269260000L; doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); - listener.onResponse(Optional.of(1602269260000L)); + listener.onResponse(Optional.of(startTime)); return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); List> coldStartSamples = new ArrayList<>(); - double[] sample1 = new double[] { 57.0 }; - double[] sample2 = new double[] { 1.0 }; - double[] sample3 = new double[] { -19.0 }; - double[] sample5 = new double[] { -17.0 }; - coldStartSamples.add(Optional.of(sample1)); - coldStartSamples.add(Optional.of(sample2)); - coldStartSamples.add(Optional.of(sample3)); - coldStartSamples.add(Optional.empty()); - coldStartSamples.add(Optional.of(sample5)); + + for (int i = 0; i < 11; i++) { + if (i == 3) { + coldStartSamples.add(Optional.empty()); + } else { + coldStartSamples.add(Optional.of(new double[] { i })); + } + } doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(5); listener.onResponse(coldStartSamples); return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); - checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); - - // 1 round: stride * (samples - 1) + 1 = 60 * 4 + 1 = 241 - // if 241 < shingle size + numMinSamples, then another round is performed - assertEquals(241, modelState.getModel().getTrcf().get().getForest().getTotalUpdates()); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + new double[] { 0 }, + startTime + coldStartSamples.size() * detector.getIntervalInMilliseconds(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); + assertTrue(modelState.getModel().isPresent()); + + // imputed value are counted as one update. So we have 11 values in total including the one missing value. + assertEquals(11, modelState.getModel().get().getForest().getTotalUpdates()); + + List expectedColdStartData = new ArrayList<>(); + long currentStartTimeMillis = startTime; + for (int i = 0; i < 11; i++) { + if (i != 3) { + expectedColdStartData + .add( + new Sample( + new double[] { i }, + Instant.ofEpochMilli(currentStartTimeMillis), + Instant.ofEpochMilli(currentStartTimeMillis + detector.getIntervalInMilliseconds()) + ) + ); + } + currentStartTimeMillis += detector.getIntervalInMilliseconds(); + } - List expectedColdStartData = new ArrayList<>(); - - // for function interpolate: - // 1st parameter is a matrix of size numFeatures * numSamples - // 2nd parameter is the number of interpolants including two samples - double[][] interval1 = imputer.impute(new double[][] { new double[] { sample1[0], sample2[0] } }, 61); - expectedColdStartData.addAll(convertToFeatures(interval1, 60)); - double[][] interval2 = imputer.impute(new double[][] { new double[] { sample2[0], sample3[0] } }, 61); - expectedColdStartData.addAll(convertToFeatures(interval2, 60)); - double[][] interval3 = imputer.impute(new double[][] { new double[] { sample3[0], sample5[0] } }, 121); - expectedColdStartData.addAll(convertToFeatures(interval3, 121)); - assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); - assertEquals(241, expectedColdStartData.size()); diffTesting(modelState, expectedColdStartData); } // two segments of samples, one segment has 3 samples, while another one 2 samples public void testTwoSegments() throws InterruptedException, IOException { - Queue samples = MLUtil.createQueueSamples(1); - double[] savedSample = samples.peek(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + // By default startNormalization is 10, thus we won't see total updates until the 10th point + numMinSamples = 10; + shingleSize = 8; + entityColdStarter = new ADColdStart( + clock, + threadPool, + stateManager, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + searchFeatureDao, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + featureManager, + // settings, + TimeSeriesSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 + ); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); + long startTime = 1602269260000L; doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); - listener.onResponse(Optional.of(1602269260000L)); + listener.onResponse(Optional.of(startTime)); return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); List> coldStartSamples = new ArrayList<>(); - double[] sample1 = new double[] { 57.0 }; - double[] sample2 = new double[] { 1.0 }; - double[] sample3 = new double[] { -19.0 }; - double[] sample5 = new double[] { -17.0 }; - double[] sample6 = new double[] { -38.0 }; - coldStartSamples.add(Optional.of(new double[] { 57.0 })); - coldStartSamples.add(Optional.of(new double[] { 1.0 })); - coldStartSamples.add(Optional.of(new double[] { -19.0 })); - coldStartSamples.add(Optional.empty()); - coldStartSamples.add(Optional.of(new double[] { -17.0 })); - coldStartSamples.add(Optional.of(new double[] { -38.0 })); + for (int i = 0; i < 11; i++) { + if (i == 3) { + coldStartSamples.add(Optional.empty()); + } else { + coldStartSamples.add(Optional.of(new double[] { i })); + } + } doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(5); listener.onResponse(coldStartSamples); return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + startTime + coldStartSamples.size() * detector.getIntervalInMilliseconds(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); - ThresholdedRandomCutForest ercf = model.getTrcf().get(); - // 1 rounds: stride * (samples - 1) + 1 = 60 * 5 + 1 = 301 - assertEquals(301, ercf.getForest().getTotalUpdates()); + assertTrue(modelState.getModel().isPresent()); + ThresholdedRandomCutForest ercf = modelState.getModel().get(); + assertEquals(coldStartSamples.size(), ercf.getForest().getTotalUpdates()); checkSemaphoreRelease(); - List expectedColdStartData = new ArrayList<>(); - - // for function interpolate: - // 1st parameter is a matrix of size numFeatures * numSamples - // 2nd parameter is the number of interpolants including two samples - double[][] interval1 = imputer.impute(new double[][] { new double[] { sample1[0], sample2[0] } }, 61); - expectedColdStartData.addAll(convertToFeatures(interval1, 60)); - double[][] interval2 = imputer.impute(new double[][] { new double[] { sample2[0], sample3[0] } }, 61); - expectedColdStartData.addAll(convertToFeatures(interval2, 60)); - double[][] interval3 = imputer.impute(new double[][] { new double[] { sample3[0], sample5[0] } }, 121); - expectedColdStartData.addAll(convertToFeatures(interval3, 120)); - double[][] interval4 = imputer.impute(new double[][] { new double[] { sample5[0], sample6[0] } }, 61); - expectedColdStartData.addAll(convertToFeatures(interval4, 61)); - assertEquals(301, expectedColdStartData.size()); - assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); + List expectedColdStartData = new ArrayList<>(); + long currentStartTimeMillis = startTime; + for (int i = 0; i < 11; i++) { + if (i != 3) { + expectedColdStartData + .add( + new Sample( + new double[] { i }, + Instant.ofEpochMilli(currentStartTimeMillis), + Instant.ofEpochMilli(currentStartTimeMillis + detector.getIntervalInMilliseconds()) + ) + ); + } + currentStartTimeMillis += detector.getIntervalInMilliseconds(); + } diffTesting(modelState, expectedColdStartData); } public void testThrottledColdStart() throws InterruptedException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); @@ -389,9 +584,18 @@ public void testThrottledColdStart() throws InterruptedException { return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); - entityColdStarter.trainModel(entity, "456", modelState, listener); + entityColdStarter.trainModel(featureRequest, "456", modelState, listener); // only the first one makes the call verify(searchFeatureDao, times(1)).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); @@ -399,9 +603,17 @@ public void testThrottledColdStart() throws InterruptedException { } public void testColdStartException() throws InterruptedException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); @@ -409,7 +621,16 @@ public void testColdStartException() throws InterruptedException { return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); assertTrue(stateManager.fetchExceptionAndClear(detectorId).isPresent()); checkSemaphoreRelease(); @@ -417,9 +638,17 @@ public void testColdStartException() throws InterruptedException { @SuppressWarnings("unchecked") public void testNotEnoughSamples() throws InterruptedException, IOException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); detector = TestHelpers.AnomalyDetectorBuilder .newInstance() @@ -434,9 +663,10 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { return null; }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + long startTime = 1602269260000L; doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); - listener.onResponse(Optional.of(1602269260000L)); + listener.onResponse(Optional.of(startTime)); return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); @@ -449,17 +679,25 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + startTime + coldStartSamples.size() * detector.getIntervalInMilliseconds(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(!model.getTrcf().isPresent()); - // 1st round we add 57 and 1. - // 2nd round we add 57 and 1. - Queue currentSamples = model.getSamples(); - assertEquals("real sample size is " + currentSamples.size(), 4, currentSamples.size()); + assertTrue(modelState.getModel().isEmpty()); + // not enough smples to train. We keep them in the sample array of model state. + Deque currentSamples = modelState.getSamples(); + assertEquals("real sample size is " + currentSamples.size(), 2, currentSamples.size()); int j = 0; while (!currentSamples.isEmpty()) { - double[] element = currentSamples.poll(); + double[] element = currentSamples.poll().getValueList(); assertEquals(1, element.length); if (j == 0 || j == 2) { assertEquals(57, element[0], 1e-10); @@ -472,51 +710,146 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { @SuppressWarnings("unchecked") public void testEmptyDataRange() throws InterruptedException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + // the min-max range has 5 samples is too small and thus no model can be initialized as we require at least 32 + numMinSamples = 32; + entityColdStarter = new ADColdStart( + clock, + threadPool, + stateManager, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + searchFeatureDao, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + featureManager, + // settings, + TimeSeriesSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 + ); + + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); - // the min-max range 894056973000L~894057860000L is too small and thus no data range can be found - when(clock.millis()).thenReturn(894057860000L); + // when(clock.millis()).thenReturn(894057860000L); doAnswer(invocation -> { - GetRequest request = invocation.getArgument(0); ActionListener listener = invocation.getArgument(2); listener.onResponse(TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX)); return null; }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + long startTime = 894056973000L; doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); - listener.onResponse(Optional.of(894056973000L)); + listener.onResponse(Optional.of(startTime)); return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + List> coldStartSamples = new ArrayList<>(); + + double[] sample1 = new double[] { 57.0 }; + double[] sample2 = new double[] { 1.0 }; + double[] sample3 = new double[] { -19.0 }; + + coldStartSamples.add(Optional.of(sample1)); + coldStartSamples.add(Optional.empty()); + coldStartSamples.add(Optional.of(sample2)); + coldStartSamples.add(Optional.empty()); + coldStartSamples.add(Optional.of(sample3)); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(5); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); + + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + startTime + coldStartSamples.size() * detector.getIntervalInMilliseconds(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(!model.getTrcf().isPresent()); - // the min-max range is too small and thus no data range can be found - assertEquals("real sample size is " + model.getSamples().size(), 1, model.getSamples().size()); + assertTrue(modelState.getModel().isEmpty()); + // the min-max range is too small and thus not enough training data + // 3 from history + assertEquals("real sample size is " + modelState.getSamples().size(), 3, modelState.getSamples().size()); } public void testTrainModelFromExistingSamplesEnoughSamples() { int inputDimension = 2; - int dimensions = inputDimension * detector.getShingleSize(); + // less than 10 will make rcf results undeterministic even though two rcf models have the same rcfSeed + numMinSamples = 10; + + // reinitialize entityColdStarter and modelManager using new numMinSamples + entityColdStarter = new ADColdStart( + clock, + threadPool, + stateManager, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + searchFeatureDao, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + featureManager, + // settings, + TimeSeriesSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 + ); + modelManager = new ADModelManager( + mock(ADCheckpointDao.class), + mock(Clock.class), + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_MIN_SAMPLES, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + TimeSeriesSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class), + settings, + clusterService + ); - ThresholdedRandomCutForest.Builder rcfConfig = ThresholdedRandomCutForest + ThresholdedRandomCutForest.Builder rcfConfig = ThresholdedRandomCutForest .builder() .compact(true) - .dimensions(dimensions) + .dimensions(inputDimension * detector.getShingleSize()) .precision(Precision.FLOAT_32) .randomSeed(rcfSeed) .numberOfTrees(TimeSeriesSettings.NUM_TREES) .shingleSize(detector.getShingleSize()) .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) - .timeDecay(TimeSeriesSettings.TIME_DECAY) - .outputAfter(numMinSamples) - .initialAcceptFraction(0.125d) + .timeDecay(detector.getTimeDecay()) + .transformDecay(detector.getTimeDecay()) + .outputAfter(Math.max(detector.getShingleSize(), numMinSamples)) + .initialAcceptFraction(numMinSamples * 1.0d / TimeSeriesSettings.NUM_SAMPLES_PER_TREE) .parallelExecutionEnabled(false) .sampleSize(TimeSeriesSettings.NUM_SAMPLES_PER_TREE) .internalShinglingEnabled(true) @@ -524,32 +857,59 @@ public void testTrainModelFromExistingSamplesEnoughSamples() { .transformMethod(TransformMethod.NORMALIZE) .alertOnce(true) .autoAdjust(true); - Tuple, ThresholdedRandomCutForest> models = MLUtil.prepareModel(inputDimension, rcfConfig); - Queue samples = models.v1(); + + if (detector.getShingleSize() > 1) { + rcfConfig.forestMode(ForestMode.STREAMING_IMPUTE); + rcfConfig = ModelColdStart.applyImputationMethod(detector, rcfConfig); + } else { + // imputation with shingle size 1 is not meaningful + rcfConfig.forestMode(ForestMode.STANDARD); + } + Tuple, ThresholdedRandomCutForest> models = MLUtil + .prepareModel(inputDimension, rcfConfig, detector.getIntervalInMilliseconds()); + Deque samples = models.v1(); ThresholdedRandomCutForest rcf = models.v2(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); Random r = new Random(); // make sure we trained the expected models + Instant currentTime = samples.getLast().getDataEndTime().plusMillis(detector.getIntervalInMilliseconds()); for (int i = 0; i < 100; i++) { double[] point = r.ints(inputDimension, 0, 50).asDoubleStream().toArray(); - AnomalyDescriptor descriptor = rcf.process(point, 0); - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(point, modelState, modelId, entity, detector.getShingleSize()); + AnomalyDescriptor descriptor = rcf.process(point, currentTime.getEpochSecond()); + Sample sample = new Sample(point, currentTime.minusMillis(detector.getIntervalInMilliseconds()), currentTime); + ThresholdingResult result = modelManager.getResult(sample, modelState, modelId, Optional.of(entity), detector, "123"); assertEquals(descriptor.getRCFScore(), result.getRcfScore(), 1e-10); assertEquals(descriptor.getAnomalyGrade(), result.getGrade(), 1e-10); + currentTime = currentTime.plusMillis(detector.getIntervalInMilliseconds()); } } public void testTrainModelFromExistingSamplesNotEnoughSamples() { - Queue samples = new ArrayDeque<>(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); - entityColdStarter.trainModelFromExistingSamples(modelState, detector.getShingleSize()); - assertTrue(!modelState.getModel().getTrcf().isPresent()); + Deque samples = new ArrayDeque<>(); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); + entityColdStarter.trainModelFromExistingSamples(modelState, Optional.of(entity), detector, "123"); + assertTrue(modelState.getModel().isEmpty()); } @SuppressWarnings("unchecked") @@ -630,8 +990,16 @@ public int compare(Entry p1, Entry p2) { return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); - modelState = new ModelState<>(model, modelId, detector.getId(), ModelType.ENTITY.getName(), clock, priority); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + new ArrayDeque<>() + ); released = new AtomicBoolean(); @@ -641,19 +1009,28 @@ public int compare(Entry p1, Entry p2) { inProgressLatch.countDown(); }); - entityColdStarter.trainModel(entity, detector.getId(), modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + new double[] { 1.3 }, + Instant.now().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detector.getId(), modelState, listener); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); int tp = 0; int fp = 0; int fn = 0; long[] changeTimestamps = dataWithKeys.changeTimeStampsMs; - for (int j = trainTestSplit; j < data.length; j++) { - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + for (int j = trainTestSplit + 1; j < data.length; j++) { + Sample sample = new Sample(data[j], Instant.ofEpochMilli(timestamps[j] - delta), Instant.ofEpochMilli(timestamps[j])); + ThresholdingResult result = modelManager.getResult(sample, modelState, modelId, Optional.of(entity), detector, "123"); if (result.getGrade() > 0) { if (changeTimestamps[j] == 0) { fp++; @@ -698,68 +1075,36 @@ public void testAccuracyThirteenMinuteInterval() throws Exception { accuracyTemplate(13, 0.5f, 0.5f); } - public void testAccuracyOneMinuteIntervalNoInterpolation() throws Exception { - ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false); - // for one minute interval, we need to disable interpolation to achieve good results - entityColdStarter = new EntityColdStarter( - clock, - threadPool, - stateManager, - TimeSeriesSettings.NUM_SAMPLES_PER_TREE, - TimeSeriesSettings.NUM_TREES, - TimeSeriesSettings.TIME_DECAY, - numMinSamples, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - imputer, - searchFeatureDao, - TimeSeriesSettings.THRESHOLD_MIN_PVALUE, - featureManager, - settings, - TimeSeriesSettings.HOURLY_MAINTENANCE, - checkpointWriteQueue, - rcfSeed, - TimeSeriesSettings.MAX_COLD_START_ROUNDS - ); - - modelManager = new ModelManager( - mock(CheckpointDao.class), - mock(Clock.class), - TimeSeriesSettings.NUM_TREES, - TimeSeriesSettings.NUM_SAMPLES_PER_TREE, - TimeSeriesSettings.TIME_DECAY, - TimeSeriesSettings.NUM_MIN_SAMPLES, - TimeSeriesSettings.THRESHOLD_MIN_PVALUE, - AnomalyDetectorSettings.MIN_PREVIEW_SIZE, - TimeSeriesSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, - entityColdStarter, - mock(FeatureManager.class), - mock(MemoryTracker.class), - settings, - clusterService - ); - - accuracyTemplate(1, 0.6f, 0.6f); + public void testAccuracyOneMinuteInterval() throws Exception { + accuracyTemplate(1, 0.5f, 0.5f); } - private ModelState createStateForCacheRelease() { + private ModelState createStateForCacheRelease() { inProgressLatch = new CountDownLatch(1); releaseSemaphore = () -> { released.set(true); inProgressLatch.countDown(); }; listener = ActionListener.wrap(releaseSemaphore); - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - return new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + return new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); } public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedException { - ModelState modelState = createStateForCacheRelease(); + ModelState modelState = createStateForCacheRelease(); + long minTime = 1602269260000L; doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); - listener.onResponse(Optional.of(1602269260000L)); + listener.onResponse(Optional.of(minTime)); return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); @@ -770,7 +1115,9 @@ public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedEx double[] sample3 = new double[] { -19.0 }; coldStartSamples.add(Optional.of(sample1)); + coldStartSamples.add(Optional.empty()); coldStartSamples.add(Optional.of(sample2)); + coldStartSamples.add(Optional.empty()); coldStartSamples.add(Optional.of(sample3)); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(5); @@ -778,15 +1125,30 @@ public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedEx return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + new double[] { 1.3 }, + // detectorInterval is of minutes, need to convert to milliseconds + minTime + coldStartSamples.size() * detectorInterval * 60000, + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); - modelState = createStateForCacheRelease(); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); - checkSemaphoreRelease(); - // model is not trained as the door keeper remembers it and won't retry training - assertTrue(!modelState.getModel().getTrcf().isPresent()); + for (int i = 0; i <= TimeSeriesSettings.COLD_START_DOOR_KEEPER_COUNT_THRESHOLD; i++) { + resetListener(); + modelState = createStateForCacheRelease(); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); + checkSemaphoreRelease(); + } + + // model is not trained as the door keeper remembers it after TimeSeriesSettings.DOOR_KEEPER_COUNT_THRESHOLD retries and won't retry + // training + assertTrue(modelState.getModel().isEmpty()); // make sure when the next maintenance coming, current door keeper gets reset // note our detector interval is 1 minute and the door keeper will expire in 60 intervals, which are 60 minutes @@ -794,17 +1156,31 @@ public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedEx entityColdStarter.maintenance(); modelState = createStateForCacheRelease(); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + new double[] { 1.3 }, + // detectorInterval is of minutes, need to convert to milliseconds + // important to let the test pass as we only mocked to return 5 results including empty ones. + // We have to match the start and end time of training data + minTime + coldStartSamples.size() * detectorInterval * 60000, + entity, + "123" + ); + resetListener(); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); // model is trained as the door keeper gets reset - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); } public void testCacheReleaseAfterClear() throws IOException, InterruptedException { - ModelState modelState = createStateForCacheRelease(); + long startTime = 1602269260000L; + ModelState modelState = createStateForCacheRelease(); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); - listener.onResponse(Optional.of(1602269260000L)); + listener.onResponse(Optional.of(startTime)); return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); @@ -823,16 +1199,87 @@ public void testCacheReleaseAfterClear() throws IOException, InterruptedExceptio return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + new double[] { 1.3 }, + startTime + coldStartSamples.size() * detector.getIntervalInMilliseconds(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); entityColdStarter.clear(detectorId); modelState = createStateForCacheRelease(); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); // model is trained as the door keeper is regenerated after clearance - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); + } + + public void testNotEnoughTrainingData() throws IOException, InterruptedException { + // we only have 3 samples and thus it is not enough to initialize the model. + numMinSamples = 4; + + entityColdStarter = new ADColdStart( + clock, + threadPool, + stateManager, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + searchFeatureDao, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + featureManager, + // settings, + TimeSeriesSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 + ); + + ModelState modelState = createStateForCacheRelease(); + long minTime = 1602269260000L; + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(Optional.of(minTime)); + return null; + }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); + + List> coldStartSamples = new ArrayList<>(); + + double[] sample1 = new double[] { 57.0 }; + double[] sample2 = new double[] { 1.0 }; + double[] sample3 = new double[] { -19.0 }; + + coldStartSamples.add(Optional.of(sample1)); + coldStartSamples.add(Optional.of(sample2)); + coldStartSamples.add(Optional.of(sample3)); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(5); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); + + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + new double[] { 1.3 }, + // detectorInterval is of minutes, need to convert to milliseconds + minTime + coldStartSamples.size() * detectorInterval * 60000, + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); + checkSemaphoreRelease(); + assertTrue(modelState.getModel().isEmpty()); } } diff --git a/src/test/java/org/opensearch/ad/ml/EntityModelTests.java b/src/test/java/org/opensearch/ad/ml/EntityModelTests.java index 1f4afe829..22b3ed5db 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityModelTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityModelTests.java @@ -11,63 +11,69 @@ package org.opensearch.ad.ml; +import java.time.Clock; +import java.time.Instant; import java.util.ArrayDeque; import org.junit.Before; import org.junit.Test; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; public class EntityModelTests extends OpenSearchTestCase { private ThresholdedRandomCutForest trcf; + private Clock clock; @Before public void setup() { this.trcf = new ThresholdedRandomCutForest(ThresholdedRandomCutForest.builder().dimensions(2).internalShinglingEnabled(true)); + this.clock = Clock.systemUTC(); } public void testNullInternalSampleQueue() { - EntityModel model = new EntityModel(null, null, null); - model.addSample(new double[] { 0.8 }); + ModelState model = new ModelState<>(null, null, null, null, clock, 0, null, null); + model.addSample(new Sample(new double[] { 0.8 }, Instant.now(), Instant.now())); assertEquals(1, model.getSamples().size()); } public void testNullInputSample() { - EntityModel model = new EntityModel(null, null, null); + ModelState model = new ModelState<>(null, null, null, null, clock, 0, null, null); model.addSample(null); assertEquals(0, model.getSamples().size()); } public void testEmptyInputSample() { - EntityModel model = new EntityModel(null, null, null); - model.addSample(new double[] {}); + ModelState model = new ModelState<>(null, null, null, null, clock, 0, null, null); + model.addSample(new Sample(new double[] {}, Instant.now(), Instant.now())); assertEquals(0, model.getSamples().size()); } @Test public void trcf_constructor() { - EntityModel em = new EntityModel(null, new ArrayDeque<>(), trcf); - assertEquals(trcf, em.getTrcf().get()); + ModelState em = new ModelState<>(trcf, null, null, null, clock, 0, null, new ArrayDeque<>()); + assertEquals(trcf, em.getModel().get()); } @Test public void clear() { - EntityModel em = new EntityModel(null, new ArrayDeque<>(), trcf); + ModelState em = new ModelState<>(trcf, null, null, null, clock, 0, null, new ArrayDeque<>()); em.clear(); assertTrue(em.getSamples().isEmpty()); - assertFalse(em.getTrcf().isPresent()); + assertFalse(em.getModel().isPresent()); } @Test public void setTrcf() { - EntityModel em = new EntityModel(null, null, null); - assertFalse(em.getTrcf().isPresent()); + ModelState em = new ModelState<>(null, null, null, null, clock, 0, null, null); + assertFalse(em.getModel().isPresent()); - em.setTrcf(this.trcf); - assertTrue(em.getTrcf().isPresent()); + em.setModel(this.trcf); + assertTrue(em.getModel().isPresent()); } } diff --git a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java index bf2732777..a891b08f6 100644 --- a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java +++ b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java @@ -20,6 +20,7 @@ import static org.mockito.Mockito.when; import java.time.Clock; +import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayDeque; import java.util.ArrayList; @@ -34,8 +35,6 @@ import org.apache.lucene.tests.util.TimeUnits; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -46,11 +45,18 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; import com.google.common.collect.ImmutableList; @@ -118,10 +124,8 @@ private void averageAccuracyTemplate( searchFeatureDao, imputer, clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + TimeSeriesSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + TimeSeriesSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, @@ -131,33 +135,30 @@ private void averageAccuracyTemplate( TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); - entityColdStarter = new EntityColdStarter( + entityColdStarter = new ADColdStart( clock, threadPool, stateManager, TimeSeriesSettings.NUM_SAMPLES_PER_TREE, TimeSeriesSettings.NUM_TREES, - TimeSeriesSettings.TIME_DECAY, numMinSamples, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - imputer, searchFeatureDao, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, - settings, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, seed, - TimeSeriesSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 ); - modelManager = new ModelManager( - mock(CheckpointDao.class), + modelManager = new ADModelManager( + mock(ADCheckpointDao.class), mock(Clock.class), TimeSeriesSettings.NUM_TREES, TimeSeriesSettings.NUM_SAMPLES_PER_TREE, - TimeSeriesSettings.TIME_DECAY, TimeSeriesSettings.NUM_MIN_SAMPLES, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, AnomalyDetectorSettings.MIN_PREVIEW_SIZE, @@ -218,14 +219,15 @@ public int compare(Entry p1, Entry p2) { }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); entity = Entity.createSingleAttributeEntity("field", entityName + z); - EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); - ModelState modelState = new ModelState<>( - model, + ModelState modelState = new ModelState<>( + null, entity.getModelId(detectorId).get(), detector.getId(), - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - priority + priority, + Optional.of(entity), + new ArrayDeque<>() ); released = new AtomicBoolean(); @@ -236,10 +238,25 @@ public int compare(Entry p1, Entry p2) { inProgressLatch.countDown(); }); - entityColdStarter.trainModel(entity, detector.getId(), modelState, listener); + long dataStartTimeMs = System.currentTimeMillis(); + entityColdStarter + .trainModel( + new FeatureRequest( + dataStartTimeMs + 60000, + detector.getId(), + RequestPriority.MEDIUM, + new double[] {}, + dataStartTimeMs, + entity, + null + ), + detector.getId(), + modelState, + listener + ); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); int tp = 0; int fp = 0; @@ -248,7 +265,7 @@ public int compare(Entry p1, Entry p2) { for (int j = trainTestSplit; j < data.length; j++) { ThresholdingResult result = modelManager - .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + .getResult(new Sample(data[j], Instant.now(), Instant.now()), modelState, modelId, Optional.of(entity), detector, null); if (result.getGrade() > 0) { if (changeTimestamps[j] == 0) { fp++; diff --git a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java index cbb7b09ba..d9c3ae10b 100644 --- a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java @@ -51,11 +51,8 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -71,7 +68,11 @@ import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -91,7 +92,7 @@ @SuppressWarnings("unchecked") public class ModelManagerTests { - private ModelManager modelManager; + private ADModelManager modelManager; @Mock private AnomalyDetector anomalyDetector; @@ -103,7 +104,7 @@ public class ModelManagerTests { private JvmService jvmService; @Mock - private CheckpointDao checkpointDao; + private ADCheckpointDao checkpointDao; @Mock private Clock clock; @@ -112,16 +113,10 @@ public class ModelManagerTests { private FeatureManager featureManager; @Mock - private EntityColdStarter entityColdStarter; + private ADColdStart entityColdStarter; @Mock - private EntityCache cache; - - @Mock - private ModelState modelState; - - @Mock - private EntityModel entityModel; + private ModelState modelState; @Mock private ThresholdedRandomCutForest trcf; @@ -225,12 +220,11 @@ public void setup() { .build(); modelManager = spy( - new ModelManager( + new ADModelManager( checkpointDao, clock, numTrees, numSamples, - rcfTimeDecay, numMinSamples, thresholdMinPvalue, minPreviewSize, @@ -248,8 +242,7 @@ public void setup() { rcfModelId = "detectorId_model_rcf_1"; thresholdModelId = "detectorId_model_threshold"; - when(this.modelState.getModel()).thenReturn(this.entityModel); - when(this.entityModel.getTrcf()).thenReturn(Optional.of(this.trcf)); + when(this.modelState.getModel()).thenReturn(Optional.of(this.trcf)); when(anomalyDetector.getShingleSize()).thenReturn(shingleSize); } @@ -267,7 +260,7 @@ private Object[] getDetectorIdForModelIdData() { @Test @Parameters(method = "getDetectorIdForModelIdData") public void getDetectorIdForModelId_returnExpectedId(String modelId, String expectedDetectorId) { - assertEquals(expectedDetectorId, SingleStreamModelIdMapper.getDetectorIdForModelId(modelId)); + assertEquals(expectedDetectorId, SingleStreamModelIdMapper.getConfigIdForModelId(modelId)); } private Object[] getDetectorIdForModelIdIllegalArgument() { @@ -277,7 +270,7 @@ private Object[] getDetectorIdForModelIdIllegalArgument() { @Test(expected = IllegalArgumentException.class) @Parameters(method = "getDetectorIdForModelIdIllegalArgument") public void getDetectorIdForModelId_throwIllegalArgument_forInvalidId(String modelId) { - SingleStreamModelIdMapper.getDetectorIdForModelId(modelId); + SingleStreamModelIdMapper.getConfigIdForModelId(modelId); } private Map createDataNodes(int numDataNodes) { @@ -415,12 +408,11 @@ public void getRcfResult_throwToListener_whenHeapLimitExceed() { // use new memoryTracker modelManager = spy( - new ModelManager( + new ADModelManager( checkpointDao, clock, numTrees, numSamples, - rcfTimeDecay, numMinSamples, thresholdMinPvalue, minPreviewSize, @@ -844,7 +836,7 @@ public void getPreviewResults_returnNoAnomalies_forNoAnomalies() { int numPoints = 1000; double[][] points = Stream.generate(() -> new double[] { 0 }).limit(numPoints).toArray(double[][]::new); - List results = modelManager.getPreviewResults(points, shingleSize); + List results = modelManager.getPreviewResults(points, shingleSize, 0.0001); assertEquals(numPoints, results.size()); assertTrue(results.stream().noneMatch(r -> r.getGrade() > 0)); @@ -856,7 +848,7 @@ public void getPreviewResults_returnAnomalies_forLastAnomaly() { double[][] points = Stream.generate(() -> new double[] { 0 }).limit(numPoints).toArray(double[][]::new); points[points.length - 1] = new double[] { 1. }; - List results = modelManager.getPreviewResults(points, shingleSize); + List results = modelManager.getPreviewResults(points, shingleSize, 0.0001); assertEquals(numPoints, results.size()); assertTrue(results.stream().limit(numPoints - 1).noneMatch(r -> r.getGrade() > 0)); @@ -865,37 +857,15 @@ public void getPreviewResults_returnAnomalies_forLastAnomaly() { @Test(expected = IllegalArgumentException.class) public void getPreviewResults_throwIllegalArgument_forInvalidInput() { - modelManager.getPreviewResults(new double[0][0], shingleSize); - } - - @Test - public void processEmptyCheckpoint() { - ModelState modelState = modelManager.processEntityCheckpoint(Optional.empty(), null, "", "", shingleSize); - assertEquals(Instant.MIN, modelState.getLastCheckpointTime()); - } - - @Test - public void processNonEmptyCheckpoint() { - String modelId = "abc"; - String detectorId = "123"; - EntityModel model = MLUtil.createNonEmptyModel(modelId); - Instant checkpointTime = Instant.ofEpochMilli(1000); - ModelState modelState = modelManager - .processEntityCheckpoint( - Optional.of(new SimpleImmutableEntry<>(model, checkpointTime)), - null, - modelId, - detectorId, - shingleSize - ); - assertEquals(checkpointTime, modelState.getLastCheckpointTime()); - assertEquals(model.getSamples().size(), modelState.getModel().getSamples().size()); - assertEquals(now, modelState.getLastUsedTime()); + modelManager.getPreviewResults(new double[0][0], shingleSize, 0.0001); } @Test public void getNullState() { - assertEquals(new ThresholdingResult(0, 0, 0), modelManager.getAnomalyResultForEntity(new double[] {}, null, "", null, shingleSize)); + assertEquals( + new ThresholdingResult(0, 0, 0), + modelManager.getResult(new Sample(new double[] {}, Instant.now(), Instant.now()), null, "", null, anomalyDetector, "") + ); } @Test @@ -909,10 +879,8 @@ public void getEmptyStateFullSamples() { searchFeatureDao, interpolator, clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + TimeSeriesSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + TimeSeriesSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, @@ -922,35 +890,32 @@ public void getEmptyStateFullSamples() { TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); - CheckpointWriteWorker checkpointWriteQueue = mock(CheckpointWriteWorker.class); + ADCheckpointWriteWorker checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); - entityColdStarter = new EntityColdStarter( + entityColdStarter = new ADColdStart( clock, threadPool, stateManager, TimeSeriesSettings.NUM_SAMPLES_PER_TREE, TimeSeriesSettings.NUM_TREES, - TimeSeriesSettings.TIME_DECAY, numMinSamples, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - interpolator, searchFeatureDao, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, - settings, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - TimeSeriesSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 ); modelManager = spy( - new ModelManager( + new ADModelManager( checkpointDao, clock, numTrees, numSamples, - rcfTimeDecay, numMinSamples, thresholdMinPvalue, minPreviewSize, @@ -964,50 +929,59 @@ public void getEmptyStateFullSamples() { ) ); - ModelState state = MLUtil + ModelState state = MLUtil .randomModelState(new RandomModelStateConfig.Builder().fullModel(false).sampleSize(numMinSamples).build()); - EntityModel model = state.getModel(); - assertTrue(!model.getTrcf().isPresent()); - ThresholdingResult result = modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize); + Optional model = state.getModel(); + assertTrue(model.isEmpty()); + ThresholdingResult result = modelManager + .getResult(new Sample(new double[] { -1 }, Instant.now(), Instant.now()), state, "", Optional.empty(), anomalyDetector, ""); // model outputs scores assertTrue(result.getRcfScore() != 0); // added the sample to score since our model is empty - assertEquals(0, model.getSamples().size()); + assertEquals(0, state.getSamples().size()); } @Test public void getAnomalyResultForEntityNoModel() { - ModelState modelState = new ModelState<>(null, modelId, detectorId, ModelType.ENTITY.getName(), clock, 0); + ModelState modelState = new ModelState<>( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock + ); ThresholdingResult result = modelManager - .getAnomalyResultForEntity( - new double[] { -1 }, + .getResult( + new Sample(new double[] { -1 }, Instant.now(), Instant.now()), modelState, modelId, - Entity.createSingleAttributeEntity("field", "val"), - shingleSize + Optional.of(Entity.createSingleAttributeEntity("field", "val")), + anomalyDetector, + "" ); // model outputs scores assertEquals(new ThresholdingResult(0, 0, 0), result); // added the sample to score since our model is empty - assertEquals(1, modelState.getModel().getSamples().size()); + assertEquals(1, modelState.getSamples().size()); } @Test public void getEmptyStateNotFullSamples() { - ModelState state = MLUtil + ModelState state = MLUtil .randomModelState(new RandomModelStateConfig.Builder().fullModel(false).sampleSize(numMinSamples - 1).build()); assertEquals( new ThresholdingResult(0, 0, 0), - modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize) + modelManager.getResult(new Sample(new double[] { -1 }, Instant.now(), Instant.now()), state, "", null, anomalyDetector, "") ); - assertEquals(numMinSamples, state.getModel().getSamples().size()); + assertEquals(numMinSamples, state.getSamples().size()); } @Test public void scoreSamples() { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize); - assertEquals(0, state.getModel().getSamples().size()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + modelManager.getResult(new Sample(new double[] { -1 }, Instant.now(), Instant.now()), state, "", null, anomalyDetector, ""); + assertEquals(0, state.getSamples().size()); assertEquals(now, state.getLastUsedTime()); } @@ -1019,7 +993,7 @@ public void getAnomalyResultForEntity_withTrcf() { when(this.trcf.process(this.point, 0)).thenReturn(anomalyDescriptor); ThresholdingResult result = modelManager - .getAnomalyResultForEntity(this.point, this.modelState, this.detectorId, null, this.shingleSize); + .getResult(new Sample(this.point, Instant.now(), Instant.now()), this.modelState, this.detectorId, null, anomalyDetector, ""); assertEquals( new ThresholdingResult( anomalyDescriptor.getAnomalyGrade(), @@ -1042,9 +1016,11 @@ public void score_with_trcf() { when(rcf.getDimensions()).thenReturn(40); when(this.trcf.getForest()).thenReturn(rcf); when(this.trcf.process(this.point, 0)).thenReturn(anomalyDescriptor); - when(this.entityModel.getSamples()).thenReturn(new ArrayDeque<>(Arrays.asList(this.point))); + when(this.modelState.getSamples()) + .thenReturn(new ArrayDeque<>(Arrays.asList(new Sample(this.point, Instant.now(), Instant.now())))); - ThresholdingResult result = modelManager.score(this.point, this.detectorId, this.modelState); + ThresholdingResult result = modelManager + .score(new Sample(this.point, Instant.now(), Instant.now()), this.modelId, this.modelState, anomalyDetector); assertEquals( new ThresholdingResult( anomalyDescriptor.getAnomalyGrade(), @@ -1075,7 +1051,8 @@ public void score_throw() { when(rcf.getDimensions()).thenReturn(40); when(this.trcf.getForest()).thenReturn(rcf); doThrow(new IllegalArgumentException()).when(trcf).process(any(), anyLong()); - when(this.entityModel.getSamples()).thenReturn(new ArrayDeque<>(Arrays.asList(this.point))); - modelManager.score(this.point, this.detectorId, this.modelState); + when(this.modelState.getSamples()) + .thenReturn(new ArrayDeque<>(Arrays.asList(new Sample(this.point, Instant.now(), Instant.now())))); + modelManager.score(new Sample(this.point, Instant.now(), Instant.now()), this.modelId, this.modelState, anomalyDetector); } } diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java index a861ec9de..9a58b9c4f 100644 --- a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java +++ b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java @@ -12,12 +12,12 @@ package org.opensearch.ad.mock.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; import org.opensearch.timeseries.transport.JobResponse; public class MockAnomalyDetectorJobAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/mockjobmanagement"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/mockjobmanagement"; public static final MockAnomalyDetectorJobAction INSTANCE = new MockAnomalyDetectorJobAction(); private MockAnomalyDetectorJobAction() { diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java index 48425a747..3adeead1c 100644 --- a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java +++ b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java @@ -22,9 +22,8 @@ import org.opensearch.ad.ExecuteADResultResponseRecorder; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyDetectorJobRequest; import org.opensearch.ad.transport.AnomalyDetectorJobTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -36,12 +35,14 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.transport.JobRequest; import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; -public class MockAnomalyDetectorJobTransportActionWithUser extends HandledTransportAction { +public class MockAnomalyDetectorJobTransportActionWithUser extends HandledTransportAction { private final Logger logger = LogManager.getLogger(AnomalyDetectorJobTransportAction.class); private final Client client; @@ -54,6 +55,7 @@ public class MockAnomalyDetectorJobTransportActionWithUser extends HandledTransp private final ADTaskManager adTaskManager; private final TransportService transportService; private final ExecuteADResultResponseRecorder recorder; + private final NodeStateManager nodeStateManager; @Inject public MockAnomalyDetectorJobTransportActionWithUser( @@ -65,9 +67,10 @@ public MockAnomalyDetectorJobTransportActionWithUser( ADIndexManagement anomalyDetectionIndices, NamedXContentRegistry xContentRegistry, ADTaskManager adTaskManager, - ExecuteADResultResponseRecorder recorder + ExecuteADResultResponseRecorder recorder, + NodeStateManager nodeStateManager ) { - super(MockAnomalyDetectorJobAction.NAME, transportService, actionFilters, AnomalyDetectorJobRequest::new); + super(MockAnomalyDetectorJobAction.NAME, transportService, actionFilters, JobRequest::new); this.transportService = transportService; this.client = client; this.clusterService = clusterService; @@ -81,15 +84,14 @@ public MockAnomalyDetectorJobTransportActionWithUser( ThreadContext threadContext = new ThreadContext(settings); context = threadContext.stashContext(); this.recorder = recorder; + this.nodeStateManager = nodeStateManager; } @Override - protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionListener listener) { - String detectorId = request.getDetectorID(); - DateRange detectionDateRange = request.getDetectionDateRange(); + protected void doExecute(Task task, JobRequest request, ActionListener listener) { + String detectorId = request.getConfigID(); + DateRange detectionDateRange = request.getDateRange(); boolean historical = request.isHistorical(); - long seqNo = request.getSeqNo(); - long primaryTerm = request.getPrimaryTerm(); String rawPath = request.getRawPath(); TimeValue requestTimeout = AD_REQUEST_TIMEOUT.get(settings); String userStr = "user_name|backendrole1,backendrole2|roles1,role2"; @@ -101,17 +103,7 @@ protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionLis detectorId, filterByEnabled, listener, - (anomalyDetector) -> executeDetector( - listener, - detectorId, - seqNo, - primaryTerm, - rawPath, - requestTimeout, - user, - detectionDateRange, - historical - ), + (anomalyDetector) -> executeDetector(listener, detectorId, rawPath, requestTimeout, user, detectionDateRange, historical), client, clusterService, xContentRegistry, @@ -126,31 +118,26 @@ protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionLis private void executeDetector( ActionListener listener, String detectorId, - long seqNo, - long primaryTerm, String rawPath, TimeValue requestTimeout, User user, DateRange detectionDateRange, boolean historical ) { - IndexAnomalyDetectorJobActionHandler handler = new IndexAnomalyDetectorJobActionHandler( + ADIndexJobActionHandler handler = new ADIndexJobActionHandler( client, anomalyDetectionIndices, - detectorId, - seqNo, - primaryTerm, - requestTimeout, xContentRegistry, - transportService, adTaskManager, - recorder + recorder, + nodeStateManager, + Settings.EMPTY ); if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { - adTaskManager.startDetector(detectorId, detectionDateRange, handler, user, transportService, context, listener); + handler.startConfig(detectorId, detectionDateRange, user, transportService, context, listener); } else if (rawPath.endsWith(RestHandlerUtils.STOP_JOB)) { // Stop detector - adTaskManager.stopDetector(detectorId, historical, handler, user, transportService, listener); + handler.stopConfig(detectorId, historical, user, transportService, listener); } } } diff --git a/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java b/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java index 27456589a..324fc0373 100644 --- a/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java @@ -19,6 +19,7 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityTaskProfile; public class ADEntityTaskProfileTests extends OpenSearchSingleNodeTestCase { @@ -32,9 +33,9 @@ protected NamedWriteableRegistry writableRegistry() { return getInstanceFromNode(NamedWriteableRegistry.class); } - private ADEntityTaskProfile createADEntityTaskProfile() { + private EntityTaskProfile createADEntityTaskProfile() { Entity entity = createEntityAndAttributes(); - return new ADEntityTaskProfile(1, 23L, false, 1, 2L, "1234", entity, "4321", ADTaskType.HISTORICAL_HC_ENTITY.name()); + return new EntityTaskProfile(1, 23L, false, 1, 2L, "1234", entity, "4321", ADTaskType.HISTORICAL_HC_ENTITY.name()); } private Entity createEntityAndAttributes() { @@ -49,24 +50,24 @@ private Entity createEntityAndAttributes() { } public void testADEntityTaskProfileSerialization() throws IOException { - ADEntityTaskProfile entityTask = createADEntityTaskProfile(); + EntityTaskProfile entityTask = createADEntityTaskProfile(); BytesStreamOutput output = new BytesStreamOutput(); entityTask.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); - ADEntityTaskProfile parsedEntityTask = new ADEntityTaskProfile(input); + EntityTaskProfile parsedEntityTask = new EntityTaskProfile(input); assertEquals(entityTask, parsedEntityTask); } public void testParseADEntityTaskProfile() throws IOException { - ADEntityTaskProfile entityTask = createADEntityTaskProfile(); + EntityTaskProfile entityTask = createADEntityTaskProfile(); String adEntityTaskProfileString = TestHelpers .xContentBuilderToString(entityTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - ADEntityTaskProfile parsedEntityTask = ADEntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); + EntityTaskProfile parsedEntityTask = EntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); assertEquals(entityTask, parsedEntityTask); } public void testParseADEntityTaskProfileWithNullEntity() throws IOException { - ADEntityTaskProfile entityTask = new ADEntityTaskProfile( + EntityTaskProfile entityTask = new EntityTaskProfile( 1, 23L, false, @@ -82,14 +83,14 @@ public void testParseADEntityTaskProfileWithNullEntity() throws IOException { assertNull(entityTask.getEntity()); String adEntityTaskProfileString = TestHelpers .xContentBuilderToString(entityTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - ADEntityTaskProfile parsedEntityTask = ADEntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); + EntityTaskProfile parsedEntityTask = EntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); assertEquals(entityTask, parsedEntityTask); } public void testADEntityTaskProfileEqual() { - ADEntityTaskProfile entityTaskOne = createADEntityTaskProfile(); - ADEntityTaskProfile entityTaskTwo = createADEntityTaskProfile(); - ADEntityTaskProfile entityTaskThree = new ADEntityTaskProfile( + EntityTaskProfile entityTaskOne = createADEntityTaskProfile(); + EntityTaskProfile entityTaskTwo = createADEntityTaskProfile(); + EntityTaskProfile entityTaskThree = new EntityTaskProfile( null, null, false, @@ -106,7 +107,7 @@ public void testADEntityTaskProfileEqual() { public void testParseADEntityTaskProfileWithMultipleNullFields() throws IOException { Entity entity = createEntityAndAttributes(); - ADEntityTaskProfile entityTask = new ADEntityTaskProfile( + EntityTaskProfile entityTask = new EntityTaskProfile( null, null, false, @@ -119,7 +120,7 @@ public void testParseADEntityTaskProfileWithMultipleNullFields() throws IOExcept ); String adEntityTaskProfileString = TestHelpers .xContentBuilderToString(entityTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - ADEntityTaskProfile parsedEntityTask = ADEntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); + EntityTaskProfile parsedEntityTask = EntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); assertEquals(entityTask, parsedEntityTask); } } diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java index d3298eae2..a31da61a0 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java @@ -11,10 +11,8 @@ package org.opensearch.ad.model; -import static org.opensearch.ad.constant.ADCommonMessages.INVALID_RESULT_INDEX_PREFIX; import static org.opensearch.ad.constant.ADCommonName.CUSTOM_RESULT_INDEX_PREFIX; -import static org.opensearch.ad.model.AnomalyDetector.MAX_RESULT_INDEX_NAME_SIZE; -import static org.opensearch.timeseries.constant.CommonMessages.INVALID_CHAR_IN_RESULT_INDEX_NAME; +import static org.opensearch.timeseries.model.Config.MAX_RESULT_INDEX_NAME_SIZE; import java.io.IOException; import java.time.Instant; @@ -30,6 +28,9 @@ import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -295,6 +296,7 @@ public void testParseAnomalyDetectorWithEmptyUiMetadata() throws IOException { } public void testInvalidShingleSize() throws Exception { + Feature feature = TestHelpers.randomFeature(); TestHelpers .assertFailWith( ValidationException.class, @@ -305,7 +307,7 @@ public void testInvalidShingleSize() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -316,12 +318,17 @@ public void testInvalidShingleSize() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ) ); } public void testNullDetectorName() throws Exception { + Feature feature = TestHelpers.randomFeature(); TestHelpers .assertFailWith( ValidationException.class, @@ -332,7 +339,7 @@ public void testNullDetectorName() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -343,12 +350,17 @@ public void testNullDetectorName() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ) ); } public void testBlankDetectorName() throws Exception { + Feature feature = TestHelpers.randomFeature(); TestHelpers .assertFailWith( ValidationException.class, @@ -359,7 +371,7 @@ public void testBlankDetectorName() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -370,12 +382,17 @@ public void testBlankDetectorName() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ) ); } public void testNullTimeField() throws Exception { + Feature feature = TestHelpers.randomFeature(); TestHelpers .assertFailWith( ValidationException.class, @@ -386,7 +403,7 @@ public void testNullTimeField() throws Exception { randomAlphaOfLength(5), null, ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -397,12 +414,17 @@ public void testNullTimeField() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ) ); } public void testNullIndices() throws Exception { + Feature feature = TestHelpers.randomFeature(); TestHelpers .assertFailWith( ValidationException.class, @@ -413,7 +435,7 @@ public void testNullIndices() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), null, - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -424,12 +446,17 @@ public void testNullIndices() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ) ); } public void testEmptyIndices() throws Exception { + Feature feature = TestHelpers.randomFeature(); TestHelpers .assertFailWith( ValidationException.class, @@ -440,7 +467,7 @@ public void testEmptyIndices() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -451,12 +478,17 @@ public void testEmptyIndices() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ) ); } public void testNullDetectionInterval() throws Exception { + Feature feature = TestHelpers.randomFeature(); TestHelpers .assertFailWith( ValidationException.class, @@ -467,7 +499,7 @@ public void testNullDetectionInterval() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), null, TestHelpers.randomIntervalTimeConfiguration(), @@ -478,12 +510,17 @@ public void testNullDetectionInterval() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ) ); } public void testInvalidDetectionInterval() { + Feature feature = TestHelpers.randomFeature(); ValidationException exception = expectThrows( ValidationException.class, () -> new AnomalyDetector( @@ -493,7 +530,7 @@ public void testInvalidDetectionInterval() { randomAlphaOfLength(30), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), new IntervalTimeConfiguration(0, ChronoUnit.MINUTES), TestHelpers.randomIntervalTimeConfiguration(), @@ -504,13 +541,18 @@ public void testInvalidDetectionInterval() { null, null, null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomInt(), + randomInt(), + randomIntBetween(1, 1000), + null ) ); assertEquals("Detection interval must be a positive integer", exception.getMessage()); } public void testInvalidWindowDelay() { + Feature feature = TestHelpers.randomFeature(); IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, () -> new AnomalyDetector( @@ -520,7 +562,7 @@ public void testInvalidWindowDelay() { randomAlphaOfLength(30), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), new IntervalTimeConfiguration(1, ChronoUnit.MINUTES), new IntervalTimeConfiguration(-1, ChronoUnit.MINUTES), @@ -531,7 +573,11 @@ public void testInvalidWindowDelay() { null, null, null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomInt(), + randomInt(), + randomIntBetween(1, 1000), + null ) ); assertEquals("Interval -1 should be non-negative", exception.getMessage()); @@ -553,14 +599,15 @@ public void testEmptyFeatures() throws IOException { } public void testGetShingleSize() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( + Feature feature = TestHelpers.randomFeature(); + Config anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -571,20 +618,25 @@ public void testGetShingleSize() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); assertEquals((int) anomalyDetector.getShingleSize(), 5); } public void testGetShingleSizeReturnsDefaultValue() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( + Feature feature = TestHelpers.randomFeature(); + Config anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -595,13 +647,17 @@ public void testGetShingleSizeReturnsDefaultValue() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); assertEquals((int) anomalyDetector.getShingleSize(), TimeSeriesSettings.DEFAULT_SHINGLE_SIZE); } public void testNullFeatureAttributes() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( + Config anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), randomAlphaOfLength(5), @@ -619,21 +675,25 @@ public void testNullFeatureAttributes() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); assertNotNull(anomalyDetector.getFeatureAttributes()); assertEquals(0, anomalyDetector.getFeatureAttributes().size()); } public void testValidateResultIndex() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( + Config anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(TestHelpers.randomFeature()), + null, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -644,11 +704,14 @@ public void testValidateResultIndex() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(0), + randomInt(), + randomInt(), + randomIntBetween(1, 1000), + null ); - String errorMessage = anomalyDetector.validateCustomResultIndex("abc"); - assertEquals(INVALID_RESULT_INDEX_PREFIX, errorMessage); + assertEquals(ADCommonMessages.INVALID_RESULT_INDEX_PREFIX, errorMessage); StringBuilder resultIndexNameBuilder = new StringBuilder(CUSTOM_RESULT_INDEX_PREFIX); for (int i = 0; i < MAX_RESULT_INDEX_NAME_SIZE - CUSTOM_RESULT_INDEX_PREFIX.length(); i++) { @@ -658,10 +721,10 @@ public void testValidateResultIndex() throws IOException { resultIndexNameBuilder.append("a"); errorMessage = anomalyDetector.validateCustomResultIndex(resultIndexNameBuilder.toString()); - assertEquals(AnomalyDetector.INVALID_RESULT_INDEX_NAME_SIZE, errorMessage); + assertEquals(Config.INVALID_RESULT_INDEX_NAME_SIZE, errorMessage); errorMessage = anomalyDetector.validateCustomResultIndex(CUSTOM_RESULT_INDEX_PREFIX + "abc#"); - assertEquals(INVALID_CHAR_IN_RESULT_INDEX_NAME, errorMessage); + assertEquals(CommonMessages.INVALID_CHAR_IN_RESULT_INDEX_NAME, errorMessage); } public void testParseAnomalyDetectorWithNoDescription() throws IOException { diff --git a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java index 424de19da..28245aa31 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java @@ -11,8 +11,6 @@ package org.opensearch.ad.model; -import static org.opensearch.test.OpenSearchTestCase.randomDouble; - import java.io.IOException; import java.util.Collection; import java.util.Locale; diff --git a/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java b/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java index 9960a5fe2..07c7410b4 100644 --- a/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java @@ -21,13 +21,19 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ConfigProfile; +import org.opensearch.timeseries.model.ConfigState; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; +import org.opensearch.timeseries.model.ProfileName; public class DetectorProfileTests extends OpenSearchTestCase { - private DetectorProfile createRandomDetectorProfile() { + private ConfigProfile createRandomDetectorProfile() { return new DetectorProfile.Builder() - .state(DetectorState.INIT) + .state(ConfigState.INIT) .error(randomAlphaOfLength(5)) .modelProfile( new ModelProfileOnNode[] { @@ -45,7 +51,7 @@ private DetectorProfile createRandomDetectorProfile() { .totalSizeInBytes(-1) .totalEntities(randomLong()) .activeEntities(randomLong()) - .adTaskProfile( + .taskProfile( new ADTaskProfile( randomAlphaOfLength(5), randomInt(), @@ -60,17 +66,17 @@ private DetectorProfile createRandomDetectorProfile() { } public void testParseDetectorProfile() throws IOException { - DetectorProfile detectorProfile = createRandomDetectorProfile(); + ConfigProfile detectorProfile = createRandomDetectorProfile(); BytesStreamOutput output = new BytesStreamOutput(); detectorProfile.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); - DetectorProfile parsedDetectorProfile = new DetectorProfile(input); + ConfigProfile parsedDetectorProfile = new DetectorProfile(input); assertEquals("Detector profile serialization doesn't work", detectorProfile, parsedDetectorProfile); } public void testMergeDetectorProfile() { - DetectorProfile detectorProfileOne = createRandomDetectorProfile(); - DetectorProfile detectorProfileTwo = createRandomDetectorProfile(); + ConfigProfile detectorProfileOne = createRandomDetectorProfile(); + ConfigProfile detectorProfileTwo = createRandomDetectorProfile(); String errorPreMerge = detectorProfileOne.getError(); detectorProfileOne.merge(detectorProfileTwo); assertTrue(detectorProfileOne.toString().contains(detectorProfileTwo.getError())); @@ -79,7 +85,7 @@ public void testMergeDetectorProfile() { } public void testDetectorProfileToXContent() throws IOException { - DetectorProfile detectorProfile = createRandomDetectorProfile(); + ConfigProfile detectorProfile = createRandomDetectorProfile(); String detectorProfileString = TestHelpers.xContentBuilderToString(detectorProfile.toXContent(TestHelpers.builder())); XContentParser parser = TestHelpers.parser(detectorProfileString); Map parsedMap = parser.map(); @@ -89,22 +95,22 @@ public void testDetectorProfileToXContent() throws IOException { } public void testDetectorProfileName() throws IllegalArgumentException { - assertEquals("ad_task", DetectorProfileName.getName(ADCommonName.AD_TASK).getName()); - assertEquals("state", DetectorProfileName.getName(ADCommonName.STATE).getName()); - assertEquals("error", DetectorProfileName.getName(ADCommonName.ERROR).getName()); - assertEquals("coordinating_node", DetectorProfileName.getName(ADCommonName.COORDINATING_NODE).getName()); - assertEquals("shingle_size", DetectorProfileName.getName(ADCommonName.SHINGLE_SIZE).getName()); - assertEquals("total_size_in_bytes", DetectorProfileName.getName(ADCommonName.TOTAL_SIZE_IN_BYTES).getName()); - assertEquals("models", DetectorProfileName.getName(ADCommonName.MODELS).getName()); - assertEquals("init_progress", DetectorProfileName.getName(ADCommonName.INIT_PROGRESS).getName()); - assertEquals("total_entities", DetectorProfileName.getName(ADCommonName.TOTAL_ENTITIES).getName()); - assertEquals("active_entities", DetectorProfileName.getName(ADCommonName.ACTIVE_ENTITIES).getName()); - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> DetectorProfileName.getName("abc")); + assertEquals("ad_task", ProfileName.getName(ADCommonName.AD_TASK).getName()); + assertEquals("state", ProfileName.getName(CommonName.STATE).getName()); + assertEquals("error", ProfileName.getName(CommonName.ERROR).getName()); + assertEquals("coordinating_node", ProfileName.getName(CommonName.COORDINATING_NODE).getName()); + assertEquals("shingle_size", ProfileName.getName(CommonName.SHINGLE_SIZE).getName()); + assertEquals("total_size_in_bytes", ProfileName.getName(CommonName.TOTAL_SIZE_IN_BYTES).getName()); + assertEquals("models", ProfileName.getName(CommonName.MODELS).getName()); + assertEquals("init_progress", ProfileName.getName(CommonName.INIT_PROGRESS).getName()); + assertEquals("total_entities", ProfileName.getName(CommonName.TOTAL_ENTITIES).getName()); + assertEquals("active_entities", ProfileName.getName(CommonName.ACTIVE_ENTITIES).getName()); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> ProfileName.getName("abc")); assertEquals(exception.getMessage(), ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); } public void testDetectorProfileSet() throws IllegalArgumentException { - DetectorProfile detectorProfileOne = createRandomDetectorProfile(); + ConfigProfile detectorProfileOne = createRandomDetectorProfile(); detectorProfileOne.setShingleSize(20); assertEquals(20, detectorProfileOne.getShingleSize()); detectorProfileOne.setActiveEntities(10L); diff --git a/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java b/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java index 24cb0c879..18be64d54 100644 --- a/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java @@ -18,8 +18,8 @@ import java.util.List; import org.junit.Test; -import org.opensearch.ad.stats.ADStatsResponse; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.transport.StatsResponse; public class EntityAnomalyResultTests extends OpenSearchTestCase { @@ -90,7 +90,7 @@ public void testMerge_self() { @Test public void testMerge_otherClass() { - ADStatsResponse adStatsResponse = new ADStatsResponse(); + StatsResponse adStatsResponse = new StatsResponse(); AnomalyResult anomalyResult = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); EntityAnomalyResult entityAnomalyResult = new EntityAnomalyResult(new ArrayList() { diff --git a/src/test/java/org/opensearch/ad/model/EntityProfileTests.java b/src/test/java/org/opensearch/ad/model/EntityProfileTests.java index 18e179145..f647da3ac 100644 --- a/src/test/java/org/opensearch/ad/model/EntityProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/EntityProfileTests.java @@ -16,10 +16,12 @@ import java.io.IOException; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.EntityState; import test.org.opensearch.ad.util.JsonDeserializer; @@ -39,7 +41,7 @@ public void testToXContent() throws IOException, JsonPathNotFoundException { profile1.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = builder.toString(); - assertEquals("INIT", JsonDeserializer.getTextValue(json, ADCommonName.STATE)); + assertEquals("INIT", JsonDeserializer.getTextValue(json, CommonName.STATE)); EntityProfile profile2 = new EntityProfile(null, -1, -1, null, null, EntityState.UNKNOWN); @@ -47,7 +49,7 @@ public void testToXContent() throws IOException, JsonPathNotFoundException { profile2.toXContent(builder, ToXContent.EMPTY_PARAMS); json = builder.toString(); - assertTrue(false == JsonDeserializer.hasChildNode(json, ADCommonName.STATE)); + assertTrue(false == JsonDeserializer.hasChildNode(json, CommonName.STATE)); } public void testToXContentTimeStampAboveZero() throws IOException, JsonPathNotFoundException { @@ -57,7 +59,7 @@ public void testToXContentTimeStampAboveZero() throws IOException, JsonPathNotFo profile1.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = builder.toString(); - assertEquals("INIT", JsonDeserializer.getTextValue(json, ADCommonName.STATE)); + assertEquals("INIT", JsonDeserializer.getTextValue(json, CommonName.STATE)); EntityProfile profile2 = new EntityProfile(null, 1, 1, null, null, EntityState.UNKNOWN); @@ -65,6 +67,6 @@ public void testToXContentTimeStampAboveZero() throws IOException, JsonPathNotFo profile2.toXContent(builder, ToXContent.EMPTY_PARAMS); json = builder.toString(); - assertTrue(false == JsonDeserializer.hasChildNode(json, ADCommonName.STATE)); + assertTrue(false == JsonDeserializer.hasChildNode(json, CommonName.STATE)); } } diff --git a/src/test/java/org/opensearch/ad/model/ModelProfileTests.java b/src/test/java/org/opensearch/ad/model/ModelProfileTests.java index c99ff6222..b5c9c852a 100644 --- a/src/test/java/org/opensearch/ad/model/ModelProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/ModelProfileTests.java @@ -20,6 +20,7 @@ import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; import test.org.opensearch.ad.util.JsonDeserializer; diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java index 830ac3f65..eb9d51da7 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java @@ -27,40 +27,45 @@ import java.util.Optional; import org.opensearch.action.update.UpdateRequest; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ratelimit.CheckPointMaintainRequestAdapter; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.CheckpointWriteRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class CheckPointMaintainRequestAdapterTests extends AbstractRateLimitingTest { - private CacheProvider cache; - private CheckpointDao checkpointDao; + private ADCacheProvider cache; + private ADCheckpointDao checkpointDao; private String indexName; private Setting checkpointInterval; private CheckPointMaintainRequestAdapter adapter; - private ModelState state; + private ModelState state; private CheckpointMaintainRequest request; private ClusterService clusterService; @Override public void setUp() throws Exception { super.setUp(); - cache = mock(CacheProvider.class); - checkpointDao = mock(CheckpointDao.class); + cache = mock(ADCacheProvider.class); + checkpointDao = mock(ADCheckpointDao.class); indexName = ADCommonName.CHECKPOINT_INDEX_NAME; checkpointInterval = AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ; - EntityCache entityCache = mock(EntityCache.class); + ADPriorityCache entityCache = mock(ADPriorityCache.class); when(cache.get()).thenReturn(entityCache); state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); when(entityCache.getForMaintainance(anyString(), anyString())).thenReturn(Optional.of(state)); @@ -71,13 +76,13 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(settings); adapter = new CheckPointMaintainRequestAdapter( - cache, checkpointDao, indexName, checkpointInterval, clock, clusterService, - Settings.EMPTY + Settings.EMPTY, + cache ); request = new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity.getModelId(detectorId).get()); diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java index 0d05259fc..04913cc6c 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java @@ -32,12 +32,12 @@ import java.util.Optional; import java.util.Random; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -45,19 +45,25 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ratelimit.CheckPointMaintainRequestAdapter; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class CheckpointMaintainWorkerTests extends AbstractRateLimitingTest { ClusterService clusterService; - CheckpointMaintainWorker cpMaintainWorker; - CheckpointWriteWorker writeWorker; + ADCheckpointMaintainWorker cpMaintainWorker; + ADCheckpointWriteWorker writeWorker; CheckpointMaintainRequest request; CheckpointMaintainRequest request2; List requests; - CheckpointDao checkpointDao; + ADCheckpointDao checkpointDao; @Override public void setUp() throws Exception { @@ -81,30 +87,34 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - writeWorker = mock(CheckpointWriteWorker.class); + writeWorker = mock(ADCheckpointWriteWorker.class); + + ADCacheProvider adCacheProvider = new ADCacheProvider(); - CacheProvider cache = mock(CacheProvider.class); - checkpointDao = mock(CheckpointDao.class); + ADPriorityCache cache = mock(ADPriorityCache.class); + checkpointDao = mock(ADCheckpointDao.class); String indexName = ADCommonName.CHECKPOINT_INDEX_NAME; Setting checkpointInterval = AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ; - EntityCache entityCache = mock(EntityCache.class); - when(cache.get()).thenReturn(entityCache); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - when(entityCache.getForMaintainance(anyString(), anyString())).thenReturn(Optional.of(state)); - CheckPointMaintainRequestAdapter adapter = new CheckPointMaintainRequestAdapter( - cache, - checkpointDao, - indexName, - checkpointInterval, - clock, - clusterService, - settings - ); + + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + when(cache.getForMaintainance(anyString(), anyString())).thenReturn(Optional.of(state)); + adCacheProvider.set(cache); + CheckPointMaintainRequestAdapter adapter = + new CheckPointMaintainRequestAdapter<>( + checkpointDao, + indexName, + checkpointInterval, + clock, + clusterService, + settings, + adCacheProvider + ); // Integer.MAX_VALUE makes a huge heap - cpMaintainWorker = new CheckpointMaintainWorker( + cpMaintainWorker = new ADCheckpointMaintainWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -119,7 +129,7 @@ public void setUp() throws Exception { writeWorker, TimeSeriesSettings.HOURLY_MAINTENANCE, nodeStateManager, - adapter + adapter::convert ); request = new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity.getModelId(detectorId).get()); diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java index 41b8035b0..ff9690ba9 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java @@ -11,10 +11,8 @@ package org.opensearch.ad.ratelimit; -import static java.util.AbstractMap.SimpleImmutableEntry; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; @@ -45,21 +43,18 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.get.MultiGetItemResponse; import org.opensearch.action.get.MultiGetResponse; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; @@ -75,31 +70,37 @@ import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.fasterxml.jackson.core.JsonParseException; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class CheckpointReadWorkerTests extends AbstractRateLimitingTest { - CheckpointReadWorker worker; + ADCheckpointReadWorker worker; - CheckpointDao checkpoint; + ADCheckpointDao checkpoint; ClusterService clusterService; - ModelState state; + ModelState state; - CheckpointWriteWorker checkpointWriteQueue; - ModelManager modelManager; - EntityColdStartWorker coldstartQueue; - ResultWriteWorker resultWriteQueue; + ADCheckpointWriteWorker checkpointWriteQueue; + ADModelManager modelManager; + ADColdStartWorker coldstartQueue; + ADSaveResultStrategy resultWriteStrategy; ADIndexManagement anomalyDetectionIndices; - CacheProvider cacheProvider; - EntityCache entityCache; - EntityFeatureRequest request, request2, request3; + Provider cacheProvider; + ADPriorityCache entityCache; + FeatureRequest request, request2, request3; ClusterSettings clusterSettings; ADStats adStats; @@ -125,38 +126,36 @@ public void setUp() throws Exception { state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - checkpoint = mock(CheckpointDao.class); + checkpoint = mock(ADCheckpointDao.class); - Map.Entry entry = new SimpleImmutableEntry(state.getModel(), Instant.now()); - when(checkpoint.processGetResponse(any(), anyString())).thenReturn(Optional.of(entry)); + when(checkpoint.processHCGetResponse(any(), anyString(), anyString())).thenReturn(state); - checkpointWriteQueue = mock(CheckpointWriteWorker.class); + checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); - modelManager = mock(ModelManager.class); - when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); - when(modelManager.score(any(), anyString(), any())).thenReturn(new ThresholdingResult(0, 1, 0.7)); + modelManager = mock(ADModelManager.class); + when(modelManager.getResult(any(), any(), anyString(), any(), any(), anyString())).thenReturn(new ThresholdingResult(0, 1, 0.7)); - coldstartQueue = mock(EntityColdStartWorker.class); - resultWriteQueue = mock(ResultWriteWorker.class); + coldstartQueue = mock(ADColdStartWorker.class); + resultWriteStrategy = mock(ADSaveResultStrategy.class); anomalyDetectionIndices = mock(ADIndexManagement.class); - cacheProvider = mock(CacheProvider.class); - entityCache = mock(EntityCache.class); + cacheProvider = new ADCacheProvider(); + entityCache = mock(ADPriorityCache.class); when(cacheProvider.get()).thenReturn(entityCache); when(entityCache.hostIfPossible(any(), any())).thenReturn(true); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; adStats = new ADStats(statsMap); // Integer.MAX_VALUE makes a huge heap - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -172,18 +171,18 @@ public void setUp() throws Exception { modelManager, checkpoint, coldstartQueue, - resultWriteQueue, nodeStateManager, anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats + adStats, + resultWriteStrategy ); - request = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity, new double[] { 0 }, 0); - request2 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity2, new double[] { 0 }, 0); - request3 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity3, new double[] { 0 }, 0); + request = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, new double[] { 0 }, 0, entity, null); + request2 = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, new double[] { 0 }, 0, entity2, null); + request3 = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, new double[] { 0 }, 0, entity3, null); } static class RegularSetUpConfig { @@ -232,16 +231,15 @@ private void regularTestSetUp(RegularSetUpConfig config) { when(entityCache.hostIfPossible(any(), any())).thenReturn(config.canHostModel); state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(config.fullModel).build()); - when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); + when(checkpoint.processHCGetResponse(any(), anyString(), anyString())).thenReturn(state); + if (config.fullModel) { - when(modelManager.getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt())) - .thenReturn(new ThresholdingResult(0, 1, 1)); + when(modelManager.getResult(any(), any(), anyString(), any(), any(), anyString())).thenReturn(new ThresholdingResult(0, 1, 1)); } else { - when(modelManager.getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt())) - .thenReturn(new ThresholdingResult(0, 0, 0)); + when(modelManager.getResult(any(), any(), anyString(), any(), any(), anyString())).thenReturn(new ThresholdingResult(0, 0, 0)); } - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); worker.putAll(requests); } @@ -249,20 +247,20 @@ private void regularTestSetUp(RegularSetUpConfig config) { public void testRegular() { regularTestSetUp(new RegularSetUpConfig.Builder().build()); - verify(resultWriteQueue, times(1)).put(any()); + verify(resultWriteStrategy, times(1)).saveResult(any(), any(), any(), anyString()); verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); } public void testCannotLoadModel() { regularTestSetUp(new RegularSetUpConfig.Builder().canHostModel(false).build()); - verify(resultWriteQueue, times(1)).put(any()); + verify(resultWriteStrategy, times(1)).saveResult(any(), any(), any(), anyString()); verify(checkpointWriteQueue, times(1)).write(any(), anyBoolean(), any()); } public void testNoFullModel() { regularTestSetUp(new RegularSetUpConfig.Builder().fullModel(false).build()); - verify(resultWriteQueue, never()).put(any()); + verify(resultWriteStrategy, never()).saveResult(any(), any(), any(), anyString()); verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); } @@ -327,7 +325,7 @@ public void testAllDocNotFound() { return null; }).when(checkpoint).batchRead(any(), any()); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request2); worker.putAll(requests); @@ -366,7 +364,7 @@ public void testSingleDocNotFound() { return null; }).when(checkpoint).batchRead(any(), any()); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request2); worker.putAll(requests); @@ -436,7 +434,7 @@ public void testTimeout() { return null; }).when(checkpoint).batchRead(any(), any()); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request2); worker.putAll(requests); @@ -533,9 +531,9 @@ public void testRemoveUnusedQueues() { ExecutorService executorService = mock(ExecutorService.class); when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + 1, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -551,19 +549,19 @@ public void testRemoveUnusedQueues() { modelManager, checkpoint, coldstartQueue, - resultWriteQueue, nodeStateManager, anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats + adStats, + resultWriteStrategy ); regularTestSetUp(new RegularSetUpConfig.Builder().build()); assertTrue(!worker.isQueueEmpty()); - assertEquals(CheckpointReadWorker.WORKER_NAME, worker.getWorkerName()); + assertEquals(ADCheckpointReadWorker.WORKER_NAME, worker.getWorkerName()); // make RequestQueue.expired return true when(clock.instant()).thenReturn(Instant.now().plusSeconds(TimeSeriesSettings.HOURLY_MAINTENANCE.getSeconds() + 1)); @@ -585,7 +583,7 @@ public void testSettingUpdatable() { maintenanceSetup(); // can host two requests in the queue - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( 2000, 1, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, @@ -603,16 +601,16 @@ public void testSettingUpdatable() { modelManager, checkpoint, coldstartQueue, - resultWriteQueue, nodeStateManager, anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats + adStats, + resultWriteStrategy ); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request2); worker.putAll(requests); @@ -638,9 +636,9 @@ public void testOpenCircuitBreaker() { CircuitBreakerService breaker = mock(CircuitBreakerService.class); when(breaker.isOpen()).thenReturn(true); - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + 1, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -656,16 +654,16 @@ public void testOpenCircuitBreaker() { modelManager, checkpoint, coldstartQueue, - resultWriteQueue, nodeStateManager, anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats + adStats, + resultWriteStrategy ); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request2); worker.putAll(requests); @@ -713,23 +711,24 @@ public void testChangePriority() { } public void testDetectorId() { - assertEquals(detectorId, request.getId()); + assertEquals(detectorId, request.getConfigId()); String newDetectorId = "456"; request.setDetectorId(newDetectorId); - assertEquals(newDetectorId, request.getId()); + assertEquals(newDetectorId, request.getConfigId()); } @SuppressWarnings("unchecked") public void testHostException() throws IOException { String detectorId2 = "456"; Entity entity4 = Entity.createSingleAttributeEntity(categoryField, "value4"); - EntityFeatureRequest request4 = new EntityFeatureRequest( + FeatureRequest request4 = new FeatureRequest( Integer.MAX_VALUE, detectorId2, RequestPriority.MEDIUM, - entity4, new double[] { 0 }, - 0 + 0, + entity4, + null ); AnomalyDetector detector2 = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId2, Arrays.asList(categoryField)); @@ -777,7 +776,7 @@ public void testHostException() throws IOException { doThrow(LimitExceededException.class).when(entityCache).hostIfPossible(eq(detector2), any()); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request4); worker.putAll(requests); @@ -803,17 +802,17 @@ public void testFailToScore() { }).when(checkpoint).batchRead(any(), any()); state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); - doThrow(new IllegalArgumentException()).when(modelManager).getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt()); + when(checkpoint.processHCGetResponse(any(), anyString(), anyString())).thenReturn(state); + doThrow(new IllegalArgumentException()).when(modelManager).getResult(any(), any(), anyString(), any(), any(), anyString()); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); worker.putAll(requests); - verify(resultWriteQueue, never()).put(any()); + verify(resultWriteStrategy, never()).saveResult(any(), any(), any(), anyString()); verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); verify(coldstartQueue, times(1)).put(any()); - Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + Object val = adStats.getStat(StatNames.AD_MODEL_CORRUTPION_COUNT.getName()).getValue(); assertEquals(1L, ((Long) val).longValue()); } } diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java index be83484ee..425b5e973 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java @@ -46,9 +46,7 @@ import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; @@ -65,18 +63,22 @@ import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class CheckpointWriteWorkerTests extends AbstractRateLimitingTest { - CheckpointWriteWorker worker; + ADCheckpointWriteWorker worker; - CheckpointDao checkpoint; + ADCheckpointDao checkpoint; ClusterService clusterService; - ModelState state; + ModelState state; @Override @SuppressWarnings("unchecked") @@ -99,14 +101,14 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - checkpoint = mock(CheckpointDao.class); + checkpoint = mock(ADCheckpointDao.class); Map checkpointMap = new HashMap<>(); checkpointMap.put(CommonName.FIELD_MODEL, "a"); when(checkpoint.toIndexSource(any())).thenReturn(checkpointMap); when(checkpoint.shouldSave(any(), anyBoolean(), any(), any())).thenReturn(true); // Integer.MAX_VALUE makes a huge heap - worker = new CheckpointWriteWorker( + worker = new ADCheckpointWriteWorker( Integer.MAX_VALUE, TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, @@ -166,7 +168,7 @@ public void testTriggerSaveAll() { return null; }).when(checkpoint).batchWrite(any(), any()); - List> states = new ArrayList<>(); + List> states = new ArrayList<>(); states.add(state); worker.writeAll(states, detectorId, true, RequestPriority.MEDIUM); @@ -210,7 +212,7 @@ public void testTriggerAutoFlush() throws InterruptedException { // Integer.MAX_VALUE makes a huge heap // create a worker to use mockThreadPool - worker = new CheckpointWriteWorker( + worker = new ADCheckpointWriteWorker( Integer.MAX_VALUE, TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, @@ -239,7 +241,7 @@ public void testTriggerAutoFlush() throws InterruptedException { // CHECKPOINT_WRITE_QUEUE_BATCH_SIZE is the largest batch size int numberOfRequests = 2 * AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.getDefault(Settings.EMPTY) + 1; for (int i = 0; i < numberOfRequests; i++) { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().build()); + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().build()); worker.write(state, true, RequestPriority.MEDIUM); } @@ -268,7 +270,7 @@ public void testOverloaded() { worker.write(state, true, RequestPriority.MEDIUM); verify(checkpoint, times(1)).batchWrite(any(), any()); - verify(nodeStateManager, times(1)).setException(eq(state.getId()), any(OpenSearchRejectedExecutionException.class)); + verify(nodeStateManager, times(1)).setException(eq(state.getConfigId()), any(OpenSearchRejectedExecutionException.class)); } public void testRetryException() { @@ -282,7 +284,7 @@ public void testRetryException() { worker.write(state, true, RequestPriority.MEDIUM); // we don't retry checkpoint write verify(checkpoint, times(1)).batchWrite(any(), any()); - verify(nodeStateManager, times(1)).setException(eq(state.getId()), any(OpenSearchStatusException.class)); + verify(nodeStateManager, times(1)).setException(eq(state.getConfigId()), any(OpenSearchStatusException.class)); } /** @@ -310,7 +312,7 @@ public void testFailedRequest() { @SuppressWarnings("unchecked") public void testEmptyTimeStamp() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.MIN); worker.write(state, false, RequestPriority.MEDIUM); @@ -319,7 +321,7 @@ public void testEmptyTimeStamp() { @SuppressWarnings("unchecked") public void testTooSoonToSaveSingleWrite() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); worker.write(state, false, RequestPriority.MEDIUM); @@ -328,10 +330,10 @@ public void testTooSoonToSaveSingleWrite() { @SuppressWarnings("unchecked") public void testTooSoonToSaveWriteAll() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); - List> states = new ArrayList<>(); + List> states = new ArrayList<>(); states.add(state); worker.writeAll(states, detectorId, false, RequestPriority.MEDIUM); @@ -341,7 +343,7 @@ public void testTooSoonToSaveWriteAll() { @SuppressWarnings("unchecked") public void testEmptyModel() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); when(state.getModel()).thenReturn(null); worker.write(state, true, RequestPriority.MEDIUM); @@ -351,11 +353,11 @@ public void testEmptyModel() { @SuppressWarnings("unchecked") public void testEmptyModelId() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); - EntityModel model = mock(EntityModel.class); - when(state.getModel()).thenReturn(model); - when(state.getId()).thenReturn("1"); + ThresholdedRandomCutForest model = mock(ThresholdedRandomCutForest.class); + when(state.getModel()).thenReturn(Optional.of(model)); + when(state.getConfigId()).thenReturn("1"); when(state.getModelId()).thenReturn(null); worker.write(state, true, RequestPriority.MEDIUM); @@ -364,11 +366,11 @@ public void testEmptyModelId() { @SuppressWarnings("unchecked") public void testEmptyDetectorId() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); - EntityModel model = mock(EntityModel.class); - when(state.getModel()).thenReturn(model); - when(state.getId()).thenReturn(null); + ThresholdedRandomCutForest model = mock(ThresholdedRandomCutForest.class); + when(state.getModel()).thenReturn(Optional.of(model)); + when(state.getConfigId()).thenReturn(null); when(state.getModelId()).thenReturn("a"); worker.write(state, true, RequestPriority.MEDIUM); @@ -395,7 +397,7 @@ public void testDetectorNotAvailableWriteAll() { return null; }).when(nodeStateManager).getConfig(any(String.class), eq(AnalysisType.AD), any(ActionListener.class)); - List> states = new ArrayList<>(); + List> states = new ArrayList<>(); states.add(state); worker.writeAll(states, detectorId, true, RequestPriority.MEDIUM); verify(checkpoint, never()).batchWrite(any(), any()); diff --git a/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java index d093f20ae..96d176a7f 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java @@ -33,14 +33,16 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; public class ColdEntityWorkerTests extends AbstractRateLimitingTest { ClusterService clusterService; - ColdEntityWorker coldWorker; - CheckpointReadWorker readWorker; - EntityFeatureRequest request, request2, invalidRequest; - List requests; + ADColdEntityWorker coldWorker; + ADCheckpointReadWorker readWorker; + FeatureRequest request, request2, invalidRequest; + List requests; @Override public void setUp() throws Exception { @@ -63,12 +65,12 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - readWorker = mock(CheckpointReadWorker.class); + readWorker = mock(ADCheckpointReadWorker.class); // Integer.MAX_VALUE makes a huge heap - coldWorker = new ColdEntityWorker( + coldWorker = new ADColdEntityWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -85,9 +87,9 @@ public void setUp() throws Exception { nodeStateManager ); - request = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity, new double[] { 0 }, 0); - request2 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity2, new double[] { 0 }, 0); - invalidRequest = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity2, new double[] { 0 }, 0); + request = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, new double[] { 0 }, 0, entity, null); + request2 = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, new double[] { 0 }, 0, entity2, null); + invalidRequest = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, new double[] { 0 }, 0, entity2, null); requests = new ArrayList<>(); requests.add(request); @@ -154,9 +156,9 @@ public void testDelay() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); // Integer.MAX_VALUE makes a huge heap - coldWorker = new ColdEntityWorker( + coldWorker = new ADColdEntityWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), diff --git a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java index 9fdf5a396..fad5b18f4 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java @@ -24,14 +24,12 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; -import java.util.Optional; import java.util.Random; import org.opensearch.OpenSearchStatusException; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -40,15 +38,20 @@ import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.core.rest.RestStatus; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + import test.org.opensearch.ad.util.MLUtil; public class EntityColdStartWorkerTests extends AbstractRateLimitingTest { ClusterService clusterService; - EntityColdStartWorker worker; - EntityColdStarter entityColdStarter; - CacheProvider cacheProvider; + ADColdStartWorker worker; + ADColdStart entityColdStarter; + ADPriorityCache cacheProvider; @Override public void setUp() throws Exception { @@ -69,14 +72,14 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - entityColdStarter = mock(EntityColdStarter.class); + entityColdStarter = mock(ADColdStart.class); - cacheProvider = mock(CacheProvider.class); + cacheProvider = mock(ADPriorityCache.class); // Integer.MAX_VALUE makes a huge heap - worker = new EntityColdStartWorker( + worker = new ADColdStartWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -92,21 +95,31 @@ public void setUp() throws Exception { entityColdStarter, TimeSeriesSettings.HOURLY_MAINTENANCE, nodeStateManager, - cacheProvider + cacheProvider, + mock(ADModelManager.class), + mock(ADSaveResultStrategy.class) ); } public void testEmptyModelId() { - EntityRequest request = mock(EntityRequest.class); + FeatureRequest request = mock(FeatureRequest.class); when(request.getPriority()).thenReturn(RequestPriority.LOW); - when(request.getModelId()).thenReturn(Optional.empty()); + when(request.getModelId()).thenReturn(null); worker.put(request); verify(entityColdStarter, never()).trainModel(any(), anyString(), any(), any()); verify(request, times(1)).getModelId(); } public void testOverloaded() { - EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + FeatureRequest request = new FeatureRequest( + Integer.MAX_VALUE, + detectorId, + RequestPriority.MEDIUM, + new double[] { 0 }, + 0, + entity, + null + ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -126,7 +139,15 @@ public void testOverloaded() { } public void testException() { - EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + FeatureRequest request = new FeatureRequest( + Integer.MAX_VALUE, + detectorId, + RequestPriority.MEDIUM, + new double[] { 0 }, + 0, + entity, + null + ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -147,13 +168,21 @@ public void testException() { } public void testModelHosted() { - EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + FeatureRequest request = new FeatureRequest( + Integer.MAX_VALUE, + detectorId, + RequestPriority.MEDIUM, + new double[] { 0 }, + 0, + entity, + null + ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); - ModelState state = invocation.getArgument(2); - state.setModel(MLUtil.createNonEmptyModel(detectorId)); + ModelState state = invocation.getArgument(2); + state.setModel(MLUtil.createNonEmptyModel(detectorId).getLeft()); listener.onResponse(null); return null; @@ -161,6 +190,6 @@ public void testModelHosted() { worker.put(request); - verify(cacheProvider, times(1)).get(); + verify(cacheProvider, times(1)).get(anyString(), any()); } } diff --git a/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java index 304a942c7..2d829722b 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java @@ -37,8 +37,7 @@ import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.transport.ADResultBulkRequest; -import org.opensearch.ad.transport.ADResultBulkResponse; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -49,13 +48,15 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.ResultBulkResponse; import org.opensearch.timeseries.util.RestHandlerUtils; public class ResultWriteWorkerTests extends AbstractRateLimitingTest { - ResultWriteWorker resultWriteQueue; + ADResultWriteWorker resultWriteQueue; ClusterService clusterService; - MultiEntityResultHandler resultHandler; + ADIndexMemoryPressureAwareResultHandler resultHandler; AnomalyResult detectResult; @Override @@ -82,9 +83,9 @@ public void setUp() throws Exception { threadPool = mock(ThreadPool.class); setUpADThreadPool(threadPool); - resultHandler = mock(MultiEntityResultHandler.class); + resultHandler = mock(ADIndexMemoryPressureAwareResultHandler.class); - resultWriteQueue = new ResultWriteWorker( + resultWriteQueue = new ADResultWriteWorker( Integer.MAX_VALUE, TimeSeriesSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, @@ -111,10 +112,10 @@ public void setUp() throws Exception { public void testRegular() { List retryRequests = new ArrayList<>(); - ADResultBulkResponse resp = new ADResultBulkResponse(retryRequests); + ResultBulkResponse resp = new ResultBulkResponse(retryRequests); ADResultBulkRequest request = new ADResultBulkRequest(); - ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + ADResultWriteRequest resultWriteRequest = new ADResultWriteRequest( Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), detectorId, RequestPriority.MEDIUM, @@ -124,12 +125,12 @@ public void testRegular() { request.add(resultWriteRequest); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); listener.onResponse(resp); return null; }).when(resultHandler).flush(any(), any()); - resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + resultWriteQueue.put(new ADResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); // the request results one flush verify(resultHandler, times(1)).flush(any(), any()); @@ -143,10 +144,10 @@ public void testSingleRetryRequest() throws IOException { retryRequests.add(indexRequest); } - ADResultBulkResponse resp = new ADResultBulkResponse(retryRequests); + ResultBulkResponse resp = new ResultBulkResponse(retryRequests); ADResultBulkRequest request = new ADResultBulkRequest(); - ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + ADResultWriteRequest resultWriteRequest = new ADResultWriteRequest( Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), detectorId, RequestPriority.MEDIUM, @@ -157,9 +158,9 @@ public void testSingleRetryRequest() throws IOException { final AtomicBoolean retried = new AtomicBoolean(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); if (retried.get()) { - listener.onResponse(new ADResultBulkResponse()); + listener.onResponse(new ResultBulkResponse()); } else { retried.set(true); listener.onResponse(resp); @@ -167,7 +168,7 @@ public void testSingleRetryRequest() throws IOException { return null; }).when(resultHandler).flush(any(), any()); - resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + resultWriteQueue.put(new ADResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); // one flush from the original request; and one due to retry verify(resultHandler, times(2)).flush(any(), any()); @@ -176,9 +177,9 @@ public void testSingleRetryRequest() throws IOException { public void testRetryException() { final AtomicBoolean retried = new AtomicBoolean(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); if (retried.get()) { - listener.onResponse(new ADResultBulkResponse()); + listener.onResponse(new ResultBulkResponse()); } else { retried.set(true); listener.onFailure(new OpenSearchStatusException("blah", RestStatus.REQUEST_TIMEOUT)); @@ -187,7 +188,7 @@ public void testRetryException() { return null; }).when(resultHandler).flush(any(), any()); - resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + resultWriteQueue.put(new ADResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); // one flush from the original request; and one due to retry verify(resultHandler, times(2)).flush(any(), any()); verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchStatusException.class)); @@ -195,13 +196,13 @@ public void testRetryException() { public void testOverloaded() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); listener.onFailure(new OpenSearchRejectedExecutionException("blah", true)); return null; }).when(resultHandler).flush(any(), any()); - resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + resultWriteQueue.put(new ADResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); // one flush from the original request; and one due to retry verify(resultHandler, times(1)).flush(any(), any()); verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchRejectedExecutionException.class)); diff --git a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java index 3411f37ac..d7148fb2e 100644 --- a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java +++ b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java @@ -43,10 +43,14 @@ import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -211,7 +215,11 @@ public static Response createAnomalyDetector( categoryFields, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(1), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); if (historical) { @@ -260,12 +268,12 @@ public static List searchLatestAdTaskOfDetector(RestClient client, Strin for (Object adTaskResponse : adTaskResponses) { String id = (String) ((Map) adTaskResponse).get("_id"); Map source = (Map) ((Map) adTaskResponse).get("_source"); - String state = (String) source.get(ADTask.STATE_FIELD); + String state = (String) source.get(TimeSeriesTask.STATE_FIELD); String parsedDetectorId = (String) source.get(ADTask.DETECTOR_ID_FIELD); - Double taskProgress = (Double) source.get(ADTask.TASK_PROGRESS_FIELD); - Double initProgress = (Double) source.get(ADTask.INIT_PROGRESS_FIELD); - String parsedTaskType = (String) source.get(ADTask.TASK_TYPE_FIELD); - String coordinatingNode = (String) source.get(ADTask.COORDINATING_NODE_FIELD); + Double taskProgress = (Double) source.get(TimeSeriesTask.TASK_PROGRESS_FIELD); + Double initProgress = (Double) source.get(TimeSeriesTask.INIT_PROGRESS_FIELD); + String parsedTaskType = (String) source.get(TimeSeriesTask.TASK_TYPE_FIELD); + String coordinatingNode = (String) source.get(TimeSeriesTask.COORDINATING_NODE_FIELD); ADTask adTask = ADTask .builder() .taskId(id) @@ -366,7 +374,8 @@ public static Map getDetectorWithJobAndTask(RestClient client, S Instant.ofEpochMilli(lastUpdateTime), null, null, - null + null, + AnalysisType.AD ); results.put(ANOMALY_DETECTOR_JOB, job); } @@ -387,13 +396,13 @@ public static Map getDetectorWithJobAndTask(RestClient client, S } private static ADTask parseAdTask(Map taskMap) { - String id = (String) taskMap.get(ADTask.TASK_ID_FIELD); - String state = (String) taskMap.get(ADTask.STATE_FIELD); + String id = (String) taskMap.get(TimeSeriesTask.TASK_ID_FIELD); + String state = (String) taskMap.get(TimeSeriesTask.STATE_FIELD); String parsedDetectorId = (String) taskMap.get(ADTask.DETECTOR_ID_FIELD); - Double taskProgress = (Double) taskMap.get(ADTask.TASK_PROGRESS_FIELD); - Double initProgress = (Double) taskMap.get(ADTask.INIT_PROGRESS_FIELD); - String parsedTaskType = (String) taskMap.get(ADTask.TASK_TYPE_FIELD); - String coordinatingNode = (String) taskMap.get(ADTask.COORDINATING_NODE_FIELD); + Double taskProgress = (Double) taskMap.get(TimeSeriesTask.TASK_PROGRESS_FIELD); + Double initProgress = (Double) taskMap.get(TimeSeriesTask.INIT_PROGRESS_FIELD); + String parsedTaskType = (String) taskMap.get(TimeSeriesTask.TASK_TYPE_FIELD); + String coordinatingNode = (String) taskMap.get(TimeSeriesTask.COORDINATING_NODE_FIELD); return ADTask .builder() .taskId(id) @@ -465,16 +474,16 @@ public static String startHistoricalAnalysis(RestClient client, String detectorI return taskId; } - public static ADTaskProfile waitUntilTaskDone(RestClient client, String detectorId) throws InterruptedException { + public static TaskProfile waitUntilTaskDone(RestClient client, String detectorId) throws InterruptedException { return waitUntilTaskReachState(client, detectorId, TestHelpers.HISTORICAL_ANALYSIS_DONE_STATS); } - public static ADTaskProfile waitUntilTaskReachState(RestClient client, String detectorId, Set targetStates) + public static TaskProfile waitUntilTaskReachState(RestClient client, String detectorId, Set targetStates) throws InterruptedException { int i = 0; int retryTimes = 200; - ADTaskProfile adTaskProfile = null; - while ((adTaskProfile == null || !targetStates.contains(adTaskProfile.getAdTask().getState())) && i < retryTimes) { + TaskProfile adTaskProfile = null; + while ((adTaskProfile == null || !targetStates.contains(adTaskProfile.getTask().getState())) && i < retryTimes) { try { adTaskProfile = getADTaskProfile(client, detectorId); } catch (Exception e) { @@ -488,7 +497,7 @@ public static ADTaskProfile waitUntilTaskReachState(RestClient client, String de return adTaskProfile; } - public static ADTaskProfile getADTaskProfile(RestClient client, String detectorId) throws IOException, ParseException { + public static TaskProfile getADTaskProfile(RestClient client, String detectorId) throws IOException, ParseException { Response profileResponse = TestHelpers .makeRequest( client, @@ -501,10 +510,10 @@ public static ADTaskProfile getADTaskProfile(RestClient client, String detectorI return parseADTaskProfile(profileResponse); } - public static ADTaskProfile parseADTaskProfile(Response profileResponse) throws IOException, ParseException { + public static TaskProfile parseADTaskProfile(Response profileResponse) throws IOException, ParseException { String profileResult = EntityUtils.toString(profileResponse.getEntity()); XContentParser parser = TestHelpers.parser(profileResult); - ADTaskProfile adTaskProfile = null; + TaskProfile adTaskProfile = null; while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); parser.nextToken(); diff --git a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java index 084e2d44f..49890fbc0 100644 --- a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java @@ -14,7 +14,6 @@ import static org.hamcrest.Matchers.containsString; import static org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler.DUPLICATE_DETECTOR_MSG; import static org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import java.io.IOException; import java.time.Instant; @@ -37,7 +36,6 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorExecutionInput; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; @@ -54,6 +52,7 @@ import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.google.common.collect.ImmutableList; @@ -127,6 +126,7 @@ private AnomalyDetector createIndexAndGetAnomalyDetector(String indexName, List< public void testCreateAnomalyDetectorWithDuplicateName() throws Exception { AnomalyDetector detector = createIndexAndGetAnomalyDetector(INDEX_NAME); + Feature feature = TestHelpers.randomFeature(); AnomalyDetector detectorDuplicateName = new AnomalyDetector( AnomalyDetector.NO_ID, randomLong(), @@ -134,7 +134,7 @@ public void testCreateAnomalyDetectorWithDuplicateName() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), detector.getIndices(), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -145,7 +145,11 @@ public void testCreateAnomalyDetectorWithDuplicateName() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); TestHelpers @@ -200,6 +204,8 @@ public void testUpdateAnomalyDetectorCategoryField() throws Exception { assertEquals("Create anomaly detector failed", RestStatus.CREATED, TestHelpers.restStatus(response)); Map responseMap = entityAsMap(response); String id = (String) responseMap.get("_id"); + List features = detector.getFeatureAttributes(); + long expectedFeatures = features.stream().filter(Feature::getEnabled).count(); AnomalyDetector newDetector = new AnomalyDetector( id, null, @@ -207,7 +213,7 @@ public void testUpdateAnomalyDetectorCategoryField() throws Exception { detector.getDescription(), detector.getTimeField(), detector.getIndices(), - detector.getFeatureAttributes(), + features, detector.getFilterQuery(), detector.getInterval(), detector.getWindowDelay(), @@ -218,7 +224,11 @@ public void testUpdateAnomalyDetectorCategoryField() throws Exception { ImmutableList.of(randomAlphaOfLength(5)), detector.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption((int) expectedFeatures), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); Exception ex = expectThrows( ResponseException.class, @@ -257,6 +267,8 @@ public void testGetNotExistingAnomalyDetector() throws Exception { public void testUpdateAnomalyDetector() throws Exception { AnomalyDetector detector = createAnomalyDetector(createIndexAndGetAnomalyDetector(INDEX_NAME), true, client()); String newDescription = randomAlphaOfLength(5); + List features = detector.getFeatureAttributes(); + long expectedFeatures = features.stream().filter(Feature::getEnabled).count(); AnomalyDetector newDetector = new AnomalyDetector( detector.getId(), detector.getVersion(), @@ -264,7 +276,7 @@ public void testUpdateAnomalyDetector() throws Exception { newDescription, detector.getTimeField(), detector.getIndices(), - detector.getFeatureAttributes(), + features, detector.getFilterQuery(), detector.getInterval(), detector.getWindowDelay(), @@ -275,7 +287,11 @@ public void testUpdateAnomalyDetector() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption((int) expectedFeatures), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); @@ -319,6 +335,8 @@ public void testUpdateAnomalyDetector() throws Exception { public void testUpdateAnomalyDetectorNameToExisting() throws Exception { AnomalyDetector detector1 = createIndexAndGetAnomalyDetector("index-test-one"); AnomalyDetector detector2 = createIndexAndGetAnomalyDetector("index-test-two"); + List features = detector1.getFeatureAttributes(); + long expectedFeatures = features.stream().filter(Feature::getEnabled).count(); AnomalyDetector newDetector1WithDetector2Name = new AnomalyDetector( detector1.getId(), detector1.getVersion(), @@ -326,7 +344,7 @@ public void testUpdateAnomalyDetectorNameToExisting() throws Exception { detector1.getDescription(), detector1.getTimeField(), detector1.getIndices(), - detector1.getFeatureAttributes(), + features, detector1.getFilterQuery(), detector1.getInterval(), detector1.getWindowDelay(), @@ -337,7 +355,11 @@ public void testUpdateAnomalyDetectorNameToExisting() throws Exception { null, detector1.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption((int) expectedFeatures), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); TestHelpers @@ -358,6 +380,8 @@ public void testUpdateAnomalyDetectorNameToExisting() throws Exception { public void testUpdateAnomalyDetectorNameToNew() throws Exception { AnomalyDetector detector = createAnomalyDetector(createIndexAndGetAnomalyDetector(INDEX_NAME), true, client()); + List features = detector.getFeatureAttributes(); + long expectedFeatures = features.stream().filter(Feature::getEnabled).count(); AnomalyDetector detectorWithNewName = new AnomalyDetector( detector.getId(), detector.getVersion(), @@ -365,7 +389,7 @@ public void testUpdateAnomalyDetectorNameToNew() throws Exception { detector.getDescription(), detector.getTimeField(), detector.getIndices(), - detector.getFeatureAttributes(), + features, detector.getFilterQuery(), detector.getInterval(), detector.getWindowDelay(), @@ -376,7 +400,11 @@ public void testUpdateAnomalyDetectorNameToNew() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption((int) expectedFeatures), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); TestHelpers @@ -403,7 +431,8 @@ public void testUpdateAnomalyDetectorWithNotExistingIndex() throws Exception { AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); String newDescription = randomAlphaOfLength(5); - + List features = detector.getFeatureAttributes(); + long expectedFeatures = features.stream().filter(Feature::getEnabled).count(); AnomalyDetector newDetector = new AnomalyDetector( detector.getId(), detector.getVersion(), @@ -411,7 +440,7 @@ public void testUpdateAnomalyDetectorWithNotExistingIndex() throws Exception { newDescription, detector.getTimeField(), detector.getIndices(), - detector.getFeatureAttributes(), + features, detector.getFilterQuery(), detector.getInterval(), detector.getWindowDelay(), @@ -422,7 +451,11 @@ public void testUpdateAnomalyDetectorWithNotExistingIndex() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption((int) expectedFeatures), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); deleteIndexWithAdminClient(CommonName.CONFIG_INDEX); @@ -766,7 +799,8 @@ public void testUpdateAnomalyDetectorWithRunningAdJob() throws Exception { assertEquals("Fail to start AD job", RestStatus.OK, TestHelpers.restStatus(startAdJobResponse)); String newDescription = randomAlphaOfLength(5); - + List features = detector.getFeatureAttributes(); + long expectedFeatures = features.stream().filter(Feature::getEnabled).count(); AnomalyDetector newDetector = new AnomalyDetector( detector.getId(), detector.getVersion(), @@ -774,7 +808,7 @@ public void testUpdateAnomalyDetectorWithRunningAdJob() throws Exception { newDescription, detector.getTimeField(), detector.getIndices(), - detector.getFeatureAttributes(), + features, detector.getFilterQuery(), detector.getInterval(), detector.getWindowDelay(), @@ -785,7 +819,11 @@ public void testUpdateAnomalyDetectorWithRunningAdJob() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption((int) expectedFeatures), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); TestHelpers @@ -895,7 +933,7 @@ public void testStartAdJobWithNonexistingDetector() throws Exception { TestHelpers .assertFailWith( ResponseException.class, - FAIL_TO_FIND_CONFIG_MSG, + CommonMessages.FAIL_TO_FIND_CONFIG_MSG, () -> TestHelpers .makeRequest( client(), @@ -997,7 +1035,7 @@ public void testStopNonExistingAdJob() throws Exception { TestHelpers .assertFailWith( ResponseException.class, - FAIL_TO_FIND_CONFIG_MSG, + CommonMessages.FAIL_TO_FIND_CONFIG_MSG, () -> TestHelpers .makeRequest( client(), @@ -1055,7 +1093,7 @@ public void testStartAdjobWithNullFeatures() throws Exception { TestHelpers .assertFailWith( ResponseException.class, - "Can't start detector job as no features configured", + "Can't start job as no features configured", () -> TestHelpers .makeRequest( client(), @@ -1076,7 +1114,7 @@ public void testStartAdjobWithEmptyFeatures() throws Exception { TestHelpers .assertFailWith( ResponseException.class, - "Can't start detector job as no features configured", + "Can't start job as no features configured", () -> TestHelpers .makeRequest( client(), @@ -1161,7 +1199,7 @@ public void testRunDetectorWithNoEnabledFeature() throws Exception { ResponseException.class, () -> startAnomalyDetector(detector.getId(), new DateRange(now.minus(10, ChronoUnit.DAYS), now), client()) ); - assertTrue(e.getMessage().contains("Can't start detector job as no enabled features configured")); + assertTrue(e.getMessage().contains("Can't start job as no enabled features configured")); } public void testDeleteAnomalyDetectorWhileRunning() throws Exception { @@ -1332,7 +1370,7 @@ public void testValidateAnomalyDetectorOnWrongValidationType() throws Exception TestHelpers .assertFailWith( ResponseException.class, - ADCommonMessages.NOT_EXISTENT_VALIDATION_TYPE, + CommonMessages.NOT_EXISTENT_VALIDATION_TYPE, () -> TestHelpers .makeRequest( client(), @@ -1475,7 +1513,7 @@ public void testValidateAnomalyDetectorWithWrongCategoryField() throws Exception .extractValue("detector", responseMap); assertEquals( "non-existing category", - String.format(Locale.ROOT, AbstractAnomalyDetectorActionHandler.CATEGORY_NOT_FOUND_ERR_MSG, "host.keyword"), + String.format(Locale.ROOT, AbstractTimeSeriesActionHandler.CATEGORY_NOT_FOUND_ERR_MSG, "host.keyword"), messageMap.get("category_field").get("message") ); diff --git a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java index e3881c968..444115302 100644 --- a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java @@ -17,8 +17,8 @@ import static org.opensearch.timeseries.TestHelpers.AD_BASE_STATS_URI; import static org.opensearch.timeseries.TestHelpers.HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS; import static org.opensearch.timeseries.stats.StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT; -import static org.opensearch.timeseries.stats.StatNames.MULTI_ENTITY_DETECTOR_COUNT; -import static org.opensearch.timeseries.stats.StatNames.SINGLE_ENTITY_DETECTOR_COUNT; +import static org.opensearch.timeseries.stats.StatNames.HC_DETECTOR_COUNT; +import static org.opensearch.timeseries.stats.StatNames.SINGLE_STREAM_DETECTOR_COUNT; import java.io.IOException; import java.util.List; @@ -39,9 +39,11 @@ import org.opensearch.client.ResponseException; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -92,9 +94,9 @@ public void testHistoricalAnalysisForMultiCategoryHC() throws Exception { private void checkIfTaskCanFinishCorrectly(String detectorId, String taskId, Set states) throws InterruptedException { List results = waitUntilTaskDone(detectorId); - ADTaskProfile endTaskProfile = (ADTaskProfile) results.get(0); + TaskProfile endTaskProfile = (TaskProfile) results.get(0); Integer retryCount = (Integer) results.get(1); - ADTask stoppedAdTask = endTaskProfile.getAdTask(); + ADTask stoppedAdTask = endTaskProfile.getTask(); assertEquals(taskId, stoppedAdTask.getTaskId()); if (retryCount < MAX_RETRY_TIMES) { // It's possible that historical analysis still running after max retry times @@ -118,14 +120,14 @@ private List startHistoricalAnalysis(int categoryFieldSize, String resul // get task profile ADTaskProfile adTaskProfile = waitUntilGetTaskProfile(detectorId); if (categoryFieldSize > 0) { - if (!TaskState.RUNNING.name().equals(adTaskProfile.getAdTask().getState())) { + if (!TaskState.RUNNING.name().equals(adTaskProfile.getTask().getState())) { adTaskProfile = (ADTaskProfile) waitUntilTaskReachState(detectorId, ImmutableSet.of(TaskState.RUNNING.name())).get(0); } assertEquals((int) Math.pow(categoryFieldDocCount, categoryFieldSize), adTaskProfile.getTotalEntitiesCount().intValue()); assertTrue(adTaskProfile.getPendingEntitiesCount() > 0); assertTrue(adTaskProfile.getRunningEntitiesCount() > 0); } - ADTask adTask = adTaskProfile.getAdTask(); + ADTask adTask = adTaskProfile.getTask(); assertEquals(taskId, adTask.getTaskId()); assertTrue(TestHelpers.HISTORICAL_ANALYSIS_RUNNING_STATS.contains(adTask.getState())); @@ -133,7 +135,7 @@ private List startHistoricalAnalysis(int categoryFieldSize, String resul Response statsResponse = TestHelpers.makeRequest(client(), "GET", AD_BASE_STATS_URI, ImmutableMap.of(), "", null); String statsResult = EntityUtils.toString(statsResponse.getEntity()); Map stringObjectMap = TestHelpers.parseStatsResult(statsResult); - String detectorCountState = categoryFieldSize > 0 ? MULTI_ENTITY_DETECTOR_COUNT.getName() : SINGLE_ENTITY_DETECTOR_COUNT.getName(); + String detectorCountState = categoryFieldSize > 0 ? HC_DETECTOR_COUNT.getName() : SINGLE_STREAM_DETECTOR_COUNT.getName(); assertTrue((long) stringObjectMap.get(detectorCountState) > 0); Map nodes = (Map) stringObjectMap.get("nodes"); long totalBatchTaskExecution = 0; @@ -317,7 +319,11 @@ private AnomalyDetector randomAnomalyDetector(AnomalyDetector detector) { detector.getCategoryFields(), detector.getUser(), detector.getCustomResultIndex(), - detector.getImputationOption() + detector.getImputationOption(), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); } diff --git a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java index 3d1aeab7d..cf62c63c8 100644 --- a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java +++ b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java @@ -38,6 +38,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.google.common.collect.ImmutableList; @@ -84,8 +85,9 @@ public static String generatePassword(String username) { @Before public void setupSecureTests() throws IOException { - if (!isHttps()) + if (!isHttps()) { throw new IllegalArgumentException("Secure Tests are running but HTTPS is not set"); + } createIndexRole(indexAllAccessRole, "*"); createSearchRole(indexSearchAccessRole, "*"); String alicePassword = generatePassword(aliceUser); @@ -266,7 +268,11 @@ public void testUpdateApiFilterByEnabledForAdmin() throws IOException { ImmutableList.of(randomAlphaOfLength(5)) ), null, - aliceDetector.getImputationOption() + aliceDetector.getImputationOption(), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); // User client has admin all access, and has "opensearch" backend role so client should be able to update detector // But the detector's backend role should not be replaced as client's backend roles (all_access). @@ -313,7 +319,11 @@ public void testUpdateApiFilterByEnabled() throws IOException { ImmutableList.of(randomAlphaOfLength(5)) ), null, - aliceDetector.getImputationOption() + aliceDetector.getImputationOption(), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); enableFilterBy(); // User Fish has AD full access, and has "odfe" backend role which is one of Alice's backend role, so diff --git a/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java index 3bb0f1fbb..ae18aa6c2 100644 --- a/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java +++ b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java @@ -21,7 +21,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_LATEST_TASK; import java.io.IOException; import java.util.Arrays; @@ -36,19 +35,18 @@ import org.opensearch.ad.ExecuteADResultResponseRecorder; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.ad.transport.AnomalyResultAction; import org.opensearch.ad.transport.AnomalyResultResponse; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileResponse; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; import org.opensearch.client.Client; -import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.search.aggregations.AggregationBuilder; @@ -58,8 +56,11 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.InternalFailure; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportService; @@ -69,12 +70,8 @@ public class IndexAnomalyDetectorJobActionHandlerTests extends OpenSearchTestCas private static ADIndexManagement anomalyDetectionIndices; private static String detectorId; - private static Long seqNo; - private static Long primaryTerm; private static NamedXContentRegistry xContentRegistry; - private static TransportService transportService; - private static TimeValue requestTimeout; private static DiscoveryNodeFilterer nodeFilter; private static AnomalyDetector detector; @@ -84,22 +81,20 @@ public class IndexAnomalyDetectorJobActionHandlerTests extends OpenSearchTestCas private ExecuteADResultResponseRecorder recorder; private Client client; - private IndexAnomalyDetectorJobActionHandler handler; - private AnomalyIndexHandler anomalyResultHandler; + private ADIndexJobActionHandler handler; + private ResultBulkIndexingHandler anomalyResultHandler; private NodeStateManager nodeStateManager; private ADTaskCacheManager adTaskCacheManager; + private TransportService transportService; @BeforeClass public static void setOnce() throws IOException { detectorId = "123"; - seqNo = 1L; - primaryTerm = 2L; anomalyDetectionIndices = mock(ADIndexManagement.class); xContentRegistry = NamedXContentRegistry.EMPTY; - transportService = mock(TransportService.class); - - requestTimeout = TimeValue.timeValueMinutes(60); when(anomalyDetectionIndices.doesJobIndexExist()).thenReturn(true); + // make sure getAndExecuteOnLatestConfigLevelTask called in startConfig + when(anomalyDetectionIndices.doesStateIndexExist()).thenReturn(true); nodeFilter = mock(DiscoveryNodeFilterer.class); detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a")); @@ -137,7 +132,7 @@ public void setUp() throws Exception { Object[] args = invocation.getArguments(); ActionListener listener = (ActionListener) args[2]; - AnomalyResultResponse response = new AnomalyResultResponse(null, "", 0L, 10L, true); + AnomalyResultResponse response = new AnomalyResultResponse(null, "", 0L, 10L, true, null); listener.onResponse(response); return null; @@ -152,11 +147,11 @@ public void setUp() throws Exception { listener.onResponse(response); return null; - }).when(adTaskManager).startDetector(any(), any(), any(), any(), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); threadPool = mock(ThreadPool.class); - anomalyResultHandler = mock(AnomalyIndexHandler.class); + anomalyResultHandler = mock(ResultBulkIndexingHandler.class); nodeStateManager = mock(NodeStateManager.class); @@ -175,18 +170,17 @@ public void setUp() throws Exception { 32 ); - handler = new IndexAnomalyDetectorJobActionHandler( + handler = new ADIndexJobActionHandler( client, anomalyDetectionIndices, - detectorId, - seqNo, - primaryTerm, - requestTimeout, xContentRegistry, - transportService, adTaskManager, - recorder + recorder, + nodeStateManager, + Settings.EMPTY ); + + transportService = mock(TransportService.class); } @SuppressWarnings("unchecked") @@ -195,11 +189,11 @@ public void testDelayHCProfile() { ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(client, times(1)).get(any(), any()); verify(client, times(1)).execute(any(), any(), any()); - verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(threadPool, times(1)).schedule(any(), any(), any()); verify(listener, times(1)).onResponse(any()); @@ -216,17 +210,17 @@ public void testNoDelayHCProfile() { listener.onResponse(response); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); - verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); verify(threadPool, never()).schedule(any(), any(), any()); @@ -242,17 +236,17 @@ public void testHCProfileException() { listener.onFailure(new RuntimeException()); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); - verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(adTaskManager, never()).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); verify(threadPool, never()).schedule(any(), any(), any()); @@ -270,7 +264,7 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingNodeResourceNotFoundExcept listener.onResponse(response); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); @@ -278,18 +272,18 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingNodeResourceNotFoundExcept Object[] args = invocation.getArguments(); ActionListener listener = (ActionListener) args[5]; - listener.onFailure(new ResourceNotFoundException(CAN_NOT_FIND_LATEST_TASK)); + listener.onFailure(new ResourceNotFoundException(CommonMessages.CAN_NOT_FIND_LATEST_TASK)); return null; }).when(adTaskManager).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); - verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); verify(adTaskManager, times(1)).removeRealtimeTaskCache(anyString()); @@ -308,7 +302,7 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingException() { listener.onResponse(response); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); @@ -323,15 +317,15 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingException() { ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); - verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); verify(adTaskManager, never()).removeRealtimeTaskCache(anyString()); - verify(adTaskManager, times(1)).skipUpdateHCRealtimeTask(anyString(), anyString()); + verify(adTaskManager, times(1)).skipUpdateRealtimeTask(anyString(), anyString()); verify(threadPool, never()).schedule(any(), any(), any()); verify(listener, times(1)).onResponse(any()); } @@ -361,7 +355,7 @@ public void testIndexException() throws IOException { ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "index" ); when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(anomalyResultHandler, times(1)).index(any(), any(), eq(null)); verify(threadPool, times(1)).schedule(any(), any(), any()); } diff --git a/src/test/java/org/opensearch/ad/settings/ADEnabledSettingTests.java b/src/test/java/org/opensearch/ad/settings/ADEnabledSettingTests.java index 6de90a068..5e574e77d 100644 --- a/src/test/java/org/opensearch/ad/settings/ADEnabledSettingTests.java +++ b/src/test/java/org/opensearch/ad/settings/ADEnabledSettingTests.java @@ -19,6 +19,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.settings.TimeSeriesEnabledSetting; public class ADEnabledSettingTests extends OpenSearchTestCase { @@ -30,9 +31,9 @@ public void testIsADEnabled() { } public void testIsADBreakerEnabled() { - assertTrue(ADEnabledSetting.isADBreakerEnabled()); + assertTrue(TimeSeriesEnabledSetting.isBreakerEnabled()); ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.AD_BREAKER_ENABLED, false); - assertTrue(!ADEnabledSetting.isADBreakerEnabled()); + assertTrue(!TimeSeriesEnabledSetting.isBreakerEnabled()); } public void testIsInterpolationInColdStartEnabled() { diff --git a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java index 085ea5959..46cd0e619 100644 --- a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java +++ b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java @@ -155,7 +155,7 @@ public void testAllLegacyOpenDistroSettingsFallback() { LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(Settings.EMPTY) ); assertEquals( - TimeSeriesSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY) ); assertEquals( @@ -163,7 +163,7 @@ public void testAllLegacyOpenDistroSettingsFallback() { LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES.get(Settings.EMPTY) ); assertEquals( - TimeSeriesSettings.BACKOFF_MINUTES.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_BACKOFF_MINUTES.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES.get(Settings.EMPTY) ); assertEquals( diff --git a/src/test/java/org/opensearch/ad/stats/ADStatTests.java b/src/test/java/org/opensearch/ad/stats/ADStatTests.java index 1912f92ad..7ec161f1b 100644 --- a/src/test/java/org/opensearch/ad/stats/ADStatTests.java +++ b/src/test/java/org/opensearch/ad/stats/ADStatTests.java @@ -14,32 +14,33 @@ import java.util.function.Supplier; import org.junit.Test; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; public class ADStatTests extends OpenSearchTestCase { @Test public void testIsClusterLevel() { - ADStat stat1 = new ADStat<>(true, new TestSupplier()); + TimeSeriesStat stat1 = new TimeSeriesStat<>(true, new TestSupplier()); assertTrue("isCluster returns the wrong value", stat1.isClusterLevel()); - ADStat stat2 = new ADStat<>(false, new TestSupplier()); + TimeSeriesStat stat2 = new TimeSeriesStat<>(false, new TestSupplier()); assertTrue("isCluster returns the wrong value", !stat2.isClusterLevel()); } @Test public void testGetValue() { - ADStat stat1 = new ADStat<>(false, new CounterSupplier()); + TimeSeriesStat stat1 = new TimeSeriesStat<>(false, new CounterSupplier()); assertEquals("GetValue returns the incorrect value", 0L, (long) (stat1.getValue())); - ADStat stat2 = new ADStat<>(false, new TestSupplier()); + TimeSeriesStat stat2 = new TimeSeriesStat<>(false, new TestSupplier()); assertEquals("GetValue returns the incorrect value", "test", stat2.getValue()); } @Test public void testSetValue() { - ADStat stat = new ADStat<>(false, new SettableSupplier()); + TimeSeriesStat stat = new TimeSeriesStat<>(false, new SettableSupplier()); assertEquals("GetValue returns the incorrect value", 0L, (long) (stat.getValue())); stat.setValue(10L); assertEquals("GetValue returns the incorrect value", 10L, (long) stat.getValue()); @@ -47,7 +48,7 @@ public void testSetValue() { @Test public void testIncrement() { - ADStat incrementStat = new ADStat<>(false, new CounterSupplier()); + TimeSeriesStat incrementStat = new TimeSeriesStat<>(false, new CounterSupplier()); for (Long i = 0L; i < 100; i++) { assertEquals("increment does not work", i, incrementStat.getValue()); @@ -55,7 +56,7 @@ public void testIncrement() { } // Ensure that no problems occur for a stat that cannot be incremented - ADStat nonIncStat = new ADStat<>(false, new TestSupplier()); + TimeSeriesStat nonIncStat = new TimeSeriesStat<>(false, new TestSupplier()); nonIncStat.increment(); } diff --git a/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java b/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java index 194623bd5..b47bd8a0d 100644 --- a/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java +++ b/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java @@ -19,18 +19,19 @@ import org.junit.Test; import org.opensearch.action.FailedNodeException; -import org.opensearch.ad.transport.ADStatsNodeResponse; -import org.opensearch.ad.transport.ADStatsNodesResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsNodesResponse; +import org.opensearch.timeseries.transport.StatsResponse; public class ADStatsResponseTests extends OpenSearchTestCase { @Test public void testGetAndSetClusterStats() { - ADStatsResponse adStatsResponse = new ADStatsResponse(); + StatsResponse adStatsResponse = new StatsResponse(); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_stat", 1L); adStatsResponse.setClusterStats(testClusterStats); @@ -39,53 +40,53 @@ public void testGetAndSetClusterStats() { @Test public void testGetAndSetADStatsNodesResponse() { - ADStatsResponse adStatsResponse = new ADStatsResponse(); - List responses = Collections.emptyList(); + StatsResponse adStatsResponse = new StatsResponse(); + List responses = Collections.emptyList(); List failures = Collections.emptyList(); - ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); - adStatsResponse.setADStatsNodesResponse(adStatsNodesResponse); - assertEquals(adStatsNodesResponse, adStatsResponse.getADStatsNodesResponse()); + StatsNodesResponse adStatsNodesResponse = new StatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse.setStatsNodesResponse(adStatsNodesResponse); + assertEquals(adStatsNodesResponse, adStatsResponse.getStatsNodesResponse()); } @Test public void testMerge() { - ADStatsResponse adStatsResponse1 = new ADStatsResponse(); + StatsResponse adStatsResponse1 = new StatsResponse(); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_stat", 1L); adStatsResponse1.setClusterStats(testClusterStats); - ADStatsResponse adStatsResponse2 = new ADStatsResponse(); - List responses = Collections.emptyList(); + StatsResponse adStatsResponse2 = new StatsResponse(); + List responses = Collections.emptyList(); List failures = Collections.emptyList(); - ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); - adStatsResponse2.setADStatsNodesResponse(adStatsNodesResponse); + StatsNodesResponse adStatsNodesResponse = new StatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse2.setStatsNodesResponse(adStatsNodesResponse); adStatsResponse1.merge(adStatsResponse2); assertEquals(testClusterStats, adStatsResponse1.getClusterStats()); - assertEquals(adStatsNodesResponse, adStatsResponse1.getADStatsNodesResponse()); + assertEquals(adStatsNodesResponse, adStatsResponse1.getStatsNodesResponse()); adStatsResponse2.merge(adStatsResponse1); assertEquals(testClusterStats, adStatsResponse2.getClusterStats()); - assertEquals(adStatsNodesResponse, adStatsResponse2.getADStatsNodesResponse()); + assertEquals(adStatsNodesResponse, adStatsResponse2.getStatsNodesResponse()); // Confirm merging with null does nothing adStatsResponse1.merge(null); assertEquals(testClusterStats, adStatsResponse1.getClusterStats()); - assertEquals(adStatsNodesResponse, adStatsResponse1.getADStatsNodesResponse()); + assertEquals(adStatsNodesResponse, adStatsResponse1.getStatsNodesResponse()); // Confirm merging with self does nothing adStatsResponse1.merge(adStatsResponse1); assertEquals(testClusterStats, adStatsResponse1.getClusterStats()); - assertEquals(adStatsNodesResponse, adStatsResponse1.getADStatsNodesResponse()); + assertEquals(adStatsNodesResponse, adStatsResponse1.getStatsNodesResponse()); } @Test public void testEquals() { - ADStatsResponse adStatsResponse1 = new ADStatsResponse(); + StatsResponse adStatsResponse1 = new StatsResponse(); assertEquals(adStatsResponse1, adStatsResponse1); assertNotEquals(null, adStatsResponse1); assertNotEquals(1, adStatsResponse1); - ADStatsResponse adStatsResponse2 = new ADStatsResponse(); + StatsResponse adStatsResponse2 = new StatsResponse(); assertEquals(adStatsResponse1, adStatsResponse2); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_stat", 1L); @@ -95,8 +96,8 @@ public void testEquals() { @Test public void testHashCode() { - ADStatsResponse adStatsResponse1 = new ADStatsResponse(); - ADStatsResponse adStatsResponse2 = new ADStatsResponse(); + StatsResponse adStatsResponse1 = new StatsResponse(); + StatsResponse adStatsResponse2 = new StatsResponse(); assertEquals(adStatsResponse1.hashCode(), adStatsResponse2.hashCode()); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_stat", 1L); @@ -106,14 +107,14 @@ public void testHashCode() { @Test public void testToXContent() throws IOException { - ADStatsResponse adStatsResponse = new ADStatsResponse(); + StatsResponse adStatsResponse = new StatsResponse(); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_stat", 1); adStatsResponse.setClusterStats(testClusterStats); - List responses = Collections.emptyList(); + List responses = Collections.emptyList(); List failures = Collections.emptyList(); - ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); - adStatsResponse.setADStatsNodesResponse(adStatsNodesResponse); + StatsNodesResponse adStatsNodesResponse = new StatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse.setStatsNodesResponse(adStatsNodesResponse); XContentBuilder builder = XContentFactory.jsonBuilder(); adStatsResponse.toXContent(builder); diff --git a/src/test/java/org/opensearch/ad/stats/ADStatsTests.java b/src/test/java/org/opensearch/ad/stats/ADStatsTests.java index 6db1ac5cc..5c85ae606 100644 --- a/src/test/java/org/opensearch/ad/stats/ADStatsTests.java +++ b/src/test/java/org/opensearch/ad/stats/ADStatsTests.java @@ -17,6 +17,7 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; import java.time.Clock; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -24,36 +25,39 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.HybridThresholdingModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; -import org.opensearch.ad.util.IndexUtils; +import org.opensearch.ad.stats.suppliers.ADModelsOnNodeSupplier; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.util.IndexUtils; import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class ADStatsTests extends OpenSearchTestCase { - private Map> statsMap; + private Map> statsMap; private ADStats adStats; private RandomCutForest rcf; private HybridThresholdingModel thresholdingModel; @@ -64,10 +68,10 @@ public class ADStatsTests extends OpenSearchTestCase { private Clock clock; @Mock - private ModelManager modelManager; + private ADModelManager modelManager; @Mock - private CacheProvider cacheProvider; + private ADCacheProvider cacheProvider; @Before public void setup() { @@ -80,20 +84,58 @@ public void setup() { List> modelsInformation = new ArrayList<>( Arrays .asList( - new ModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f) + new ModelState<>( + rcf, + "rcf-model-1", + "detector-1", + ModelManager.ModelType.RCF.getName(), + clock, + 0f, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + thresholdingModel, + "thr-model-1", + "detector-1", + ModelManager.ModelType.RCF.getName(), + clock, + 0f, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + rcf, + "rcf-model-2", + "detector-2", + ModelManager.ModelType.THRESHOLD.getName(), + clock, + 0f, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + thresholdingModel, + "thr-model-2", + "detector-2", + ModelManager.ModelType.THRESHOLD.getName(), + clock, + 0f, + Optional.empty(), + new ArrayDeque<>() + ) ) ); when(modelManager.getAllModels()).thenReturn(modelsInformation); - ModelState entityModel1 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - ModelState entityModel2 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel1 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel2 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - List> entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); - EntityCache cache = mock(EntityCache.class); + List> entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); + ADPriorityCache cache = mock(ADPriorityCache.class); when(cacheProvider.get()).thenReturn(cache); when(cache.getAllModels()).thenReturn(entityModelsInformation); @@ -115,12 +157,15 @@ public void setup() { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - statsMap = new HashMap>() { + statsMap = new HashMap>() { { - put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); - put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService))); - put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); - put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); + put(nodeStatName1, new TimeSeriesStat<>(false, new CounterSupplier())); + put( + nodeStatName2, + new TimeSeriesStat<>(false, new ADModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService)) + ); + put(clusterStatName1, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); + put(clusterStatName2, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); } }; @@ -134,11 +179,11 @@ public void testStatNamesGetNames() { @Test public void testGetStats() { - Map> stats = adStats.getStats(); + Map> stats = adStats.getStats(); assertEquals("getStats returns the incorrect number of stats", stats.size(), statsMap.size()); - for (Map.Entry> stat : stats.entrySet()) { + for (Map.Entry> stat : stats.entrySet()) { assertTrue( "getStats returns incorrect stats", adStats.getStats().containsKey(stat.getKey()) && adStats.getStats().get(stat.getKey()) == stat.getValue() @@ -148,7 +193,7 @@ public void testGetStats() { @Test public void testGetStat() { - ADStat stat = adStats.getStat(clusterStatName1); + TimeSeriesStat stat = adStats.getStat(clusterStatName1); assertTrue( "getStat returns incorrect stat", @@ -158,10 +203,10 @@ public void testGetStat() { @Test public void testGetNodeStats() { - Map> stats = adStats.getStats(); - Set> nodeStats = new HashSet<>(adStats.getNodeStats().values()); + Map> stats = adStats.getStats(); + Set> nodeStats = new HashSet<>(adStats.getNodeStats().values()); - for (ADStat stat : stats.values()) { + for (TimeSeriesStat stat : stats.values()) { assertTrue( "getNodeStats returns incorrect stat", (stat.isClusterLevel() && !nodeStats.contains(stat)) || (!stat.isClusterLevel() && nodeStats.contains(stat)) @@ -171,10 +216,10 @@ public void testGetNodeStats() { @Test public void testGetClusterStats() { - Map> stats = adStats.getStats(); - Set> clusterStats = new HashSet<>(adStats.getClusterStats().values()); + Map> stats = adStats.getStats(); + Set> clusterStats = new HashSet<>(adStats.getClusterStats().values()); - for (ADStat stat : stats.values()) { + for (TimeSeriesStat stat : stats.values()) { assertTrue( "getClusterStats returns incorrect stat", (stat.isClusterLevel() && clusterStats.contains(stat)) || (!stat.isClusterLevel() && !clusterStats.contains(stat)) diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java index 333d50ffe..3490e0318 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java @@ -13,6 +13,7 @@ import org.junit.Test; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; public class CounterSupplierTests extends OpenSearchTestCase { @Test diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java index cfdf71188..409437490 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java @@ -16,8 +16,9 @@ import org.junit.Before; import org.junit.Test; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.util.IndexUtils; public class IndexSupplierTests extends OpenSearchTestCase { private IndexUtils indexUtils; diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java index 21a9e4aff..c7e17be3b 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java @@ -14,15 +14,17 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; -import static org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier.MODEL_STATE_STAT_KEYS; +import static org.opensearch.ad.stats.suppliers.ADModelsOnNodeSupplier.MODEL_STATE_STAT_KEYS; import java.time.Clock; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -30,18 +32,19 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.HybridThresholdingModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; @@ -51,13 +54,13 @@ public class ModelsOnNodeSupplierTests extends OpenSearchTestCase { private HybridThresholdingModel thresholdingModel; private List> expectedResults; private Clock clock; - private List> entityModelsInformation; + private List> entityModelsInformation; @Mock - private ModelManager modelManager; + private ADModelManager modelManager; @Mock - private CacheProvider cacheProvider; + private ADCacheProvider cacheProvider; @Before public void setup() { @@ -70,20 +73,58 @@ public void setup() { expectedResults = new ArrayList<>( Arrays .asList( - new ModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f) + new ModelState<>( + rcf, + "rcf-model-1", + "detector-1", + ModelManager.ModelType.RCF.getName(), + clock, + 0f, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + thresholdingModel, + "thr-model-1", + "detector-1", + ModelManager.ModelType.RCF.getName(), + clock, + 0f, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + rcf, + "rcf-model-2", + "detector-2", + ModelManager.ModelType.THRESHOLD.getName(), + clock, + 0f, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + thresholdingModel, + "thr-model-2", + "detector-2", + ModelManager.ModelType.THRESHOLD.getName(), + clock, + 0f, + Optional.empty(), + new ArrayDeque<>() + ) ) ); when(modelManager.getAllModels()).thenReturn(expectedResults); - ModelState entityModel1 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - ModelState entityModel2 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel1 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel2 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); - EntityCache cache = mock(EntityCache.class); + ADPriorityCache cache = mock(ADPriorityCache.class); when(cacheProvider.get()).thenReturn(cache); when(cache.getAllModels()).thenReturn(entityModelsInformation); } @@ -98,7 +139,7 @@ public void testGet() { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - ModelsOnNodeSupplier modelsOnNodeSupplier = new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService); + ADModelsOnNodeSupplier modelsOnNodeSupplier = new ADModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService); List> results = modelsOnNodeSupplier.get(); assertEquals( "get fails to return correct result", @@ -119,7 +160,7 @@ public void testGet() { @Test public void testGetModelCount() { - ModelsOnNodeCountSupplier modelsOnNodeSupplier = new ModelsOnNodeCountSupplier(modelManager, cacheProvider); + ADModelsOnNodeCountSupplier modelsOnNodeSupplier = new ADModelsOnNodeCountSupplier(modelManager, cacheProvider); assertEquals(6L, modelsOnNodeSupplier.get().longValue()); } } diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java index 1cf1c9306..821871984 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java @@ -13,6 +13,7 @@ import org.junit.Test; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; public class SettableSupplierTests extends OpenSearchTestCase { @Test diff --git a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java index ad14b49c4..ed43baa73 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java @@ -246,7 +246,7 @@ public void testTopEntityInited() throws IOException { assertTrue(adTaskCacheManager.topEntityInited(detectorId)); } - public void testEntityCache() throws IOException { + public void testADPriorityCache() throws IOException { String detectorId = randomAlphaOfLength(10); assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); @@ -321,14 +321,14 @@ public void testRealtimeTaskCache() { adTaskCacheManager.updateRealtimeTaskCache(detectorId1, newState, newInitProgress, newError); assertFalse(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); - assertArrayEquals(new String[] { detectorId1 }, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()); + assertArrayEquals(new String[] { detectorId1 }, adTaskCacheManager.getConfigIdsInRealtimeTaskCache()); String detectorId2 = randomAlphaOfLength(10); adTaskCacheManager.updateRealtimeTaskCache(detectorId2, newState, newInitProgress, newError); - assertEquals(1, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache().length); + assertEquals(1, adTaskCacheManager.getConfigIdsInRealtimeTaskCache().length); adTaskCacheManager.initRealtimeTaskCache(detectorId2, 60_000); adTaskCacheManager.updateRealtimeTaskCache(detectorId2, newState, newInitProgress, newError); - assertEquals(2, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache().length); + assertEquals(2, adTaskCacheManager.getConfigIdsInRealtimeTaskCache().length); newState = TaskState.RUNNING.name(); newInitProgress = 1.0f; @@ -340,10 +340,10 @@ public void testRealtimeTaskCache() { assertEquals(newError, adTaskCacheManager.getRealtimeTaskCache(detectorId1).getError()); adTaskCacheManager.removeRealtimeTaskCache(detectorId1); - assertArrayEquals(new String[] { detectorId2 }, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()); + assertArrayEquals(new String[] { detectorId2 }, adTaskCacheManager.getConfigIdsInRealtimeTaskCache()); adTaskCacheManager.clearRealtimeTaskCache(); - assertEquals(0, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache().length); + assertEquals(0, adTaskCacheManager.getConfigIdsInRealtimeTaskCache().length); } diff --git a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java index f9df58903..184964343 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java @@ -45,10 +45,10 @@ import static org.opensearch.timeseries.TestHelpers.randomIntervalSchedule; import static org.opensearch.timeseries.TestHelpers.randomIntervalTimeConfiguration; import static org.opensearch.timeseries.TestHelpers.randomUser; -import static org.opensearch.timeseries.constant.CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED; import static org.opensearch.timeseries.model.Entity.createSingleAttributeEntity; import java.io.IOException; +import java.time.Clock; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; @@ -79,8 +79,9 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.ADUnitTestCase; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.ExecuteADResultResponseRecorder; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.model.ADTask; @@ -88,10 +89,7 @@ import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; -import org.opensearch.ad.stats.InternalStatNames; -import org.opensearch.ad.transport.ADStatsNodeResponse; -import org.opensearch.ad.transport.ADStatsNodesResponse; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.transport.ADTaskProfileNodeResponse; import org.opensearch.ad.transport.ADTaskProfileResponse; import org.opensearch.ad.transport.ForwardADTaskRequest; @@ -109,6 +107,7 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; @@ -120,16 +119,26 @@ import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.stats.InternalStatNames; import org.opensearch.timeseries.task.RealtimeTaskCache; import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsNodesResponse; +import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; @@ -154,7 +163,7 @@ public class ADTaskManagerTests extends ADUnitTestCase { private TransportService transportService; private ADTaskManager adTaskManager; private ThreadPool threadPool; - private IndexAnomalyDetectorJobActionHandler indexAnomalyDetectorJobActionHandler; + private ADIndexJobActionHandler indexAnomalyDetectorJobActionHandler; private DateRange detectionDateRange; private ActionListener listener; @@ -202,6 +211,9 @@ public class ADTaskManagerTests extends ADUnitTestCase { @Captor ArgumentCaptor> remoteResponseHandler; + NodeStateManager nodeStateManager; + ADTaskProfileRunner taskProfileRunner; + @Override public void setUp() throws Exception { super.setUp(); @@ -240,7 +252,8 @@ public void setUp() throws Exception { threadContext = new ThreadContext(settings); when(threadPool.getThreadContext()).thenReturn(threadContext); when(client.threadPool()).thenReturn(threadPool); - indexAnomalyDetectorJobActionHandler = mock(IndexAnomalyDetectorJobActionHandler.class); + nodeStateManager = mock(NodeStateManager.class); + taskProfileRunner = mock(ADTaskProfileRunner.class); adTaskManager = spy( new ADTaskManager( settings, @@ -251,9 +264,20 @@ public void setUp() throws Exception { nodeFilter, hashRing, adTaskCacheManager, - threadPool + threadPool, + nodeStateManager, + taskProfileRunner ) ); + indexAnomalyDetectorJobActionHandler = new ADIndexJobActionHandler( + client, + detectionIndices, + mock(NamedXContentRegistry.class), + adTaskManager, + mock(ExecuteADResultResponseRecorder.class), + nodeStateManager, + Settings.EMPTY + ); listener = spy(new ActionListener() { @Override @@ -313,7 +337,7 @@ private void setupHashRingWithSameLocalADVersionNodes() { Consumer function = invocation.getArgument(0); function.accept(new DiscoveryNode[] { node1, node2 }); return null; - }).when(hashRing).getNodesWithSameLocalAdVersion(any(), any()); + }).when(hashRing).getNodesWithSameLocalVersion(any(), any()); } private void setupHashRingWithOwningNode() { @@ -321,7 +345,7 @@ private void setupHashRingWithOwningNode() { Consumer> function = invocation.getArgument(1); function.accept(Optional.of(node1)); return null; - }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(any(), any(), any()); + }).when(hashRing).buildAndGetOwningNodeWithSameLocalVersion(any(), any(), any()); } public void testCreateTaskIndexNotAcknowledged() throws IOException { @@ -334,9 +358,9 @@ public void testCreateTaskIndexNotAcknowledged() throws IOException { AnomalyDetector detector = randomDetector(ImmutableList.of(randomFeature(true)), randomAlphaOfLength(5), 1, randomAlphaOfLength(5)); setupGetDetector(detector); - adTaskManager.startDetector(detector, detectionDateRange, randomUser(), transportService, listener); + adTaskManager.getAndExecuteOnLatestConfigLevelTask(detector, detectionDateRange, false, randomUser(), transportService, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); + String error = String.format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); assertEquals(error, exceptionCaptor.getValue().getMessage()); } @@ -350,7 +374,7 @@ public void testCreateTaskIndexWithResourceAlreadyExistsException() throws IOExc AnomalyDetector detector = randomDetector(ImmutableList.of(randomFeature(true)), randomAlphaOfLength(5), 1, randomAlphaOfLength(5)); setupGetDetector(detector); - adTaskManager.startDetector(detector, detectionDateRange, randomUser(), transportService, listener); + adTaskManager.getAndExecuteOnLatestConfigLevelTask(detector, detectionDateRange, false, randomUser(), transportService, listener); verify(listener, never()).onFailure(any()); } @@ -365,12 +389,12 @@ public void testCreateTaskIndexWithException() throws IOException { AnomalyDetector detector = randomDetector(ImmutableList.of(randomFeature(true)), randomAlphaOfLength(5), 1, randomAlphaOfLength(5)); setupGetDetector(detector); - adTaskManager.startDetector(detector, detectionDateRange, randomUser(), transportService, listener); + adTaskManager.getAndExecuteOnLatestConfigLevelTask(detector, detectionDateRange, false, randomUser(), transportService, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals(error, exceptionCaptor.getValue().getMessage()); } - public void testStartDetectorWithNoEnabledFeature() throws IOException { + public void testgetAndExecuteOnLatestConfigLevelTaskWithNoEnabledFeature() throws IOException { AnomalyDetector detector = randomDetector( ImmutableList.of(randomFeature(false)), randomAlphaOfLength(5), @@ -379,16 +403,7 @@ public void testStartDetectorWithNoEnabledFeature() throws IOException { ); setupGetDetector(detector); - adTaskManager - .startDetector( - detector.getId(), - detectionDateRange, - indexAnomalyDetectorJobActionHandler, - randomUser(), - transportService, - context, - listener - ); + adTaskManager.startHistorical(detector, detectionDateRange, randomUser(), transportService, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); } @@ -398,29 +413,20 @@ public void testStartDetectorForHistoricalAnalysis() throws IOException { setupGetDetector(detector); setupHashRingWithOwningNode(); - adTaskManager - .startDetector( - detector.getId(), - detectionDateRange, - indexAnomalyDetectorJobActionHandler, - randomUser(), - transportService, - context, - listener - ); + adTaskManager.startHistorical(detector, detectionDateRange, randomUser(), transportService, listener); verify(adTaskManager, times(1)).forwardRequestToLeadNode(any(), any(), any()); } private void setupTaskSlots(int node1UsedTaskSlots, int node1AssignedTaskSLots, int node2UsedTaskSlots, int node2AssignedTaskSLots) { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); listener .onResponse( - new ADStatsNodesResponse( + new StatsNodesResponse( new ClusterName(randomAlphaOfLength(5)), ImmutableList .of( - new ADStatsNodeResponse( + new StatsNodeResponse( node1, ImmutableMap .of( @@ -430,7 +436,7 @@ private void setupTaskSlots(int node1UsedTaskSlots, int node1AssignedTaskSLots, node1AssignedTaskSLots ) ), - new ADStatsNodeResponse( + new StatsNodeResponse( node2, ImmutableMap .of( @@ -556,7 +562,7 @@ public void testCheckTaskSlotsWithAvailableTaskSlotsForScale() throws IOExceptio public void testDeleteDuplicateTasks() throws IOException { ADTask adTask = randomAdTask(); - adTaskManager.handleADTaskException(adTask, new DuplicateTaskException("test")); + adTaskManager.handleTaskException(adTask, new DuplicateTaskException("test")); verify(client, times(1)).delete(any(), any()); } @@ -595,7 +601,7 @@ public void testDetectorTaskSlotScaleUpDelta() { DiscoveryNode[] eligibleDataNodes = new DiscoveryNode[] { node1, node2 }; // Scale down - when(hashRing.getNodesWithSameLocalAdVersion()).thenReturn(eligibleDataNodes); + when(hashRing.getNodesWithSameLocalVersion()).thenReturn(eligibleDataNodes); when(adTaskCacheManager.getUnfinishedEntityCount(detectorId)).thenReturn(maxRunningEntities * 10); int taskSlots = maxRunningEntities - 1; when(adTaskCacheManager.getDetectorTaskSlots(detectorId)).thenReturn(taskSlots); @@ -608,7 +614,7 @@ public void testDetectorTaskSlotScaleDownDelta() { DiscoveryNode[] eligibleDataNodes = new DiscoveryNode[] { node1, node2 }; // Scale down - when(hashRing.getNodesWithSameLocalAdVersion()).thenReturn(eligibleDataNodes); + when(hashRing.getNodesWithSameLocalVersion()).thenReturn(eligibleDataNodes); when(adTaskCacheManager.getUnfinishedEntityCount(detectorId)).thenReturn(maxRunningEntities * 10); int taskSlots = maxRunningEntities * 5; when(adTaskCacheManager.getDetectorTaskSlots(detectorId)).thenReturn(taskSlots); @@ -727,7 +733,7 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingNode() { ActionListener listener = invocation.getArgument(3); listener.onResponse(new UpdateResponse(ShardId.fromString("[test][1]"), "1", 0L, 1L, 1L, DocWriteResponse.Result.UPDATED)); return null; - }).when(adTaskManager).updateLatestADTask(anyString(), any(), anyMap(), any()); + }).when(adTaskManager).updateLatestTask(anyString(), any(), anyMap(), any()); adTaskManager .updateLatestRealtimeTaskOnCoordinatingNode( detectorId, @@ -825,7 +831,7 @@ public void testResetLatestFlagAsFalse() throws IOException { public void testCleanADResultOfDeletedDetectorWithNoDeletedDetector() { when(adTaskCacheManager.pollDeletedConfig()).thenReturn(null); - adTaskManager.cleanADResultOfDeletedDetector(); + adTaskManager.cleanResultOfDeletedConfig(); verify(client, never()).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); } @@ -874,57 +880,59 @@ public void testCleanADResultOfDeletedDetectorWithException() { nodeFilter, hashRing, adTaskCacheManager, - threadPool + threadPool, + nodeStateManager, + taskProfileRunner ) ); - adTaskManager.cleanADResultOfDeletedDetector(); + adTaskManager.cleanResultOfDeletedConfig(); verify(client, times(1)).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).addDeletedConfig(eq(detectorId)); - adTaskManager.cleanADResultOfDeletedDetector(); + adTaskManager.cleanResultOfDeletedConfig(); verify(client, times(2)).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).addDeletedConfig(eq(detectorId)); } public void testMaintainRunningHistoricalTasksWithOwningNodeIsNotLocalNode() { // Test no owning node - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.empty()); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.empty()); adTaskManager.maintainRunningHistoricalTasks(transportService, 10); verify(client, never()).search(any(), any()); // Test owning node is not local node - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node2)); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.of(node2)); doReturn(node1).when(clusterService).localNode(); adTaskManager.maintainRunningHistoricalTasks(transportService, 10); verify(client, never()).search(any(), any()); } public void testMaintainRunningHistoricalTasksWithNoRunningTask() { - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node1)); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.of(node1)); doReturn(node1).when(clusterService).localNode(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); InternalSearchResponse response = new InternalSearchResponse( - searchHits, - InternalAggregations.EMPTY, - null, - null, - false, - null, - 1 - ); + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); SearchResponse searchResponse = new SearchResponse( - response, - null, - 1, - 1, - 0, - 100, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); listener.onResponse(searchResponse); return null; }).when(client).search(any(), any()); @@ -933,7 +941,7 @@ public void testMaintainRunningHistoricalTasksWithNoRunningTask() { } public void testMaintainRunningHistoricalTasksWithRunningTask() { - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node1)); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.of(node1)); doReturn(node1).when(clusterService).localNode(); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); @@ -946,24 +954,24 @@ public void testMaintainRunningHistoricalTasksWithRunningTask() { SearchHit task = SearchHit.fromXContent(TestHelpers.parser(runningHistoricalHCTaskContent)); SearchHits searchHits = new SearchHits(new SearchHit[] { task }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); InternalSearchResponse response = new InternalSearchResponse( - searchHits, - InternalAggregations.EMPTY, - null, - null, - false, - null, - 1 - ); + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); SearchResponse searchResponse = new SearchResponse( - response, - null, - 1, - 1, - 0, - 100, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); listener.onResponse(searchResponse); return null; }).when(client).search(any(), any()); @@ -972,11 +980,11 @@ public void testMaintainRunningHistoricalTasksWithRunningTask() { } public void testMaintainRunningRealtimeTasksWithNoRealtimeTask() { - when(adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()).thenReturn(null); + when(adTaskCacheManager.getConfigIdsInRealtimeTaskCache()).thenReturn(null); adTaskManager.maintainRunningRealtimeTasks(); verify(adTaskCacheManager, never()).removeRealtimeTaskCache(anyString()); - when(adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()).thenReturn(new String[0]); + when(adTaskCacheManager.getConfigIdsInRealtimeTaskCache()).thenReturn(new String[0]); adTaskManager.maintainRunningRealtimeTasks(); verify(adTaskCacheManager, never()).removeRealtimeTaskCache(anyString()); } @@ -985,7 +993,7 @@ public void testMaintainRunningRealtimeTasks() { String detectorId1 = randomAlphaOfLength(5); String detectorId2 = randomAlphaOfLength(5); String detectorId3 = randomAlphaOfLength(5); - when(adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()).thenReturn(new String[] { detectorId1, detectorId2, detectorId3 }); + when(adTaskCacheManager.getConfigIdsInRealtimeTaskCache()).thenReturn(new String[] { detectorId1, detectorId2, detectorId3 }); when(adTaskCacheManager.getRealtimeTaskCache(detectorId1)).thenReturn(null); RealtimeTaskCache cacheOfDetector2 = mock(RealtimeTaskCache.class); @@ -1011,7 +1019,7 @@ public void testStartHistoricalAnalysisWithNoOwningNode() throws IOException { Consumer> function = invocation.getArgument(1); function.accept(Optional.empty()); return null; - }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(anyString(), any(), any()); + }).when(hashRing).buildAndGetOwningNodeWithSameLocalVersion(anyString(), any(), any()); adTaskManager.startHistoricalAnalysis(detector, detectionDateRange, user, availableTaskSlots, transportService, listener); verify(listener, times(1)).onFailure(any()); } @@ -1068,7 +1076,7 @@ public void testGetAndExecuteOnLatestADTasksWithRunningRealtimeTaskWithTaskStopp ); setupGetAndExecuteOnLatestADTasks(profile); adTaskManager - .getAndExecuteOnLatestADTasks( + .getAndExecuteOnLatestTasks( detectorId, null, null, @@ -1134,7 +1142,7 @@ public void testGetAndExecuteOnLatestADTasksWithRunningHistoricalTask() throws I ); setupGetAndExecuteOnLatestADTasks(profile); adTaskManager - .getAndExecuteOnLatestADTasks( + .getAndExecuteOnLatestTasks( detectorId, null, null, @@ -1193,7 +1201,7 @@ private void setupGetAndExecuteOnLatestADTasks(ADTaskProfile adTaskProfile) { Consumer getNodeFunction = invocation.getArgument(0); getNodeFunction.accept(new DiscoveryNode[] { node1, node2 }); return null; - }).when(hashRing).getAllEligibleDataNodesWithKnownAdVersion(any(), any()); + }).when(hashRing).getAllEligibleDataNodesWithKnownVersion(any(), any()); doAnswer(invocation -> { ActionListener taskProfileResponseListener = invocation.getArgument(2); @@ -1248,7 +1256,8 @@ private void setupGetAndExecuteOnLatestADTasks(ADTaskProfile adTaskProfile) { Instant.now(), 60L, TestHelpers.randomUser(), - null + null, + AnalysisType.AD ).toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS) ), Collections.emptyMap(), @@ -1267,7 +1276,7 @@ public void testCreateADTaskDirectlyWithException() throws IOException { ActionListener listener = mock(ActionListener.class); doThrow(new RuntimeException("test")).when(client).index(any(), any()); - adTaskManager.createADTaskDirectly(adTask, function, listener); + adTaskManager.createTaskDirectly(adTask, function, listener); verify(listener, times(1)).onFailure(any()); doAnswer(invocation -> { @@ -1275,13 +1284,13 @@ public void testCreateADTaskDirectlyWithException() throws IOException { actionListener.onFailure(new RuntimeException("test")); return null; }).when(client).index(any(), any()); - adTaskManager.createADTaskDirectly(adTask, function, listener); + adTaskManager.createTaskDirectly(adTask, function, listener); verify(listener, times(2)).onFailure(any()); } public void testCleanChildTasksAndADResultsOfDeletedTaskWithNoDeletedDetectorTask() { when(adTaskCacheManager.hasDeletedTask()).thenReturn(false); - adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + adTaskManager.cleanChildTasksAndResultsOfDeletedTask(); verify(client, never()).execute(any(), any(), any()); } @@ -1300,7 +1309,7 @@ public void testCleanChildTasksAndADResultsOfDeletedTaskWithNullTask() { return null; }).when(threadPool).schedule(any(), any(), any()); - adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + adTaskManager.cleanChildTasksAndResultsOfDeletedTask(); verify(client, never()).execute(any(), any(), any()); } @@ -1319,7 +1328,7 @@ public void testCleanChildTasksAndADResultsOfDeletedTaskWithFailToDeleteADResult return null; }).when(threadPool).schedule(any(), any(), any()); - adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + adTaskManager.cleanChildTasksAndResultsOfDeletedTask(); verify(client, times(1)).execute(any(), any(), any()); } @@ -1339,7 +1348,7 @@ public void testCleanChildTasksAndADResultsOfDeletedTask() { return null; }).when(threadPool).schedule(any(), any(), any()); - adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + adTaskManager.cleanChildTasksAndResultsOfDeletedTask(); verify(client, times(2)).execute(any(), any(), any()); } @@ -1356,7 +1365,7 @@ public void testDeleteADTasks() { String detectorId = randomAlphaOfLength(5); ExecutorFunction function = mock(ExecutorFunction.class); ActionListener listener = mock(ActionListener.class); - adTaskManager.deleteADTasks(detectorId, function, listener); + adTaskManager.deleteTasks(detectorId, function, listener); verify(function, times(1)).execute(); } @@ -1381,7 +1390,7 @@ public void testDeleteADTasksWithBulkFailures() { String detectorId = randomAlphaOfLength(5); ExecutorFunction function = mock(ExecutorFunction.class); ActionListener listener = mock(ActionListener.class); - adTaskManager.deleteADTasks(detectorId, function, listener); + adTaskManager.deleteTasks(detectorId, function, listener); verify(listener, times(1)).onFailure(any()); } @@ -1401,11 +1410,11 @@ public void testDeleteADTasksWithException() { ExecutorFunction function = mock(ExecutorFunction.class); ActionListener listener = mock(ActionListener.class); - adTaskManager.deleteADTasks(detectorId, function, listener); + adTaskManager.deleteTasks(detectorId, function, listener); verify(function, times(1)).execute(); verify(listener, never()).onFailure(any()); - adTaskManager.deleteADTasks(detectorId, function, listener); + adTaskManager.deleteTasks(detectorId, function, listener); verify(function, times(1)).execute(); verify(listener, times(1)).onFailure(any()); } @@ -1438,7 +1447,7 @@ public void testForwardRequestToLeadNodeWithNotExistingNode() throws IOException Consumer> function = invocation.getArgument(1); function.accept(Optional.empty()); return null; - }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(any(), any(), any()); + }).when(hashRing).buildAndGetOwningNodeWithSameLocalVersion(any(), any(), any()); adTaskManager.forwardRequestToLeadNode(forwardADTaskRequest, transportService, listener); verify(listener, times(1)).onFailure(any()); @@ -1454,14 +1463,14 @@ public void testScaleTaskLaneOnCoordinatingNode() { } @SuppressWarnings("unchecked") - public void testStartDetectorWithException() throws IOException { + public void testgetAndExecuteOnLatestConfigLevelTaskWithException() throws IOException { AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); DateRange detectionDateRange = randomDetectionDateRange(); User user = null; ActionListener listener = mock(ActionListener.class); when(detectionIndices.doesStateIndexExist()).thenReturn(false); doThrow(new RuntimeException("test")).when(detectionIndices).initStateIndex(any()); - adTaskManager.startDetector(detector, detectionDateRange, user, transportService, listener); + adTaskManager.getAndExecuteOnLatestConfigLevelTask(detector, detectionDateRange, false, user, transportService, listener); verify(listener, times(1)).onFailure(any()); } @@ -1471,11 +1480,11 @@ public void testStopDetectorWithNonExistingDetector() { boolean historical = true; ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { - Consumer> function = invocation.getArgument(1); + Consumer> function = invocation.getArgument(2); function.accept(Optional.empty()); return null; - }).when(adTaskManager).getDetector(anyString(), any(), any()); - adTaskManager.stopDetector(detectorId, historical, indexAnomalyDetectorJobActionHandler, null, transportService, listener); + }).when(nodeStateManager).getConfig(anyString(), eq(AnalysisType.AD), any(Consumer.class), any()); + indexAnomalyDetectorJobActionHandler.stopConfig(detectorId, historical, null, transportService, listener); verify(listener, times(1)).onFailure(any()); } @@ -1485,11 +1494,11 @@ public void testStopDetectorWithNonExistingTask() { boolean historical = true; ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { - Consumer> function = invocation.getArgument(1); + Consumer> function = invocation.getArgument(2); AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); function.accept(Optional.of(detector)); return null; - }).when(adTaskManager).getDetector(anyString(), any(), any()); + }).when(nodeStateManager).getConfig(anyString(), eq(AnalysisType.AD), any(Consumer.class), any()); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -1497,7 +1506,7 @@ public void testStopDetectorWithNonExistingTask() { return null; }).when(client).search(any(), any()); - adTaskManager.stopDetector(detectorId, historical, indexAnomalyDetectorJobActionHandler, null, transportService, listener); + indexAnomalyDetectorJobActionHandler.stopConfig(detectorId, historical, null, transportService, listener); verify(listener, times(1)).onFailure(any()); } @@ -1507,11 +1516,11 @@ public void testStopDetectorWithTaskDone() { boolean historical = true; ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { - Consumer> function = invocation.getArgument(1); + Consumer> function = invocation.getArgument(2); AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); function.accept(Optional.of(detector)); return null; - }).when(adTaskManager).getDetector(anyString(), any(), any()); + }).when(nodeStateManager).getConfig(anyString(), eq(AnalysisType.AD), any(Consumer.class), any()); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -1540,14 +1549,14 @@ public void testStopDetectorWithTaskDone() { return null; }).when(client).search(any(), any()); - adTaskManager.stopDetector(detectorId, historical, indexAnomalyDetectorJobActionHandler, null, transportService, listener); + indexAnomalyDetectorJobActionHandler.stopConfig(detectorId, historical, null, transportService, listener); verify(listener, times(1)).onFailure(any()); } @SuppressWarnings("unchecked") public void testGetDetectorWithWrongContent() { String detectorId = randomAlphaOfLength(5); - Consumer> function = mock(Consumer.class); + Consumer> function = mock(Consumer.class); ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); @@ -1571,7 +1580,18 @@ public void testGetDetectorWithWrongContent() { responseListener.onResponse(response); return null; }).when(client).get(any(), any()); - adTaskManager.getDetector(detectorId, function, listener); + NodeStateManager nodeStateManager = new NodeStateManager( + client, + xContentRegistry(), + Settings.EMPTY, + mock(ClientUtil.class), + mock(Clock.class), + TimeSeriesSettings.HOURLY_MAINTENANCE, + clusterService, + TimeSeriesSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, + TimeSeriesSettings.BACKOFF_MINUTES + ); + nodeStateManager.getConfig(detectorId, AnalysisType.AD, function, listener); verify(listener, times(1)).onFailure(any()); } diff --git a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java index 8ce30df12..60f68131e 100644 --- a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java @@ -34,6 +34,7 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.TimeSeriesTask; import org.opensearch.timeseries.util.ExceptionUtil; import com.google.common.collect.ImmutableList; @@ -102,7 +103,7 @@ public void testHistoricalAnalysisWithValidDateRange() throws IOException, Inter client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); Thread.sleep(20000); GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); - assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(ADTask.STATE_FIELD))); + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(TimeSeriesTask.STATE_FIELD))); } public void testHistoricalAnalysisWithNonExistingIndex() throws IOException { @@ -140,7 +141,7 @@ public void testDisableADPlugin() throws IOException { ImmutableList.of(NotSerializableExceptionWrapper.class, EndRunException.class), () -> client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(10000) ); - assertTrue(exception.getMessage(), exception.getMessage().contains("AD functionality is disabled")); + assertTrue(exception.getMessage(), exception.getMessage().contains("AD plugin is disabled")); updateTransientSettings(ImmutableMap.of(AD_ENABLED, false)); } finally { // guarantee reset back to default @@ -162,7 +163,7 @@ public void testMultipleTasks() throws IOException, InterruptedException { client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); Thread.sleep(25000); GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); - assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(ADTask.STATE_FIELD))); + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(TimeSeriesTask.STATE_FIELD))); updateTransientSettings(ImmutableMap.of(MAX_BATCH_TASK_PER_NODE.getKey(), 1)); } @@ -187,6 +188,6 @@ private void testInvalidDetectionDateRange(DateRange dateRange, String error) th client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); Thread.sleep(5000); GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); - assertEquals(error, doc.getSourceAsMap().get(ADTask.ERROR_FIELD)); + assertEquals(error, doc.getSourceAsMap().get(TimeSeriesTask.ERROR_FIELD)); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java b/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java index 6946953fc..076c8763c 100644 --- a/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java @@ -20,16 +20,17 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.transport.ResultBulkResponse; public class ADResultBulkResponseTests extends OpenSearchTestCase { public void testSerialization() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); List retryRequests = new ArrayList<>(); retryRequests.add(new IndexRequest("index").id("blah").source(Collections.singletonMap("foo", "bar"))); - ADResultBulkResponse response = new ADResultBulkResponse(retryRequests); + ResultBulkResponse response = new ResultBulkResponse(retryRequests); response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ADResultBulkResponse readResponse = new ADResultBulkResponse(streamInput); + ResultBulkResponse readResponse = new ResultBulkResponse(streamInput); assertTrue(readResponse.hasFailures()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java b/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java index da8f3dce7..b545b97d8 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java @@ -18,6 +18,8 @@ import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.StatsNodesResponse; +import org.opensearch.timeseries.transport.StatsRequest; public class ADStatsITTests extends OpenSearchIntegTestCase { @@ -31,9 +33,9 @@ protected Collection> transportClientPlugins() { } public void testNormalADStats() throws ExecutionException, InterruptedException { - ADStatsRequest adStatsRequest = new ADStatsRequest(new String[0]); + StatsRequest adStatsRequest = new StatsRequest(new String[0]); - ADStatsNodesResponse response = client().execute(ADStatsNodesAction.INSTANCE, adStatsRequest).get(); + StatsNodesResponse response = client().execute(ADStatsNodesAction.INSTANCE, adStatsRequest).get(); assertTrue("getting stats failed", !response.hasFailures()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java index 4836825f3..45d8c15a0 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java @@ -20,12 +20,14 @@ import java.io.IOException; import java.time.Clock; import java.time.Instant; +import java.util.ArrayDeque; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.TreeMap; import java.util.stream.Collectors; @@ -34,8 +36,6 @@ import org.opensearch.Version; import org.opensearch.action.FailedNodeException; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -45,9 +45,16 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.StatsNodeRequest; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsNodesResponse; +import org.opensearch.timeseries.transport.StatsRequest; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.gson.JsonArray; import com.google.gson.JsonElement; @@ -78,18 +85,18 @@ public void setUp() throws Exception { @Test public void testADStatsNodeRequest() throws IOException { - ADStatsNodeRequest adStatsNodeRequest1 = new ADStatsNodeRequest(); + StatsNodeRequest adStatsNodeRequest1 = new StatsNodeRequest(); assertNull("ADStatsNodeRequest default constructor failed", adStatsNodeRequest1.getADStatsRequest()); - ADStatsRequest adStatsRequest = new ADStatsRequest(new String[0]); - ADStatsNodeRequest adStatsNodeRequest2 = new ADStatsNodeRequest(adStatsRequest); + StatsRequest adStatsRequest = new StatsRequest(new String[0]); + StatsNodeRequest adStatsNodeRequest2 = new StatsNodeRequest(adStatsRequest); assertEquals("ADStatsNodeRequest has the wrong ADStatsRequest", adStatsNodeRequest2.getADStatsRequest(), adStatsRequest); // Test serialization BytesStreamOutput output = new BytesStreamOutput(); adStatsNodeRequest2.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - adStatsNodeRequest1 = new ADStatsNodeRequest(streamInput); + adStatsNodeRequest1 = new StatsNodeRequest(streamInput); assertEquals( "readStats failed", adStatsNodeRequest2.getADStatsRequest().getStatsToBeRetrieved(), @@ -106,11 +113,11 @@ public void testSimpleADStatsNodeResponse() throws IOException, JsonPathNotFound }; // Test serialization - ADStatsNodeResponse adStatsNodeResponse = new ADStatsNodeResponse(discoveryNode1, stats); + StatsNodeResponse adStatsNodeResponse = new StatsNodeResponse(discoveryNode1, stats); BytesStreamOutput output = new BytesStreamOutput(); adStatsNodeResponse.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ADStatsNodeResponse readResponse = ADStatsNodeResponse.readStats(streamInput); + StatsNodeResponse readResponse = StatsNodeResponse.readStats(streamInput); assertEquals("readStats failed", readResponse.getStatsMap(), adStatsNodeResponse.getStatsMap()); // Test toXContent @@ -139,25 +146,26 @@ public void testADStatsNodeResponseWithEntity() throws IOException, JsonPathNotF attributes.put(name2, val2); String detectorId = "detectorId"; Entity entity = Entity.createEntityFromOrderedMap(attributes); - EntityModel entityModel = new EntityModel(entity, null, null); Clock clock = mock(Clock.class); when(clock.instant()).thenReturn(Instant.now()); - ModelState state = new ModelState( - entityModel, + ModelState state = new ModelState( + null, entity.getModelId(detectorId).get(), detectorId, - "entity", + ModelManager.ModelType.TRCF.getName(), clock, - 0.1f + 0.1f, + Optional.empty(), + new ArrayDeque<>() ); Map stats = state.getModelStateAsMap(); // Test serialization - ADStatsNodeResponse adStatsNodeResponse = new ADStatsNodeResponse(discoveryNode1, stats); + StatsNodeResponse adStatsNodeResponse = new StatsNodeResponse(discoveryNode1, stats); BytesStreamOutput output = new BytesStreamOutput(); adStatsNodeResponse.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ADStatsNodeResponse readResponse = ADStatsNodeResponse.readStats(streamInput); + StatsNodeResponse readResponse = StatsNodeResponse.readStats(streamInput); assertEquals("readStats failed", readResponse.getStatsMap(), adStatsNodeResponse.getStatsMap()); // Test toXContent @@ -192,7 +200,7 @@ public void testADStatsNodeResponseWithEntity() throws IOException, JsonPathNotF @Test public void testADStatsRequest() throws IOException { List allStats = Arrays.stream(StatNames.values()).map(StatNames::getName).collect(Collectors.toList()); - ADStatsRequest adStatsRequest = new ADStatsRequest(new String[0]); + StatsRequest adStatsRequest = new StatsRequest(new String[0]); // Test clear() adStatsRequest.clear(); @@ -215,7 +223,7 @@ public void testADStatsRequest() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); adStatsRequest.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ADStatsRequest readRequest = new ADStatsRequest(streamInput); + StatsRequest readRequest = new StatsRequest(streamInput); assertEquals("Serialization fails", readRequest.getStatsToBeRetrieved(), adStatsRequest.getStatsToBeRetrieved()); } @@ -227,10 +235,10 @@ public void testADStatsNodesResponse() throws IOException, JsonPathNotFoundExcep } }; - ADStatsNodeResponse adStatsNodeResponse = new ADStatsNodeResponse(discoveryNode1, nodeStats); - List adStatsNodeResponses = Collections.singletonList(adStatsNodeResponse); + StatsNodeResponse adStatsNodeResponse = new StatsNodeResponse(discoveryNode1, nodeStats); + List adStatsNodeResponses = Collections.singletonList(adStatsNodeResponse); List failures = Collections.emptyList(); - ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(new ClusterName(clusterName), adStatsNodeResponses, failures); + StatsNodesResponse adStatsNodesResponse = new StatsNodesResponse(new ClusterName(clusterName), adStatsNodeResponses, failures); // Test toXContent XContentBuilder builder = jsonBuilder(); @@ -256,7 +264,7 @@ public void testADStatsNodesResponse() throws IOException, JsonPathNotFoundExcep adStatsNodesResponse.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ADStatsNodesResponse readRequest = new ADStatsNodesResponse(streamInput); + StatsNodesResponse readRequest = new StatsNodesResponse(streamInput); builder = jsonBuilder(); String readJson = readRequest.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject().toString(); diff --git a/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java b/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java index d8e13e9ec..23ae4b954 100644 --- a/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java @@ -33,6 +33,7 @@ import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; @@ -117,7 +118,7 @@ private void testADTaskProfileResponse(ADTaskProfileNodeResponse response) throw } public void testADTaskProfileParse() throws IOException { - ADTaskProfile adTaskProfile = new ADTaskProfile( + TaskProfile adTaskProfile = new ADTaskProfile( randomAlphaOfLength(5), randomInt(), randomLong(), @@ -128,7 +129,7 @@ public void testADTaskProfileParse() throws IOException { ); String adTaskProfileString = TestHelpers .xContentBuilderToString(adTaskProfile.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - ADTaskProfile parsedADTaskProfile = ADTaskProfile.parse(TestHelpers.parser(adTaskProfileString)); + TaskProfile parsedADTaskProfile = ADTaskProfile.parse(TestHelpers.parser(adTaskProfileString)); assertEquals(adTaskProfile, parsedADTaskProfile); assertEquals(parsedADTaskProfile.toString(), adTaskProfile.toString()); } @@ -170,7 +171,7 @@ public void testSerializeResponse() throws IOException { } public void testADTaskProfileParseFullConstructor() throws IOException { - ADTaskProfile adTaskProfile = new ADTaskProfile( + TaskProfile adTaskProfile = new ADTaskProfile( TestHelpers.randomAdTask(), randomInt(), randomLong(), @@ -190,7 +191,7 @@ public void testADTaskProfileParseFullConstructor() throws IOException { ); String adTaskProfileString = TestHelpers .xContentBuilderToString(adTaskProfile.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - ADTaskProfile parsedADTaskProfile = ADTaskProfile.parse(TestHelpers.parser(adTaskProfileString)); + TaskProfile parsedADTaskProfile = ADTaskProfile.parse(TestHelpers.parser(adTaskProfileString)); assertEquals(adTaskProfile, parsedADTaskProfile); } } diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java index 1b23b6d51..74bb539cd 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java @@ -19,17 +19,16 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.anyDouble; -import static org.mockito.Mockito.anyLong; +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.same; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -55,7 +54,6 @@ import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; -import org.mockito.ArgumentCaptor; import org.opensearch.OpenSearchTimeoutException; import org.opensearch.Version; import org.opensearch.action.ActionRequestValidationException; @@ -65,20 +63,15 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SinglePointFeatures; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; @@ -108,6 +101,7 @@ import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.InternalFailure; import org.opensearch.timeseries.common.exception.LimitExceededException; @@ -115,9 +109,15 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SinglePointFeatures; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.model.FeatureData; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.transport.ResultResponse; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.NodeNotConnectedException; import org.opensearch.transport.RemoteTransportException; @@ -139,7 +139,7 @@ public class AnomalyResultTests extends AbstractTimeSeriesTest { private ClusterService clusterService; private NodeStateManager stateManager; private FeatureManager featureQuery; - private ModelManager normalModelManager; + private ADModelManager normalModelManager; private Client client; private SecurityClientUtil clientUtil; private AnomalyDetector detector; @@ -203,7 +203,7 @@ public void setUp() throws Exception { hashRing = mock(HashRing.class); Optional localNode = Optional.of(clusterService.state().nodes().getLocalNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(localNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(localNode); doReturn(localNode).when(hashRing).getNodeByAddress(any()); featureQuery = mock(FeatureManager.class); @@ -216,7 +216,7 @@ public void setUp() throws Exception { double rcfScore = 0.2; confidence = 0.91; anomalyGrade = 0.5; - normalModelManager = mock(ModelManager.class); + normalModelManager = mock(ADModelManager.class); long totalUpdates = 1440; int relativeIndex = 0; double[] currentTimeAttribution = new double[] { 0.5, 0.5 }; @@ -288,12 +288,12 @@ public void setUp() throws Exception { indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; @@ -309,7 +309,6 @@ public void setUp() throws Exception { DetectorInternalState.Builder result = new DetectorInternalState.Builder().lastUpdateTime(Instant.now()); listener.onResponse(TestHelpers.createGetResponse(result.build(), detector.getId(), ADCommonName.DETECTION_STATE_INDEX)); - } return null; @@ -364,7 +363,6 @@ public void testNormal() throws IOException { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -466,7 +464,7 @@ public void sendRequest( // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor Optional discoveryNode = Optional.of(testNodes[1].discoveryNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(discoveryNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(discoveryNode); when(hashRing.getNodeByAddress(any(TransportAddress.class))).thenReturn(discoveryNode); // register handler on testNodes[1] new RCFResultTransportAction( @@ -489,7 +487,6 @@ public void sendRequest( clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, realClusterService, indexNameResolver, @@ -515,7 +512,7 @@ public void noModelExceptionTemplate(Exception exception, String adID, String er @SuppressWarnings("unchecked") public void testInsufficientCapacityExceptionDuringColdStart() { - ModelManager rcfManager = mock(ModelManager.class); + ADModelManager rcfManager = mock(ADModelManager.class); doThrow(ResourceNotFoundException.class) .when(rcfManager) .getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); @@ -542,7 +539,6 @@ public void testInsufficientCapacityExceptionDuringColdStart() { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -563,7 +559,7 @@ public void testInsufficientCapacityExceptionDuringColdStart() { @SuppressWarnings("unchecked") public void testInsufficientCapacityExceptionDuringRestoringModel() { - ModelManager rcfManager = mock(ModelManager.class); + ADModelManager rcfManager = mock(ADModelManager.class); doThrow(new NotSerializableExceptionWrapper(new LimitExceededException(adID, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG))) .when(rcfManager) .getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); @@ -587,7 +583,6 @@ public void testInsufficientCapacityExceptionDuringRestoringModel() { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -689,7 +684,7 @@ public void sendRequest( // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor Optional discoveryNode = Optional.of(testNodes[1].discoveryNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(discoveryNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(discoveryNode); when(hashRing.getNodeByAddress(any(TransportAddress.class))).thenReturn(discoveryNode); // register handlers on testNodes[1] ActionFilters actionFilters = new ActionFilters(Collections.emptySet()); @@ -714,7 +709,6 @@ public void sendRequest( clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, realClusterService, indexNameResolver, @@ -757,7 +751,6 @@ public void testCircuitBreaker() { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -798,7 +791,7 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, .when(exceptionTransportService) .getConnection(same(rcfNode)); } else { - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(thresholdModelID))).thenReturn(Optional.of(thresholdNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(thresholdModelID))).thenReturn(Optional.of(thresholdNode)); when(hashRing.getNodeByAddress(any())).thenReturn(Optional.of(thresholdNode)); doThrow(new NodeNotConnectedException(rcfNode, "rcf node not connected")) .when(exceptionTransportService) @@ -827,7 +820,6 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, hackedClusterService, indexNameResolver, @@ -845,10 +837,10 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, assertException(listener, TimeSeriesException.class); if (!temporary) { - verify(hashRing, times(numberOfBuildCall)).buildCirclesForRealtimeAD(); + verify(hashRing, times(numberOfBuildCall)).buildCirclesForRealtime(); verify(stateManager, never()).addPressure(any(String.class), any(String.class)); } else { - verify(hashRing, never()).buildCirclesForRealtimeAD(); + verify(hashRing, never()).buildCirclesForRealtime(); verify(stateManager, times(numberOfBuildCall)).addPressure(any(String.class), any(String.class)); } } @@ -880,7 +872,6 @@ public void testMute() { clientUtil, muteStateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -895,7 +886,7 @@ public void testMute() { action.doExecute(null, request, listener); Throwable exception = assertException(listener, TimeSeriesException.class); - assertThat(exception.getMessage(), containsString(AnomalyResultTransportAction.NODE_UNRESPONSIVE_ERR_MSG)); + assertThat(exception.getMessage(), containsString(ResultProcessor.NODE_UNRESPONSIVE_ERR_MSG)); } public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOException { @@ -910,7 +901,7 @@ public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOE ); Optional localNode = Optional.of(clusterService.state().nodes().getLocalNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(localNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(localNode); doReturn(localNode).when(hashRing).getNodeByAddress(any()); new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); @@ -922,7 +913,6 @@ public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOE clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -971,7 +961,7 @@ public String executor() { } public void testSerialzationResponse() throws IOException { - AnomalyResultResponse response = new AnomalyResultResponse( + ResultResponse response = new AnomalyResultResponse( 4d, 0.993, 1.01, @@ -985,7 +975,8 @@ public void testSerialzationResponse() throws IOException { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); @@ -996,7 +987,7 @@ public void testSerialzationResponse() throws IOException { } public void testJsonResponse() throws IOException, JsonPathNotFoundException { - AnomalyResultResponse response = new AnomalyResultResponse( + ResultResponse response = new AnomalyResultResponse( 4d, 0.993, 1.01, @@ -1010,7 +1001,8 @@ public void testJsonResponse() throws IOException, JsonPathNotFoundException { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); XContentBuilder builder = jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -1042,7 +1034,8 @@ public void testJsonResponse() throws IOException, JsonPathNotFoundException { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); assertAnomalyResultResponse(readResponse, readResponse.getAnomalyGrade(), readResponse.getConfidence(), 0d); } @@ -1054,7 +1047,7 @@ public void testSerialzationRequest() throws IOException { StreamInput streamInput = output.bytes().streamInput(); AnomalyResultRequest readRequest = new AnomalyResultRequest(streamInput); - assertThat(request.getAdID(), equalTo(readRequest.getAdID())); + assertThat(request.getConfigId(), equalTo(readRequest.getConfigId())); assertThat(request.getStart(), equalTo(readRequest.getStart())); assertThat(request.getEnd(), equalTo(readRequest.getEnd())); } @@ -1065,7 +1058,7 @@ public void testJsonRequest() throws IOException, JsonPathNotFoundException { request.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = builder.toString(); - assertEquals(JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY), request.getAdID()); + assertEquals(JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY), request.getConfigId()); assertEquals(JsonDeserializer.getLongValue(json, CommonName.START_JSON_KEY), request.getStart()); assertEquals(JsonDeserializer.getLongValue(json, CommonName.END_JSON_KEY), request.getEnd()); } @@ -1090,33 +1083,6 @@ public void testNegativeTime() { assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); } - // no exception should be thrown - @SuppressWarnings("unchecked") - public void testOnFailureNull() throws IOException { - AnomalyResultTransportAction action = new AnomalyResultTransportAction( - new ActionFilters(Collections.emptySet()), - transportService, - settings, - client, - clientUtil, - stateManager, - featureQuery, - normalModelManager, - hashRing, - clusterService, - indexNameResolver, - adCircuitBreakerService, - adStats, - threadPool, - NamedXContentRegistry.EMPTY, - adTaskManager - ); - AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( - null, null, null, null, mock(ActionListener.class), null, null - ); - listener.onFailure(null); - } - static class ColdStartConfig { boolean coldStartRunning = false; Exception getCheckpointException = null; @@ -1193,7 +1159,6 @@ public void testColdStartNoTrainingData() throws Exception { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -1231,7 +1196,6 @@ public void testConcurrentColdStart() throws Exception { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -1275,7 +1239,6 @@ public void testColdStartTimeoutPutCheckpoint() throws Exception { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -1319,7 +1282,6 @@ public void testColdStartIllegalArgumentException() throws Exception { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -1370,7 +1332,6 @@ public void featureTestTemplate(FeatureTestMode mode) throws IOException { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -1459,7 +1420,6 @@ private void globalBlockTemplate(BlockType type, String errLogMsg, Settings inde clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, hackedClusterService, indexNameResolver, @@ -1482,119 +1442,22 @@ private void globalBlockTemplate(BlockType type, String errLogMsg) { } public void testReadBlock() { - globalBlockTemplate(BlockType.GLOBAL_BLOCK_READ, AnomalyResultTransportAction.READ_WRITE_BLOCKED); + globalBlockTemplate(BlockType.GLOBAL_BLOCK_READ, ResultProcessor.READ_WRITE_BLOCKED); } public void testWriteBlock() { - globalBlockTemplate(BlockType.GLOBAL_BLOCK_WRITE, AnomalyResultTransportAction.READ_WRITE_BLOCKED); + globalBlockTemplate(BlockType.GLOBAL_BLOCK_WRITE, ResultProcessor.READ_WRITE_BLOCKED); } public void testIndexReadBlock() { globalBlockTemplate( BlockType.INDEX_BLOCK, - AnomalyResultTransportAction.INDEX_READ_BLOCKED, + ResultProcessor.INDEX_READ_BLOCKED, Settings.builder().put(IndexMetadata.INDEX_BLOCKS_READ_SETTING.getKey(), true).build(), "test1" ); } - @SuppressWarnings("unchecked") - public void testNullRCFResult() { - AnomalyResultTransportAction action = new AnomalyResultTransportAction( - new ActionFilters(Collections.emptySet()), - transportService, - settings, - client, - clientUtil, - stateManager, - featureQuery, - normalModelManager, - hashRing, - clusterService, - indexNameResolver, - adCircuitBreakerService, - adStats, - threadPool, - NamedXContentRegistry.EMPTY, - adTaskManager - ); - AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( - "123-rcf-0", null, "123", null, mock(ActionListener.class), null, null - ); - listener.onResponse(null); - assertTrue(testAppender.containsMessage(AnomalyResultTransportAction.NULL_RESPONSE)); - } - - @SuppressWarnings("unchecked") - public void testNormalRCFResult() { - AnomalyResultTransportAction action = new AnomalyResultTransportAction( - new ActionFilters(Collections.emptySet()), - transportService, - settings, - client, - clientUtil, - stateManager, - featureQuery, - normalModelManager, - hashRing, - clusterService, - indexNameResolver, - adCircuitBreakerService, - adStats, - threadPool, - NamedXContentRegistry.EMPTY, - adTaskManager - ); - ActionListener listener = mock(ActionListener.class); - AnomalyResultTransportAction.RCFActionListener rcfListener = action.new RCFActionListener( - "123-rcf-0", null, "nodeID", detector, listener, null, adID - ); - double[] attribution = new double[] { 1. }; - long totalUpdates = 32; - double grade = 0.5; - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(AnomalyResultResponse.class); - rcfListener - .onResponse(new RCFResultResponse(0.3, 0, 26, attribution, totalUpdates, grade, Version.CURRENT, 0, null, null, null, 1.1)); - verify(listener, times(1)).onResponse(responseCaptor.capture()); - assertEquals(grade, responseCaptor.getValue().getAnomalyGrade(), 1e-10); - } - - @SuppressWarnings("unchecked") - public void testNullPointerRCFResult() { - AnomalyResultTransportAction action = new AnomalyResultTransportAction( - new ActionFilters(Collections.emptySet()), - transportService, - settings, - client, - clientUtil, - stateManager, - featureQuery, - normalModelManager, - hashRing, - clusterService, - indexNameResolver, - adCircuitBreakerService, - adStats, - threadPool, - NamedXContentRegistry.EMPTY, - adTaskManager - ); - ActionListener listener = mock(ActionListener.class); - // detector being null causes NullPointerException - AnomalyResultTransportAction.RCFActionListener rcfListener = action.new RCFActionListener( - "123-rcf-0", null, "nodeID", null, listener, null, adID - ); - double[] attribution = new double[] { 1. }; - long totalUpdates = 32; - double grade = 0.5; - ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(Exception.class); - rcfListener - .onResponse(new RCFResultResponse(0.3, 0, 26, attribution, totalUpdates, grade, Version.CURRENT, 0, null, null, null, 1.1)); - verify(listener, times(1)).onFailure(failureCaptor.capture()); - Exception failure = failureCaptor.getValue(); - assertTrue(failure instanceof InternalFailure); - } - @SuppressWarnings("unchecked") public void testAllFeaturesDisabled() throws IOException { doAnswer(invocation -> { @@ -1611,7 +1474,6 @@ public void testAllFeaturesDisabled() throws IOException { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -1634,7 +1496,7 @@ public void testEndRunDueToNoTrainingData() { ThreadPool mockThreadPool = mock(ThreadPool.class); setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(false).build()); - ModelManager rcfManager = mock(ModelManager.class); + ADModelManager rcfManager = mock(ADModelManager.class); doAnswer(invocation -> { Object[] args = invocation.getArguments(); ActionListener listener = (ActionListener) args[3]; @@ -1676,7 +1538,6 @@ public void testEndRunDueToNoTrainingData() { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -1729,7 +1590,6 @@ public void testColdStartEndRunException() { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -1773,7 +1633,6 @@ public void testColdStartEndRunExceptionNow() { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -1814,7 +1673,6 @@ public void testColdStartBecauseFailtoGetCheckpoint() { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -1853,7 +1711,6 @@ public void testNoColdStartDueToUnknownException() { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java index 78ffca8dd..0100e07f7 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java @@ -25,13 +25,12 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.ad.ADIntegTestCase; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.test.rest.OpenSearchRestTestCase; import org.opensearch.timeseries.TestHelpers; -import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.ExceptionUtil; import com.google.common.collect.ImmutableList; @@ -134,57 +133,57 @@ public void testFeatureWithCardinalityOfTextField() throws IOException { public void testFeatureQueryWithTermsAggregationForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"terms\":{\"field\":\"type\"}}}", true); - assertErrorMessage(adId, "Failed to parse aggregation", true); + assertErrorMessage(adId, "Failed to parse aggregation"); } public void testFeatureWithSumOfTextFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"sum\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations", true); + assertErrorMessage(adId, "Text fields are not optimised for operations"); } public void testFeatureWithSumOfTypeFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"sum\":{\"field\":\"type\"}}}", true); - assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [sum]", true); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [sum]"); } public void testFeatureWithMaxOfTextFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"max\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations", true); + assertErrorMessage(adId, "Text fields are not optimised for operations"); } public void testFeatureWithMaxOfTypeFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"max\":{\"field\":\"type\"}}}", true); - assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [max]", true); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [max]"); } public void testFeatureWithMinOfTextFieldForHCDetector() throws IOException { - String adId = createDetectorWithFeatureAgg("{\"test\":{\"min\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations", true); + String adId = createDetectorWithFeatureAgg("{\"test\":{\"min\":{\"field\":\"message\"}}}"); + assertErrorMessage(adId, "Text fields are not optimised for operations"); } public void testFeatureWithMinOfTypeFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"min\":{\"field\":\"type\"}}}", true); - assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [min]", true); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [min]"); } public void testFeatureWithAvgOfTextFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"avg\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations", true); + assertErrorMessage(adId, "Text fields are not optimised for operations"); } public void testFeatureWithAvgOfTypeFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"avg\":{\"field\":\"type\"}}}", true); - assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [avg]", true); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [avg]"); } public void testFeatureWithCountOfTextFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"value_count\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations", true); + assertErrorMessage(adId, "Text fields are not optimised for operations"); } public void testFeatureWithCardinalityOfTextFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"cardinality\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations", true); + assertErrorMessage(adId, "Text fields are not optimised for operations"); } private String createDetectorWithFeatureAgg(String aggQuery) throws IOException { @@ -220,7 +219,11 @@ private AnomalyDetector randomDetector(List indices, List featu null, null, null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption((int) features.stream().filter(Feature::getEnabled).count()), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); } @@ -243,11 +246,15 @@ private AnomalyDetector randomHCDetector(List indices, List fea ImmutableList.of(categoryField), null, null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption((int) features.stream().filter(Feature::getEnabled).count()), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); } - private void assertErrorMessage(String adId, String errorMessage, boolean hcDetector) { + private void assertErrorMessage(String adId, String errorMessage) { AnomalyResultRequest resultRequest = new AnomalyResultRequest(adId, start, end); try { Thread.sleep(1000); // sleep some time to build AD version hash ring @@ -257,7 +264,7 @@ private void assertErrorMessage(String adId, String errorMessage, boolean hcDete // wait at most 20 seconds int numberofTries = 40; Exception e = null; - if (hcDetector) { + while (numberofTries-- > 0) { try { // HCAD records failures asynchronously. Before a failure is recorded, HCAD returns immediately without failure. @@ -265,15 +272,12 @@ private void assertErrorMessage(String adId, String errorMessage, boolean hcDete Thread.sleep(500); } catch (Exception exp) { e = exp; + LOG.info(numberofTries); + LOG.error("hello", e); break; } } - } else { - e = expectThrowsAnyOf( - ImmutableList.of(NotSerializableExceptionWrapper.class, TimeSeriesException.class), - () -> client().execute(AnomalyResultAction.INSTANCE, resultRequest).actionGet(30_000) - ); - } + String stackErrorMessage = ExceptionUtil.getErrorMessage(e); assertTrue( "Unexpected error: " + e.getMessage(), @@ -282,8 +286,4 @@ private void assertErrorMessage(String adId, String errorMessage, boolean hcDete || stackErrorMessage.contains("AD memory circuit is broken") ); } - - private void assertErrorMessage(String adId, String errorMessage) { - assertErrorMessage(adId, errorMessage, false); - } } diff --git a/src/test/java/org/opensearch/ad/transport/DelegateADProfileTransportAction.java b/src/test/java/org/opensearch/ad/transport/DelegateADProfileTransportAction.java new file mode 100644 index 000000000..410fe3f9a --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/DelegateADProfileTransportAction.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.ProfileNodeRequest; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.transport.TransportService; + +/** + * This utility class serves as a delegate for testing ProfileTransportAction functionalities. + * It facilitates the invocation of protected methods within the org.opensearch.ad.transport.ADProfileTransportAction + * and org.opensearch.timeseries.transport.BaseProfileTransportAction classes, which are otherwise inaccessible + * due to Java's access control restrictions. This is achieved by extending the target classes or using reflection + * where inheritance is not possible, enabling the testing framework to perform comprehensive tests on protected + * class members across different packages. + */ +public class DelegateADProfileTransportAction extends ADProfileTransportAction { + + public DelegateADProfileTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ADModelManager modelManager, + FeatureManager featureManager, + ADCacheProvider cacheProvider, + Settings settings + ) { + super(threadPool, clusterService, transportService, actionFilters, modelManager, featureManager, cacheProvider, settings); + } + + @Override + public ProfileResponse newResponse(ProfileRequest request, List responses, List failures) { + return super.newResponse(request, responses, failures); + } + + @Override + public ProfileNodeRequest newNodeRequest(ProfileRequest request) { + return super.newNodeRequest(request); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/DelegateDeleteADModelTransportAction.java b/src/test/java/org/opensearch/ad/transport/DelegateDeleteADModelTransportAction.java new file mode 100644 index 000000000..bc113fe37 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/DelegateDeleteADModelTransportAction.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.DeleteModelNodeResponse; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.DeleteModelResponse; +import org.opensearch.transport.TransportService; + +/** + * This utility class serves as a delegate for testing ProfileTransportAction functionalities. + * It facilitates the invocation of protected methods within the org.opensearch.ad.transport.DeleteADModelTransportAction + * and org.opensearch.timeseries.transport.BaseDeleteModelTransportAction classes, which are otherwise inaccessible + * due to Java's access control restrictions. This is achieved by extending the target classes or using reflection + * where inheritance is not possible, enabling the testing framework to perform comprehensive tests on protected + * class members across different packages. + */ +public class DelegateDeleteADModelTransportAction extends DeleteADModelTransportAction { + public DelegateDeleteADModelTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NodeStateManager nodeStateManager, + ADModelManager modelManager, + FeatureManager featureManager, + ADCacheProvider cache, + ADTaskCacheManager adTaskCacheManager, + ADColdStart coldStarter + ) { + super( + threadPool, + clusterService, + transportService, + actionFilters, + nodeStateManager, + modelManager, + featureManager, + cache, + adTaskCacheManager, + coldStarter + ); + } + + @Override + public DeleteModelResponse newResponse( + DeleteModelRequest request, + List responses, + List failures + ) { + return super.newResponse(request, responses, failures); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportActionTests.java index ac81ecf25..c388e7499 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportActionTests.java @@ -22,6 +22,7 @@ import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.transport.DeleteConfigRequest; import com.google.common.collect.ImmutableList; @@ -60,7 +61,7 @@ public void testDeleteAnomalyDetectorWithEnabledFeature() throws IOException { private void testDeleteDetector(AnomalyDetector detector) throws IOException { String detectorId = createDetector(detector); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest(detectorId); + DeleteConfigRequest request = new DeleteConfigRequest(detectorId); DeleteResponse deleteResponse = client().execute(DeleteAnomalyDetectorAction.INSTANCE, request).actionGet(10000); assertEquals("deleted", deleteResponse.getResult().getLowercase()); } diff --git a/src/test/java/org/opensearch/ad/transport/DeleteITTests.java b/src/test/java/org/opensearch/ad/transport/DeleteITTests.java index 1a57504cc..52678e63f 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteITTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteITTests.java @@ -20,6 +20,10 @@ import org.opensearch.common.action.ActionFuture; import org.opensearch.plugins.Plugin; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.DeleteModelResponse; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; public class DeleteITTests extends ADIntegTestCase { @@ -28,23 +32,24 @@ protected Collection> nodePlugins() { return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } + @Override protected Collection> transportClientPlugins() { return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } public void testNormalStopDetector() throws ExecutionException, InterruptedException { - StopDetectorRequest request = new StopDetectorRequest().adID("123"); + StopConfigRequest request = new StopConfigRequest().adID("123"); - ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); + ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); - StopDetectorResponse response = future.get(); + StopConfigResponse response = future.get(); assertTrue(response.success()); } public void testNormalDeleteModel() throws ExecutionException, InterruptedException { DeleteModelRequest request = new DeleteModelRequest("123"); - ActionFuture future = client().execute(DeleteModelAction.INSTANCE, request); + ActionFuture future = client().execute(DeleteADModelAction.INSTANCE, request); DeleteModelResponse response = future.get(); assertTrue(!response.hasFailures()); @@ -53,15 +58,15 @@ public void testNormalDeleteModel() throws ExecutionException, InterruptedExcept public void testEmptyIDDeleteModel() throws ExecutionException, InterruptedException { DeleteModelRequest request = new DeleteModelRequest(""); - ActionFuture future = client().execute(DeleteModelAction.INSTANCE, request); + ActionFuture future = client().execute(DeleteADModelAction.INSTANCE, request); expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); } public void testEmptyIDStopDetector() throws ExecutionException, InterruptedException { - StopDetectorRequest request = new StopDetectorRequest(); + StopConfigRequest request = new StopConfigRequest(); - ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); + ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); } diff --git a/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java index b76925492..8d3a23618 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java @@ -27,13 +27,12 @@ import org.opensearch.Version; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -46,6 +45,13 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.CronNodeResponse; +import org.opensearch.timeseries.transport.CronResponse; +import org.opensearch.timeseries.transport.DeleteModelNodeRequest; +import org.opensearch.timeseries.transport.DeleteModelNodeResponse; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.DeleteModelResponse; import org.opensearch.transport.TransportService; import com.google.gson.JsonElement; @@ -53,7 +59,7 @@ import test.org.opensearch.ad.util.JsonDeserializer; public class DeleteModelTransportActionTests extends AbstractTimeSeriesTest { - private DeleteModelTransportAction action; + private DelegateDeleteADModelTransportAction action; private String localNodeID; @Override @@ -70,15 +76,15 @@ public void setUp() throws Exception { TransportService transportService = mock(TransportService.class); ActionFilters actionFilters = mock(ActionFilters.class); NodeStateManager nodeStateManager = mock(NodeStateManager.class); - ModelManager modelManager = mock(ModelManager.class); + ADModelManager modelManager = mock(ADModelManager.class); FeatureManager featureManager = mock(FeatureManager.class); - CacheProvider cacheProvider = mock(CacheProvider.class); - EntityCache entityCache = mock(EntityCache.class); + ADCacheProvider cacheProvider = mock(ADCacheProvider.class); + ADPriorityCache entityCache = mock(ADPriorityCache.class); when(cacheProvider.get()).thenReturn(entityCache); ADTaskCacheManager adTaskCacheManager = mock(ADTaskCacheManager.class); - EntityColdStarter coldStarter = mock(EntityColdStarter.class); + ADColdStart coldStarter = mock(ADColdStart.class); - action = new DeleteModelTransportAction( + action = new DelegateDeleteADModelTransportAction( threadPool, clusterService, transportService, diff --git a/src/test/java/org/opensearch/ad/transport/DeleteTests.java b/src/test/java/org/opensearch/ad/transport/DeleteTests.java index 4821cbfbd..9e5cff5df 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteTests.java @@ -58,6 +58,11 @@ import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.transport.DeleteModelNodeResponse; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.DeleteModelResponse; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportService; @@ -139,7 +144,7 @@ public void testSerialzationResponse() throws IOException { response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - DeleteModelResponse readResponse = DeleteModelAction.INSTANCE.getResponseReader().read(streamInput); + DeleteModelResponse readResponse = DeleteADModelAction.INSTANCE.getResponseReader().read(streamInput); assertTrue(readResponse.hasFailures()); assertEquals(failures.size(), readResponse.failures().size()); @@ -152,12 +157,12 @@ public void testEmptyIDDeleteModel() { } public void testEmptyIDStopDetector() { - ActionRequestValidationException e = new StopDetectorRequest().validate(); + ActionRequestValidationException e = new StopConfigRequest().validate(); assertThat(e.validationErrors(), hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); } public void testValidIDStopDetector() { - ActionRequestValidationException e = new StopDetectorRequest().adID("foo").validate(); + ActionRequestValidationException e = new StopConfigRequest().adID("foo").validate(); assertThat(e, is(nullValue())); } @@ -171,12 +176,12 @@ public void testSerialzationRequestDeleteModel() throws IOException { } public void testSerialzationRequestStopDetector() throws IOException { - StopDetectorRequest request = new StopDetectorRequest().adID("123"); + StopConfigRequest request = new StopConfigRequest().adID("123"); BytesStreamOutput output = new BytesStreamOutput(); request.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - StopDetectorRequest readRequest = new StopDetectorRequest(streamInput); - assertThat(request.getAdID(), equalTo(readRequest.getAdID())); + StopConfigRequest readRequest = new StopConfigRequest(streamInput); + assertThat(request.getConfigID(), equalTo(readRequest.getConfigID())); } public void testJsonRequestTemplate(R request, Supplier requestSupplier) throws IOException, @@ -189,8 +194,8 @@ public void testJsonRequestTemplate(R request, Supplier listener = new PlainActionFuture<>(); + StopConfigRequest request = new StopConfigRequest().adID(detectorID); + PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(task, request, listener); - StopDetectorResponse response = listener.actionGet(); + StopConfigResponse response = listener.actionGet(); assertTrue(!response.success()); } diff --git a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java index f7eb2c8e9..a702f69a4 100644 --- a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java @@ -15,7 +15,6 @@ import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.startsWith; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; @@ -50,27 +49,23 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.ad.AnomalyDetectorJobRunnerTests; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.constant.CommonValue; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; +import org.opensearch.ad.ratelimit.ADSaveResultStrategy; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -87,10 +82,16 @@ import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.constant.CommonValue; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.transport.EntityResultRequest; import org.opensearch.transport.TransportService; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.gson.JsonArray; import com.google.gson.JsonElement; @@ -99,14 +100,14 @@ import test.org.opensearch.ad.util.RandomModelStateConfig; public class EntityResultTransportActionTests extends AbstractTimeSeriesTest { - EntityResultTransportAction entityResult; + EntityADResultTransportAction entityResult; ActionFilters actionFilters; TransportService transportService; - ModelManager manager; + ADModelManager manager; CircuitBreakerService adCircuitBreakerService; - CheckpointDao checkpointDao; - CacheProvider provider; - EntityCache entityCache; + ADCheckpointDao checkpointDao; + ADCacheProvider provider; + ADPriorityCache entityCache; NodeStateManager stateManager; Settings settings; Clock clock; @@ -125,13 +126,13 @@ public class EntityResultTransportActionTests extends AbstractTimeSeriesTest { double[] cacheHitData; String tooLongEntity; double[] tooLongData; - ResultWriteWorker resultWriteQueue; - CheckpointReadWorker checkpointReadQueue; + ADResultWriteWorker resultWriteQueue; + ADCheckpointReadWorker checkpointReadQueue; int minSamples; Instant now; - EntityColdStarter coldStarter; - ColdEntityWorker coldEntityQueue; - EntityColdStartWorker entityColdStartQueue; + ADColdStart coldStarter; + ADColdEntityWorker coldEntityQueue; + ADColdStartWorker entityColdStartQueue; ADIndexManagement indexUtil; ClusterService clusterService; ADStats adStats; @@ -157,14 +158,14 @@ public void setUp() throws Exception { adCircuitBreakerService = mock(CircuitBreakerService.class); when(adCircuitBreakerService.isOpen()).thenReturn(false); - checkpointDao = mock(CheckpointDao.class); + checkpointDao = mock(ADCheckpointDao.class); detectorId = "123"; entities = new HashMap<>(); start = 10L; end = 20L; - request = new EntityResultRequest(detectorId, entities, start, end); + request = new EntityResultRequest(detectorId, entities, start, end, AnalysisType.AD, null); clock = mock(Clock.class); now = Instant.now(); @@ -182,7 +183,7 @@ public void setUp() throws Exception { Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - manager = new ModelManager( + manager = new ADModelManager( null, clock, 0, @@ -190,18 +191,17 @@ public void setUp() throws Exception { 0, 0, 0, - 0, null, AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, - mock(EntityColdStarter.class), + mock(ADColdStart.class), null, null, settings, clusterService ); - provider = mock(CacheProvider.class); - entityCache = mock(EntityCache.class); + provider = mock(ADCacheProvider.class); + entityCache = mock(ADPriorityCache.class); when(provider.get()).thenReturn(entityCache); String field = "a"; @@ -225,7 +225,8 @@ public void setUp() throws Exception { tooLongData = new double[] { 0.3 }; entities.put(Entity.createSingleAttributeEntity(detector.getCategoryFields().get(0), tooLongEntity), tooLongData); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); when(entityCache.get(eq(cacheMissEntityObj.getModelId(detectorId).get()), any())).thenReturn(null); when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state); @@ -236,31 +237,31 @@ public void setUp() throws Exception { indexUtil = mock(ADIndexManagement.class); when(indexUtil.getSchemaVersion(any())).thenReturn(CommonValue.NO_SCHEMA_VERSION); - resultWriteQueue = mock(ResultWriteWorker.class); - checkpointReadQueue = mock(CheckpointReadWorker.class); + resultWriteQueue = mock(ADResultWriteWorker.class); + checkpointReadQueue = mock(ADCheckpointReadWorker.class); minSamples = 1; - coldStarter = mock(EntityColdStarter.class); + coldStarter = mock(ADColdStart.class); doAnswer(invocation -> { - ModelState modelState = invocation.getArgument(0); - modelState.getModel().clear(); + ModelState modelState = invocation.getArgument(0); + modelState.clear(); return null; - }).when(coldStarter).trainModelFromExistingSamples(any(), anyInt()); + }).when(coldStarter).trainModelFromExistingSamples(any(), any(), any(), any()); - coldEntityQueue = mock(ColdEntityWorker.class); - entityColdStartQueue = mock(EntityColdStartWorker.class); + coldEntityQueue = mock(ADColdEntityWorker.class); + entityColdStartQueue = mock(ADColdStartWorker.class); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; adStats = new ADStats(statsMap); - entityResult = new EntityResultTransportAction( + entityResult = new EntityADResultTransportAction( actionFilters, transportService, manager, @@ -273,7 +274,8 @@ public void setUp() throws Exception { coldEntityQueue, threadPool, entityColdStartQueue, - adStats + adStats, + mock(ADSaveResultStrategy.class) ); // timeout in 60 seconds @@ -317,7 +319,8 @@ public void testFailtoGetDetector() { // test rcf score is 0 public void testNoResultsToSave() { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(false).build()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(false).build()); when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state); PlainActionFuture future = PlainActionFuture.newFuture(); @@ -335,19 +338,19 @@ public void testValidRequest() { } public void testEmptyId() { - request = new EntityResultRequest("", entities, start, end); + request = new EntityResultRequest("", entities, start, end, AnalysisType.AD, null); ActionRequestValidationException e = request.validate(); assertThat(e.validationErrors(), hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); } public void testReverseTime() { - request = new EntityResultRequest(detectorId, entities, end, start); + request = new EntityResultRequest(detectorId, entities, end, start, AnalysisType.AD, null); ActionRequestValidationException e = request.validate(); assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); } public void testNegativeTime() { - request = new EntityResultRequest(detectorId, entities, start, -end); + request = new EntityResultRequest(detectorId, entities, start, -end, AnalysisType.AD, null); ActionRequestValidationException e = request.validate(); assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); } @@ -384,9 +387,9 @@ public void testJsonResponse() throws IOException, JsonPathNotFoundException { } public void testFailToScore() { - ModelManager spyModelManager = spy(manager); - doThrow(new IllegalArgumentException()).when(spyModelManager).getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt()); - entityResult = new EntityResultTransportAction( + ADModelManager spyModelManager = spy(manager); + doThrow(new IllegalArgumentException()).when(spyModelManager).getResult(any(), any(), anyString(), any(), any(), any()); + entityResult = new EntityADResultTransportAction( actionFilters, transportService, spyModelManager, @@ -399,7 +402,8 @@ public void testFailToScore() { coldEntityQueue, threadPool, entityColdStartQueue, - adStats + adStats, + mock(ADSaveResultStrategy.class) ); PlainActionFuture future = PlainActionFuture.newFuture(); @@ -409,9 +413,9 @@ public void testFailToScore() { future.actionGet(timeoutMs); verify(resultWriteQueue, never()).put(any()); - verify(entityCache, times(1)).removeEntityModel(anyString(), anyString()); + verify(entityCache, times(1)).removeModel(anyString(), anyString()); verify(entityColdStartQueue, times(1)).put(any()); - Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + Object val = adStats.getStat(StatNames.AD_MODEL_CORRUTPION_COUNT.getName()).getValue(); assertEquals(1L, ((Long) val).longValue()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java index 633a9a4fe..7a28bfa7c 100644 --- a/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java @@ -78,7 +78,11 @@ public void testNullDetectorIdAndTaskAction() throws IOException { null, randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); ForwardADTaskRequest request = new ForwardADTaskRequest(detector, null, null, null, null, Version.V_2_1_0); ActionRequestValidationException validate = request.validate(); diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java index ac83b5c8e..fdbbbae49 100644 --- a/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java @@ -31,15 +31,16 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.ad.ADUnitTestCase; -import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.core.action.ActionListener; import org.opensearch.tasks.Task; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.transport.TransportService; @@ -72,7 +73,8 @@ public void setUp() throws Exception { adTaskManager, adTaskCacheManager, featureManager, - stateManager + stateManager, + mock(ADIndexJobActionHandler.class) ); task = mock(Task.class); diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java index 2a0b677ed..d41f255f4 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java @@ -31,6 +31,7 @@ import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.GetConfigRequest; import com.google.common.collect.ImmutableList; @@ -53,11 +54,11 @@ protected NamedWriteableRegistry writableRegistry() { public void testGetRequest() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); - GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, false, false, "nonempty", "", false, null); + GetConfigRequest request = new GetConfigRequest("1234", 4321, false, false, "nonempty", "", false, null); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + GetConfigRequest newRequest = new GetConfigRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); } diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java index ac5702edc..7aed2eae0 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java @@ -21,7 +21,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; @@ -38,6 +37,7 @@ import org.opensearch.action.get.MultiGetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -55,6 +55,8 @@ import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.transport.EntityProfileTests; +import org.opensearch.timeseries.transport.GetConfigRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.Transport; @@ -67,7 +69,7 @@ public class GetAnomalyDetectorTests extends AbstractTimeSeriesTest { private ActionFilters actionFilters; private Client client; private SecurityClientUtil clientUtil; - private GetAnomalyDetectorRequest request; + private GetConfigRequest request; private String detectorId = "yecrdnUBqurvo9uKU_d8"; private String entityValue = "app_0"; private String categoryField = "categoryField"; @@ -120,6 +122,8 @@ public void setUp() throws Exception { adTaskManager = mock(ADTaskManager.class); + ADTaskProfileRunner adTaskProfileRunner = mock(ADTaskProfileRunner.class); + action = new GetAnomalyDetectorTransportAction( transportService, nodeFilter, @@ -129,18 +133,19 @@ public void setUp() throws Exception { clientUtil, Settings.EMPTY, xContentRegistry(), - adTaskManager + adTaskManager, + adTaskProfileRunner ); entity = Entity.createSingleAttributeEntity(categoryField, entityValue); } - public void testInvalidRequest() throws IOException { + public void testInvalidRequest() { typeStr = "entity_info2,init_progress2"; rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_/_profile"; - request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); + request = new GetConfigRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); future = new PlainActionFuture<>(); action.doExecute(null, request, future); @@ -148,7 +153,7 @@ public void testInvalidRequest() throws IOException { } @SuppressWarnings("unchecked") - public void testValidRequest() throws IOException { + public void testValidRequest() { doAnswer(invocation -> { Object[] args = invocation.getArguments(); GetRequest request = (GetRequest) args[0]; @@ -165,7 +170,7 @@ public void testValidRequest() throws IOException { rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_/_profile"; - request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); + request = new GetConfigRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); future = new PlainActionFuture<>(); action.doExecute(null, request, future); @@ -181,17 +186,7 @@ public void testGetTransportActionWithReturnTask() { return null; }) .when(adTaskManager) - .getAndExecuteOnLatestADTasks( - anyString(), - eq(null), - eq(null), - anyList(), - any(), - eq(transportService), - eq(true), - anyInt(), - any() - ); + .getAndExecuteOnLatestTasks(anyString(), eq(null), eq(null), anyList(), any(), eq(transportService), eq(true), anyInt(), any()); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -203,7 +198,7 @@ public void testGetTransportActionWithReturnTask() { rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_"; - request = new GetAnomalyDetectorRequest(detectorId, 0L, false, true, typeStr, rawPath, false, entity); + request = new GetConfigRequest(detectorId, 0L, false, true, typeStr, rawPath, false, entity); future = new PlainActionFuture<>(); action.getExecute(request, future); diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java index 35f6ba36f..1c12aeb08 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java @@ -25,14 +25,11 @@ import org.junit.*; import org.mockito.Mockito; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.EntityProfile; -import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.*; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.ClusterSettings; @@ -50,8 +47,12 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.InitProgressProfile; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.GetConfigRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.timeseries.util.SecurityClientUtil; @@ -101,7 +102,8 @@ public void setUp() throws Exception { clientUtil, Settings.EMPTY, xContentRegistry(), - adTaskManager + adTaskManager, + mock(ADTaskProfileRunner.class) ); task = Mockito.mock(Task.class); response = new ActionListener() { @@ -126,14 +128,14 @@ protected NamedWriteableRegistry writableRegistry() { @Test public void testGetTransportAction() throws IOException { - GetAnomalyDetectorRequest getConfigRequest = new GetAnomalyDetectorRequest("1234", 4321, false, false, "nonempty", "", false, null); - action.doExecute(task, getConfigRequest, response); + GetConfigRequest getAnomalyDetectorRequest = new GetConfigRequest("1234", 4321, false, false, "nonempty", "", false, null); + action.doExecute(task, getAnomalyDetectorRequest, response); } @Test public void testGetTransportActionWithReturnJob() throws IOException { - GetAnomalyDetectorRequest getConfigRequest = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, null); - action.doExecute(task, getConfigRequest, response); + GetConfigRequest getAnomalyDetectorRequest = new GetConfigRequest("1234", 4321, true, false, "", "abcd", false, null); + action.doExecute(task, getAnomalyDetectorRequest, response); } @Test @@ -144,23 +146,23 @@ public void testGetAction() { @Test public void testGetAnomalyDetectorRequest() throws IOException { - GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, entity); + GetConfigRequest request = new GetConfigRequest("1234", 4321, true, false, "", "abcd", false, entity); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + GetConfigRequest newRequest = new GetConfigRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); Assert.assertEquals(request.getRawPath(), newRequest.getRawPath()); Assert.assertNull(newRequest.validate()); } @Test public void testGetAnomalyDetectorRequestNoEntityValue() throws IOException { - GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, null); + GetConfigRequest request = new GetConfigRequest("1234", 4321, true, false, "", "abcd", false, null); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); + GetConfigRequest newRequest = new GetConfigRequest(input); Assert.assertNull(newRequest.getEntity()); } @@ -230,7 +232,7 @@ public void testGetAnomalyDetectorProfileResponse() throws IOException { // {init_progress={percentage=99%, estimated_minutes_left=2, needed_shingles=2}} Map map = TestHelpers.XContentBuilderToMap(builder); - Map parsedInitProgress = (Map) (map.get(ADCommonName.INIT_PROGRESS)); + Map parsedInitProgress = (Map) (map.get(CommonName.INIT_PROGRESS)); Assert.assertEquals(initProgress.getPercentage(), parsedInitProgress.get(InitProgressProfile.PERCENTAGE).toString()); assertTrue(initProgress.toString().contains("[percentage=99%,estimated_minutes_left=2,needed_shingles=2]")); Assert diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java index f29030912..e4f160aa1 100644 --- a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java @@ -34,6 +34,7 @@ import com.google.common.collect.ImmutableMap; public class IndexAnomalyDetectorActionTests extends OpenSearchSingleNodeTestCase { + @Override @Before public void setUp() throws Exception { super.setUp(); @@ -58,7 +59,8 @@ public void testIndexRequest() throws Exception { TimeValue.timeValueSeconds(60), 1000, 10, - 5 + 5, + 10 ); request.writeTo(out); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java index d370fa703..c59108c17 100644 --- a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java @@ -176,7 +176,8 @@ public void setUp() throws Exception { TimeValue.timeValueSeconds(60), 1000, 10, - 5 + 5, + 10 ); response = new ActionListener() { @Override diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java index 94e07fe3c..673e9af92 100644 --- a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java @@ -13,7 +13,6 @@ import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -66,24 +65,19 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.feature.CompositeRetriever; -import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.EntityFeatureRequest; -import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; +import org.opensearch.ad.ratelimit.ADSaveResultStrategy; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; @@ -113,15 +107,22 @@ import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.InternalFailure; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.CompositeRetriever; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.ratelimit.FeatureRequest; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.transport.ResultProcessor; import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.Transport; @@ -149,7 +150,7 @@ public class MultiEntityResultTests extends AbstractTimeSeriesTest { private Client client; private SecurityClientUtil clientUtil; private FeatureManager featureQuery; - private ModelManager normalModelManager; + private ADModelManager normalModelManager; private HashRing hashRing; private ClusterService clusterService; private IndexNameExpressionResolver indexNameResolver; @@ -158,12 +159,12 @@ public class MultiEntityResultTests extends AbstractTimeSeriesTest { private ThreadPool mockThreadPool; private String detectorId; private Instant now; - private CacheProvider provider; + private ADCacheProvider provider; private ADIndexManagement indexUtil; - private ResultWriteWorker resultWriteQueue; - private CheckpointReadWorker checkpointReadQueue; - private EntityColdStartWorker entityColdStartQueue; - private ColdEntityWorker coldEntityQueue; + private ADResultWriteWorker resultWriteQueue; + private ADCheckpointReadWorker checkpointReadQueue; + private ADColdStartWorker entityColdStartQueue; + private ADColdEntityWorker coldEntityQueue; private String app0 = "app_0"; private String server1 = "server_1"; private String server2 = "server_2"; @@ -171,7 +172,7 @@ public class MultiEntityResultTests extends AbstractTimeSeriesTest { private String serviceField = "service"; private String hostField = "host"; private Map attrs1, attrs2, attrs3; - private EntityCache entityCache; + private ADPriorityCache entityCache; private ADTaskManager adTaskManager; @BeforeClass @@ -222,7 +223,7 @@ public void setUp() throws Exception { featureQuery = mock(FeatureManager.class); - normalModelManager = mock(ModelManager.class); + normalModelManager = mock(ADModelManager.class); hashRing = mock(HashRing.class); @@ -248,13 +249,13 @@ public void setUp() throws Exception { adCircuitBreakerService = mock(CircuitBreakerService.class); when(adCircuitBreakerService.isOpen()).thenReturn(false); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; adStats = new ADStats(statsMap); @@ -281,7 +282,6 @@ public void setUp() throws Exception { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -292,19 +292,19 @@ public void setUp() throws Exception { adTaskManager ); - provider = mock(CacheProvider.class); - entityCache = mock(EntityCache.class); + provider = mock(ADCacheProvider.class); + entityCache = mock(ADPriorityCache.class); when(provider.get()).thenReturn(entityCache); when(entityCache.get(any(), any())) .thenReturn(MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build())); when(entityCache.selectUpdateCandidate(any(), any(), any())).thenReturn(Pair.of(new ArrayList(), new ArrayList())); indexUtil = mock(ADIndexManagement.class); - resultWriteQueue = mock(ResultWriteWorker.class); - checkpointReadQueue = mock(CheckpointReadWorker.class); - entityColdStartQueue = mock(EntityColdStartWorker.class); + resultWriteQueue = mock(ADResultWriteWorker.class); + checkpointReadQueue = mock(ADCheckpointReadWorker.class); + entityColdStartQueue = mock(ADColdStartWorker.class); - coldEntityQueue = mock(ColdEntityWorker.class); + coldEntityQueue = mock(ADColdEntityWorker.class); attrs1 = new HashMap<>(); attrs1.put(serviceField, app0); @@ -328,17 +328,17 @@ public final void tearDown() throws Exception { public void testColdStartEndRunException() { when(stateManager.fetchExceptionAndClear(anyString())) - .thenReturn( + .thenReturn( Optional - .of( + .of( new EndRunException( - detectorId, - CommonMessages.INVALID_SEARCH_QUERY_MSG, - new NoSuchElementException("No value present"), - false + detectorId, + CommonMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + false + ) ) - ) - ); + ); PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(null, request, listener); assertException(listener, EndRunException.class, CommonMessages.INVALID_SEARCH_QUERY_MSG); @@ -397,7 +397,7 @@ public String executor() { private void setUpEntityResult(int nodeIndex, NodeStateManager nodeStateManager) { // register entity result action - new EntityResultTransportAction( + new EntityADResultTransportAction( new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[nodeIndex].transportService, @@ -411,11 +411,11 @@ private void setUpEntityResult(int nodeIndex, NodeStateManager nodeStateManager) coldEntityQueue, threadPool, entityColdStartQueue, - adStats + adStats, + mock(ADSaveResultStrategy.class) ); - when(normalModelManager.getAnomalyResultForEntity(any(), any(), any(), any(), anyInt())) - .thenReturn(new ThresholdingResult(0, 1, 1)); + when(normalModelManager.getResult(any(), any(), any(), any(), any(), any())).thenReturn(new ThresholdingResult(0, 1, 1)); } private void setUpEntityResult(int nodeIndex) { @@ -430,7 +430,7 @@ public void setUpNormlaStateManager() throws IOException { .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, CommonName.CONFIG_INDEX)); return null; }).when(client).get(any(GetRequest.class), any(ActionListener.class)); @@ -457,7 +457,6 @@ public void setUpNormlaStateManager() throws IOException { clientUtil, stateManager, featureQuery, - normalModelManager, hashRing, clusterService, indexNameResolver, @@ -538,11 +537,7 @@ public void testIndexNotFound() throws InterruptedException, IOException { PlainActionFuture listener2 = new PlainActionFuture<>(); action.doExecute(null, request, listener2); Exception e = expectThrows(EndRunException.class, () -> listener2.actionGet(10000L)); - assertThat( - "actual message: " + e.getMessage(), - e.getMessage(), - containsString(AnomalyResultTransportAction.TROUBLE_QUERYING_ERR_MSG) - ); + assertThat("actual message: " + e.getMessage(), e.getMessage(), containsString(ResultProcessor.TROUBLE_QUERYING_ERR_MSG)); assertTrue(!((EndRunException) e).isEndNow()); } @@ -661,7 +656,7 @@ public void sendRequest( TransportRequestOptions options, TransportResponseHandler handler ) { - if (action.equals(EntityResultAction.NAME)) { + if (action.equals(EntityADResultAction.NAME)) { sender .sendRequest( connection, @@ -693,7 +688,6 @@ public void sendRequest( clientUtil, nodeStateManager, featureQuery, - normalModelManager, hashRing, realClusterService, indexNameResolver, @@ -715,7 +709,7 @@ public void testNonEmptyFeatures() throws InterruptedException, IOException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); setUpEntityResult(1); @@ -766,14 +760,14 @@ public void testCircuitBreakerOpen() throws InterruptedException, IOException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler, spyStateManager); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); CircuitBreakerService openBreaker = mock(CircuitBreakerService.class); when(openBreaker.isOpen()).thenReturn(true); // register entity result action - new EntityResultTransportAction( + new EntityADResultTransportAction( new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[1].transportService, @@ -787,7 +781,8 @@ public void testCircuitBreakerOpen() throws InterruptedException, IOException { coldEntityQueue, threadPool, entityColdStartQueue, - adStats + adStats, + mock(ADSaveResultStrategy.class) ); CountDownLatch inProgress = new CountDownLatch(1); @@ -816,7 +811,7 @@ public void testNotAck() throws InterruptedException, IOException { setUpSearchResponse(); setUpTransportInterceptor(this::unackEntityResultHandler); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); setUpEntityResult(1); @@ -847,13 +842,13 @@ public void testMultipleNode() throws InterruptedException, IOException { Entity entity3 = Entity.createEntityByReordering(attrs3); // we use ordered attributes values as the key to hashring - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity1.toString()))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(entity1.toString()))) .thenReturn(Optional.of(testNodes[2].discoveryNode())); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity2.toString()))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(entity2.toString()))) .thenReturn(Optional.of(testNodes[3].discoveryNode())); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity3.toString()))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(entity3.toString()))) .thenReturn(Optional.of(testNodes[4].discoveryNode())); for (int i = 2; i <= 4; i++) { @@ -883,7 +878,7 @@ public void testCacheSelectionError() throws IOException, InterruptedException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); setUpEntityResult(1); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); List hotEntities = new ArrayList<>(); @@ -916,21 +911,21 @@ public void testCacheSelectionError() throws IOException, InterruptedException { assertTrue(modelNodeInProgress.await(10000L, TimeUnit.MILLISECONDS)); // size 0 because cacheMissEntities has no record of these entities - verify(checkpointReadQueue).putAll(argThat(new ArgumentMatcher>() { + verify(checkpointReadQueue).putAll(argThat(new ArgumentMatcher>() { @Override - public boolean matches(List argument) { - List arg = (argument); + public boolean matches(List argument) { + List arg = (argument); LOG.info("size: " + arg.size()); return arg.size() == 0; } })); - verify(coldEntityQueue).putAll(argThat(new ArgumentMatcher>() { + verify(coldEntityQueue).putAll(argThat(new ArgumentMatcher>() { @Override - public boolean matches(List argument) { - List arg = (argument); + public boolean matches(List argument) { + List arg = (argument); LOG.info("size: " + arg.size()); return arg.size() == 0; } @@ -940,7 +935,7 @@ public boolean matches(List argument) { public void testCacheSelection() throws IOException, InterruptedException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); List hotEntities = new ArrayList<>(); @@ -951,13 +946,13 @@ public void testCacheSelection() throws IOException, InterruptedException { Entity entity2 = Entity.createEntityByReordering(attrs2); coldEntities.add(entity2); - provider = mock(CacheProvider.class); - entityCache = mock(EntityCache.class); + provider = mock(ADCacheProvider.class); + entityCache = mock(ADPriorityCache.class); when(provider.get()).thenReturn(entityCache); when(entityCache.selectUpdateCandidate(any(), any(), any())).thenReturn(Pair.of(hotEntities, coldEntities)); when(entityCache.get(any(), any())).thenReturn(null); - new EntityResultTransportAction( + new EntityADResultTransportAction( new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[1].transportService, @@ -971,7 +966,8 @@ public void testCacheSelection() throws IOException, InterruptedException { coldEntityQueue, threadPool, entityColdStartQueue, - adStats + adStats, + mock(ADSaveResultStrategy.class) ); CountDownLatch modelNodeInProgress = new CountDownLatch(1); @@ -987,21 +983,21 @@ public void testCacheSelection() throws IOException, InterruptedException { action.doExecute(null, request, listener); assertTrue(modelNodeInProgress.await(10000L, TimeUnit.MILLISECONDS)); - verify(checkpointReadQueue).putAll(argThat(new ArgumentMatcher>() { + verify(checkpointReadQueue).putAll(argThat(new ArgumentMatcher>() { @Override - public boolean matches(List argument) { - List arg = (argument); + public boolean matches(List argument) { + List arg = (argument); LOG.info("size: " + arg.size() + " ; element: " + arg.get(0)); return arg.size() == 1 && arg.get(0).getEntity().equals(entity1); } })); - verify(coldEntityQueue).putAll(argThat(new ArgumentMatcher>() { + verify(coldEntityQueue).putAll(argThat(new ArgumentMatcher>() { @Override - public boolean matches(List argument) { - List arg = (argument); + public boolean matches(List argument) { + List arg = (argument); LOG.info("size: " + arg.size() + " ; element: " + arg.get(0)); return arg.size() == 1 && arg.get(0).getEntity().equals(entity2); } @@ -1131,7 +1127,7 @@ public void testRetry() throws IOException, InterruptedException { }).when(coldEntityQueue).putAll(any()); setUpTransportInterceptor(this::entityResultHandler); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); setUpEntityResult(1); @@ -1166,7 +1162,8 @@ public void testPageToString() { 10000, 1000, indexNameResolver, - clusterService + clusterService, + AnalysisType.AD ); Map results = new HashMap<>(); Entity entity1 = Entity.createEntityByReordering(attrs1); @@ -1193,7 +1190,8 @@ public void testEmptyPageToString() { 10000, 1000, indexNameResolver, - clusterService + clusterService, + AnalysisType.AD ); CompositeRetriever.Page page = retriever.new Page(null); @@ -1207,7 +1205,7 @@ private NodeStateManager setUpTestExceptionTestingInModelNode() throws IOExcepti setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); NodeStateManager modelNodeStateManager = mock(NodeStateManager.class); diff --git a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java index 38cdce966..eee68c4ae 100644 --- a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java @@ -45,10 +45,8 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.AnomalyDetectorRunner; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.Features; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -74,6 +72,8 @@ import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.Features; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; @@ -85,7 +85,7 @@ public class PreviewAnomalyDetectorTransportActionTests extends OpenSearchSingle private AnomalyDetectorRunner runner; private ClusterService clusterService; private FeatureManager featureManager; - private ModelManager modelManager; + private ADModelManager modelManager; private Task task; private CircuitBreakerService circuitBreaker; @@ -127,7 +127,7 @@ public void setUp() throws Exception { when(clusterService.state()).thenReturn(clusterState); featureManager = mock(FeatureManager.class); - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); runner = new AnomalyDetectorRunner(modelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); circuitBreaker = mock(CircuitBreakerService.class); when(circuitBreaker.isOpen()).thenReturn(false); @@ -173,7 +173,7 @@ public void onFailure(Exception e) { } }; - doReturn(TestHelpers.randomThresholdingResults()).when(modelManager).getPreviewResults(any(), anyInt()); + doReturn(TestHelpers.randomThresholdingResults()).when(modelManager).getPreviewResults(any(), anyInt(), anyInt()); doAnswer(responseMock -> { Long startTime = responseMock.getArgument(1); @@ -373,7 +373,7 @@ public void onFailure(Exception e) { Assert.assertTrue(false); } }; - doReturn(TestHelpers.randomThresholdingResults()).when(modelManager).getPreviewResults(any(), anyInt()); + doReturn(TestHelpers.randomThresholdingResults()).when(modelManager).getPreviewResults(any(), anyInt(), anyInt()); doAnswer(responseMock -> { Long startTime = responseMock.getArgument(1); diff --git a/src/test/java/org/opensearch/ad/transport/ProfileITTests.java b/src/test/java/org/opensearch/ad/transport/ProfileITTests.java index 013f00097..6bd8aa326 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileITTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileITTests.java @@ -16,10 +16,12 @@ import java.util.HashSet; import java.util.concurrent.ExecutionException; -import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; public class ProfileITTests extends OpenSearchIntegTestCase { @@ -33,9 +35,9 @@ protected Collection> transportClientPlugins() { } public void testNormalProfile() throws ExecutionException, InterruptedException { - ProfileRequest profileRequest = new ProfileRequest("123", new HashSet(), false); + ProfileRequest profileRequest = new ProfileRequest("123", new HashSet(), false); - ProfileResponse response = client().execute(ProfileAction.INSTANCE, profileRequest).get(); + ProfileResponse response = client().execute(ADProfileAction.INSTANCE, profileRequest).get(); assertTrue("getting profile failed", !response.hasFailures()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTests.java b/src/test/java/org/opensearch/ad/transport/ProfileTests.java index 7df0d5e02..46728f79e 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileTests.java @@ -30,9 +30,6 @@ import org.opensearch.Version; import org.opensearch.action.FailedNodeException; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -42,6 +39,12 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ModelProfileOnNode; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.ProfileNodeRequest; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; import com.google.gson.JsonArray; import com.google.gson.JsonElement; @@ -112,11 +115,11 @@ public void setUp() throws Exception { @Test public void testProfileNodeRequest() throws IOException { - Set profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.COORDINATING_NODE); + Set profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.COORDINATING_NODE); ProfileRequest ProfileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false); ProfileNodeRequest ProfileNodeRequest = new ProfileNodeRequest(ProfileRequest); - assertEquals("ProfileNodeRequest has the wrong detector id", ProfileNodeRequest.getId(), detectorId); + assertEquals("ProfileNodeRequest has the wrong detector id", ProfileNodeRequest.getConfigId(), detectorId); assertEquals("ProfileNodeRequest has the wrong ProfileRequest", ProfileNodeRequest.getProfilesToBeRetrieved(), profilesToRetrieve); // Test serialization @@ -124,7 +127,7 @@ public void testProfileNodeRequest() throws IOException { ProfileNodeRequest.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); ProfileNodeRequest nodeRequest = new ProfileNodeRequest(streamInput); - assertEquals("serialization has the wrong detector id", nodeRequest.getId(), detectorId); + assertEquals("serialization has the wrong detector id", nodeRequest.getConfigId(), detectorId); assertEquals("serialization has the wrong ProfileRequest", nodeRequest.getProfilesToBeRetrieved(), profilesToRetrieve); } @@ -162,14 +165,14 @@ public void testProfileNodeResponse() throws IOException, JsonPathNotFoundExcept ); } - assertEquals("toXContent has the wrong shingle size", JsonDeserializer.getIntValue(json, ADCommonName.SHINGLE_SIZE), shingleSize); + assertEquals("toXContent has the wrong shingle size", JsonDeserializer.getIntValue(json, CommonName.SHINGLE_SIZE), shingleSize); } @Test public void testProfileRequest() throws IOException { String detectorId = "123"; - Set profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.COORDINATING_NODE); + Set profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.COORDINATING_NODE); ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false); // Test Serialization @@ -182,7 +185,7 @@ public void testProfileRequest() throws IOException { readRequest.getProfilesToBeRetrieved(), profileRequest.getProfilesToBeRetrieved() ); - assertEquals("Serialization has the wrong detector id", readRequest.getId(), profileRequest.getId()); + assertEquals("Serialization has the wrong detector id", readRequest.getConfigId(), profileRequest.getConfigId()); } @Test diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java index bccd385bb..36d2b29f8 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java @@ -29,35 +29,39 @@ import org.junit.Test; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.Settings; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.ProfileNodeRequest; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; import org.opensearch.transport.TransportService; public class ProfileTransportActionTests extends OpenSearchIntegTestCase { - private ProfileTransportAction action; + private DelegateADProfileTransportAction action; private String detectorId = "Pl536HEBnXkDrah03glg"; String node1, nodeName1; DiscoveryNode discoveryNode1; - Set profilesToRetrieve = new HashSet(); + Set profilesToRetrieve = new HashSet(); private int shingleSize = 6; private long modelSize = 4456448L; private String modelId = "Pl536HEBnXkDrah03glg_model_rcf_1"; - private CacheProvider cacheProvider; + private ADCacheProvider cacheProvider; private int activeEntities = 10; private long totalUpdates = 127; private long multiEntityModelSize = 712480L; - private ModelManager modelManager; + private ADModelManager modelManager; private FeatureManager featureManager; @Override @@ -65,13 +69,13 @@ public class ProfileTransportActionTests extends OpenSearchIntegTestCase { public void setUp() throws Exception { super.setUp(); - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); featureManager = mock(FeatureManager.class); when(featureManager.getShingleSize(any(String.class))).thenReturn(shingleSize); - EntityCache cache = mock(EntityCache.class); - cacheProvider = mock(CacheProvider.class); + ADPriorityCache cache = mock(ADPriorityCache.class); + cacheProvider = mock(ADCacheProvider.class); when(cacheProvider.get()).thenReturn(cache); when(cache.getActiveEntities(anyString())).thenReturn(activeEntities); when(cache.getTotalUpdates(anyString())).thenReturn(totalUpdates); @@ -98,7 +102,7 @@ public void setUp() throws Exception { Settings settings = Settings.builder().put("plugins.anomaly_detection.max_model_size_per_node", 100).build(); - action = new ProfileTransportAction( + action = new DelegateADProfileTransportAction( client().threadPool(), clusterService(), mock(TransportService.class), @@ -109,8 +113,8 @@ public void setUp() throws Exception { settings ); - profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.COORDINATING_NODE); + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.COORDINATING_NODE); } private void setUpModelSize(int maxModel) { @@ -145,7 +149,7 @@ public void testNewNodeRequest() { ProfileNodeRequest profileNodeRequest1 = new ProfileNodeRequest(profileRequest); ProfileNodeRequest profileNodeRequest2 = action.newNodeRequest(profileRequest); - assertEquals(profileNodeRequest1.getId(), profileNodeRequest2.getId()); + assertEquals(profileNodeRequest1.getConfigId(), profileNodeRequest2.getConfigId()); assertEquals(profileNodeRequest2.getProfilesToBeRetrieved(), profileNodeRequest2.getProfilesToBeRetrieved()); } @@ -160,8 +164,8 @@ public void testNodeOperation() { assertEquals(shingleSize, response.getShingleSize()); assertEquals(null, response.getModelSize()); - profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.TOTAL_SIZE_IN_BYTES); + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.TOTAL_SIZE_IN_BYTES); profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false, nodeId); response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); @@ -175,8 +179,8 @@ public void testNodeOperation() { public void testMultiEntityNodeOperation() { setUpModelSize(100); DiscoveryNode nodeId = clusterService().localNode(); - profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.ACTIVE_ENTITIES); + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.ACTIVE_ENTITIES); ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); ProfileNodeResponse response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); @@ -184,7 +188,7 @@ public void testMultiEntityNodeOperation() { assertEquals(activeEntities, response.getActiveEntities()); assertEquals(null, response.getModelSize()); - profilesToRetrieve.add(DetectorProfileName.INIT_PROGRESS); + profilesToRetrieve.add(ProfileName.INIT_PROGRESS); profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); @@ -193,7 +197,7 @@ public void testMultiEntityNodeOperation() { assertEquals(null, response.getModelSize()); assertEquals(totalUpdates, response.getTotalUpdates()); - profilesToRetrieve.add(DetectorProfileName.MODELS); + profilesToRetrieve.add(ProfileName.MODELS); profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); @@ -210,7 +214,7 @@ public void testModelCount() { Settings settings = Settings.builder().put("plugins.anomaly_detection.max_model_size_per_node", 1).build(); - action = new ProfileTransportAction( + action = new DelegateADProfileTransportAction( client().threadPool(), clusterService(), mock(TransportService.class), @@ -222,8 +226,8 @@ public void testModelCount() { ); DiscoveryNode nodeId = clusterService().localNode(); - profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.MODELS); + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.MODELS); ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); ProfileNodeResponse response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); assertEquals(2, response.getModelCount()); diff --git a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java index 7a91fc5f9..13443b596 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java @@ -28,10 +28,9 @@ import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -44,6 +43,7 @@ import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.transport.ConnectTransportException; @@ -70,7 +70,7 @@ public class RCFPollingTests extends AbstractTimeSeriesTest { private ClusterService clusterService; private HashRing hashRing; private TransportAddress transportAddress1; - private ModelManager manager; + private ADModelManager manager; private TransportService transportService; private PlainActionFuture future; private RCFPollingTransportAction action; @@ -105,7 +105,7 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); hashRing = mock(HashRing.class); transportAddress1 = new TransportAddress(new InetSocketAddress(InetAddress.getByName("1.2.3.4"), 9300)); - manager = mock(ModelManager.class); + manager = mock(ADModelManager.class); transportService = new TransportService( Settings.EMPTY, mock(Transport.class), @@ -191,7 +191,7 @@ public void testDoubleNaN() { public void testNormal() { DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(Optional.of(localNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(Optional.of(localNode)); when(clusterService.localNode()).thenReturn(localNode); @@ -210,15 +210,15 @@ public void testNormal() { } public void testNoNodeFoundForModel() { - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(Optional.empty()); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(Optional.empty()); action = new RCFPollingTransportAction( - mock(ActionFilters.class), - transportService, - Settings.EMPTY, - manager, - hashRing, - clusterService - ); + mock(ActionFilters.class), + transportService, + Settings.EMPTY, + manager, + hashRing, + clusterService + ); action.doExecute(mock(Task.class), request, future); assertException(future, TimeSeriesException.class, RCFPollingTransportAction.NO_NODE_FOUND_MSG); } @@ -307,7 +307,7 @@ public void testGetRemoteNormalResponse() { clusterService ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); @@ -335,7 +335,7 @@ public void testGetRemoteFailureResponse() { clusterService ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); diff --git a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java index fad6c9ab0..a3ed6ee7c 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java @@ -37,15 +37,12 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -57,8 +54,11 @@ import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportService; @@ -88,10 +88,10 @@ public void setUp() throws Exception { hashRing = mock(HashRing.class); node = mock(DiscoveryNode.class); doReturn(Optional.of(node)).when(hashRing).getNodeByAddress(any()); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; @@ -111,7 +111,7 @@ public void testNormal() { NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); CircuitBreakerService adCircuitBreakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), @@ -170,7 +170,7 @@ public void testExecutionException() { NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); CircuitBreakerService adCircuitBreakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), @@ -287,7 +287,7 @@ public void testCircuitBreaker() { NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); CircuitBreakerService breakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), @@ -339,7 +339,7 @@ public void testCorruptModel() { NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); CircuitBreakerService adCircuitBreakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), @@ -363,7 +363,7 @@ public void testCorruptModel() { action.doExecute(mock(Task.class), request, future); expectThrows(IllegalArgumentException.class, () -> future.actionGet()); - Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + Object val = adStats.getStat(StatNames.AD_MODEL_CORRUTPION_COUNT.getName()).getValue(); assertEquals(1L, ((Long) val).longValue()); verify(manager, times(1)).clear(eq(detectorId), any()); } diff --git a/src/test/java/org/opensearch/ad/transport/RCFResultUnitTests.java b/src/test/java/org/opensearch/ad/transport/RCFResultUnitTests.java new file mode 100644 index 000000000..6bd2f8d53 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/RCFResultUnitTests.java @@ -0,0 +1,153 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.mockito.ArgumentCaptor; +import org.opensearch.Version; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.transport.ResultProcessor; + +public class RCFResultUnitTests extends AbstractTimeSeriesTest { + private String adID = "123"; + + // For single stream detector + class RCFActionListener implements ActionListener { + private String modelID; + private AtomicReference failure; + private String rcfNodeID; + private Config detector; + private ActionListener listener; + private List featureInResponse; + private final String adID; + + RCFActionListener( + String modelID, + AtomicReference failure, + String rcfNodeID, + Config detector, + ActionListener listener, + List features, + String adID + ) { + this.modelID = modelID; + this.failure = failure; + this.rcfNodeID = rcfNodeID; + this.detector = detector; + this.listener = listener; + this.featureInResponse = features; + this.adID = adID; + } + + @Override + public void onResponse(RCFResultResponse response) { + try { + if (response != null) { + listener + .onResponse( + new AnomalyResultResponse( + response.getAnomalyGrade(), + response.getConfidence(), + response.getRCFScore(), + featureInResponse, + null, + response.getTotalUpdates(), + detector.getIntervalInMinutes(), + false, + response.getRelativeIndex(), + response.getAttribution(), + response.getPastValues(), + response.getExpectedValuesList(), + response.getLikelihoodOfValues(), + response.getThreshold(), + null + ) + ); + } else { + LOG.warn(ResultProcessor.NULL_RESPONSE + " {} for {}", modelID, rcfNodeID); + listener.onFailure(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); + } + } catch (Exception ex) { + LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); + ResultProcessor.handleExecuteException(ex, listener, adID); + } + } + + @Override + public void onFailure(Exception exception) { + try { + if (exception != null) { + listener.onFailure(exception); + } else { + listener.onFailure(new InternalFailure(adID, "Node connection problem or unexpected exception")); + } + } catch (Exception ex) { + LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); + ResultProcessor.handleExecuteException(ex, listener, adID); + } + } + } + + // no exception should be thrown + @SuppressWarnings("unchecked") + public void testOnFailureNull() throws IOException { + RCFActionListener listener = new RCFActionListener(null, null, null, null, mock(ActionListener.class), null, null); + listener.onFailure(null); + } + + @SuppressWarnings("unchecked") + public void testNullRCFResult() { + RCFActionListener listener = new RCFActionListener("123-rcf-0", null, "123", null, mock(ActionListener.class), null, null); + listener.onResponse(null); + assertTrue(testAppender.containsMessage(ResultProcessor.NULL_RESPONSE)); + } + + @SuppressWarnings("unchecked") + public void testNormalRCFResult() { + ActionListener listener = mock(ActionListener.class); + AnomalyDetector detector = mock(AnomalyDetector.class); + RCFActionListener rcfListener = new RCFActionListener("123-rcf-0", null, "nodeID", detector, listener, null, adID); + double[] attribution = new double[] { 1. }; + long totalUpdates = 32; + double grade = 0.5; + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(AnomalyResultResponse.class); + rcfListener + .onResponse(new RCFResultResponse(0.3, 0, 26, attribution, totalUpdates, grade, Version.CURRENT, 0, null, null, null, 1.1)); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(grade, responseCaptor.getValue().getAnomalyGrade(), 1e-10); + } + + @SuppressWarnings("unchecked") + public void testNullPointerRCFResult() { + ActionListener listener = mock(ActionListener.class); + // detector being null causes NullPointerException + RCFActionListener rcfListener = new RCFActionListener("123-rcf-0", null, "nodeID", null, listener, null, adID); + double[] attribution = new double[] { 1. }; + long totalUpdates = 32; + double grade = 0.5; + ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(Exception.class); + rcfListener + .onResponse(new RCFResultResponse(0.3, 0, 26, attribution, totalUpdates, grade, Version.CURRENT, 0, null, null, null, 1.1)); + verify(listener, times(1)).onFailure(failureCaptor.capture()); + Exception failure = failureCaptor.getValue(); + assertTrue(failure instanceof InternalFailure); + } + +} diff --git a/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java index bc87faf13..65b0ee95d 100644 --- a/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java @@ -24,13 +24,13 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.HistoricalAnalysisIntegTestCase; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ADTask; import org.opensearch.common.settings.Settings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.model.TimeSeriesTask; @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) public class SearchADTasksTransportActionTests extends HistoricalAnalysisIntegTestCase { @@ -81,7 +81,7 @@ public void testSearchWithExistingTask() throws IOException { private SearchRequest searchRequest(boolean isLatest) { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(ADTask.IS_LATEST_FIELD, isLatest)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, isLatest)); sourceBuilder.query(query); SearchRequest request = new SearchRequest().source(sourceBuilder).indices(ADCommonName.DETECTION_STATE_INDEX); return request; diff --git a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java index 796d492e1..6a0dd234b 100644 --- a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java @@ -21,7 +21,6 @@ import org.junit.Before; import org.junit.Test; import org.opensearch.action.FailedNodeException; -import org.opensearch.ad.stats.ADStatsResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; @@ -30,6 +29,10 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsNodesResponse; +import org.opensearch.timeseries.transport.StatsResponse; +import org.opensearch.timeseries.transport.StatsTimeSeriesResponse; public class StatsAnomalyDetectorActionTests extends OpenSearchTestCase { @@ -47,20 +50,20 @@ public void testStatsAction() { @Test public void testStatsResponse() throws IOException { - ADStatsResponse adStatsResponse = new ADStatsResponse(); + StatsResponse adStatsResponse = new StatsResponse(); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_response", 1); adStatsResponse.setClusterStats(testClusterStats); - List responses = Collections.emptyList(); + List responses = Collections.emptyList(); List failures = Collections.emptyList(); - ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); - adStatsResponse.setADStatsNodesResponse(adStatsNodesResponse); + StatsNodesResponse adStatsNodesResponse = new StatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse.setStatsNodesResponse(adStatsNodesResponse); - StatsAnomalyDetectorResponse response = new StatsAnomalyDetectorResponse(adStatsResponse); + StatsTimeSeriesResponse response = new StatsTimeSeriesResponse(adStatsResponse); BytesStreamOutput out = new BytesStreamOutput(); response.writeTo(out); StreamInput input = out.bytes().streamInput(); - StatsAnomalyDetectorResponse newResponse = new StatsAnomalyDetectorResponse(input); + StatsTimeSeriesResponse newResponse = new StatsTimeSeriesResponse(input); assertNotNull(newResponse); XContentBuilder builder = XContentFactory.jsonBuilder(); diff --git a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java index 7c877c086..ceff494ed 100644 --- a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java @@ -16,9 +16,11 @@ import org.junit.Before; import org.opensearch.ad.ADIntegTestCase; -import org.opensearch.ad.stats.InternalStatNames; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.stats.InternalStatNames; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.transport.StatsTimeSeriesResponse; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -47,14 +49,14 @@ public void setUp() throws Exception { } public void testStatsAnomalyDetectorWithNodeLevelStats() { - ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); + StatsRequest adStatsRequest = new StatsRequest(clusterService().localNode()); adStatsRequest.addStat(InternalStatNames.JVM_HEAP_USAGE.getName()); - StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); - assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); + StatsTimeSeriesResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getStatsNodesResponse().getNodes().size()); assertTrue( response .getAdStatsResponse() - .getADStatsNodesResponse() + .getStatsNodesResponse() .getNodes() .get(0) .getStatsMap() @@ -63,39 +65,39 @@ public void testStatsAnomalyDetectorWithNodeLevelStats() { } public void testStatsAnomalyDetectorWithClusterLevelStats() { - ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); + StatsRequest adStatsRequest = new StatsRequest(clusterService().localNode()); adStatsRequest.addStat(StatNames.DETECTOR_COUNT.getName()); - adStatsRequest.addStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()); - StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); - assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); - Map statsMap = response.getAdStatsResponse().getADStatsNodesResponse().getNodes().get(0).getStatsMap(); + adStatsRequest.addStat(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()); + StatsTimeSeriesResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getStatsNodesResponse().getNodes().size()); + Map statsMap = response.getAdStatsResponse().getStatsNodesResponse().getNodes().get(0).getStatsMap(); Map clusterStats = response.getAdStatsResponse().getClusterStats(); assertEquals(0, statsMap.size()); assertEquals(2L, clusterStats.get(StatNames.DETECTOR_COUNT.getName())); - assertEquals(1L, clusterStats.get(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + assertEquals(1L, clusterStats.get(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())); } public void testStatsAnomalyDetectorWithDetectorCount() { - ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); + StatsRequest adStatsRequest = new StatsRequest(clusterService().localNode()); adStatsRequest.addStat(StatNames.DETECTOR_COUNT.getName()); - StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); - assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); - Map statsMap = response.getAdStatsResponse().getADStatsNodesResponse().getNodes().get(0).getStatsMap(); + StatsTimeSeriesResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getStatsNodesResponse().getNodes().size()); + Map statsMap = response.getAdStatsResponse().getStatsNodesResponse().getNodes().get(0).getStatsMap(); Map clusterStats = response.getAdStatsResponse().getClusterStats(); assertEquals(0, statsMap.size()); assertEquals(2L, clusterStats.get(StatNames.DETECTOR_COUNT.getName())); - assertFalse(clusterStats.containsKey(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + assertFalse(clusterStats.containsKey(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())); } public void testStatsAnomalyDetectorWithSingleEntityDetectorCount() { - ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); - adStatsRequest.addStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()); - StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); - assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); - Map statsMap = response.getAdStatsResponse().getADStatsNodesResponse().getNodes().get(0).getStatsMap(); + StatsRequest adStatsRequest = new StatsRequest(clusterService().localNode()); + adStatsRequest.addStat(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()); + StatsTimeSeriesResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getStatsNodesResponse().getNodes().size()); + Map statsMap = response.getAdStatsResponse().getStatsNodesResponse().getNodes().get(0).getStatsMap(); Map clusterStats = response.getAdStatsResponse().getClusterStats(); assertEquals(0, statsMap.size()); - assertEquals(1L, clusterStats.get(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + assertEquals(1L, clusterStats.get(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())); assertFalse(clusterStats.containsKey(StatNames.DETECTOR_COUNT.getName())); } diff --git a/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java index a2e98bf88..786d34cea 100644 --- a/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java @@ -26,6 +26,8 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; public class StopDetectorActionTests extends OpenSearchIntegTestCase { @@ -43,7 +45,7 @@ public void testStopDetectorAction() { @Test public void fromActionRequest_Success() { - StopDetectorRequest stopDetectorRequest = new StopDetectorRequest("adID"); + StopConfigRequest stopDetectorRequest = new StopConfigRequest("adID"); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -55,41 +57,41 @@ public void writeTo(StreamOutput out) throws IOException { stopDetectorRequest.writeTo(out); } }; - StopDetectorRequest result = StopDetectorRequest.fromActionRequest(actionRequest); + StopConfigRequest result = StopConfigRequest.fromActionRequest(actionRequest); assertNotSame(result, stopDetectorRequest); - assertEquals(result.getAdID(), stopDetectorRequest.getAdID()); + assertEquals(result.getConfigID(), stopDetectorRequest.getConfigID()); } @Test public void writeTo_Success() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - StopDetectorResponse response = new StopDetectorResponse(true); + StopConfigResponse response = new StopConfigResponse(true); response.writeTo(bytesStreamOutput); - StopDetectorResponse parsedResponse = new StopDetectorResponse(bytesStreamOutput.bytes().streamInput()); + StopConfigResponse parsedResponse = new StopConfigResponse(bytesStreamOutput.bytes().streamInput()); assertNotEquals(response, parsedResponse); assertEquals(response.success(), parsedResponse.success()); } @Test public void fromActionResponse_Success() throws IOException { - StopDetectorResponse stopDetectorResponse = new StopDetectorResponse(true); + StopConfigResponse stopDetectorResponse = new StopConfigResponse(true); ActionResponse actionResponse = new ActionResponse() { @Override public void writeTo(StreamOutput streamOutput) throws IOException { stopDetectorResponse.writeTo(streamOutput); } }; - StopDetectorResponse result = stopDetectorResponse.fromActionResponse(actionResponse); + StopConfigResponse result = stopDetectorResponse.fromActionResponse(actionResponse); assertNotSame(result, stopDetectorResponse); assertEquals(result.success(), stopDetectorResponse.success()); - StopDetectorResponse parsedStopDetectorResponse = stopDetectorResponse.fromActionResponse(stopDetectorResponse); + StopConfigResponse parsedStopDetectorResponse = stopDetectorResponse.fromActionResponse(stopDetectorResponse); assertEquals(parsedStopDetectorResponse, stopDetectorResponse); } @Test public void toXContentTest() throws IOException { - StopDetectorResponse stopDetectorResponse = new StopDetectorResponse(true); + StopConfigResponse stopDetectorResponse = new StopConfigResponse(true); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); stopDetectorResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); diff --git a/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java index 4457f0fb7..20c1b06ef 100644 --- a/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java @@ -29,7 +29,7 @@ import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -60,7 +60,7 @@ public void testNormal() { NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -90,7 +90,7 @@ public void testExecutionException() { NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); doThrow(NullPointerException.class) .when(manager) diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java index 4a1fae9cb..5f7a4ede4 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java @@ -21,7 +21,9 @@ import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.transport.ValidateConfigRequest; import com.google.common.collect.ImmutableMap; @@ -38,14 +40,14 @@ public void testValidateAnomalyDetectorRequestSerialization() throws IOException TimeValue requestTimeout = new TimeValue(1000L); String typeStr = "type"; - ValidateAnomalyDetectorRequest request1 = new ValidateAnomalyDetectorRequest(detector, typeStr, 1, 1, 1, requestTimeout); + ValidateConfigRequest request1 = new ValidateConfigRequest(AnalysisType.AD, detector, typeStr, 1, 1, 1, requestTimeout, 10); // Test serialization BytesStreamOutput output = new BytesStreamOutput(); request1.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); - ValidateAnomalyDetectorRequest request2 = new ValidateAnomalyDetectorRequest(input); - assertEquals("serialization has the wrong detector", request2.getDetector(), detector); + ValidateConfigRequest request2 = new ValidateConfigRequest(input); + assertEquals("serialization has the wrong detector", request2.getConfig(), detector); assertEquals("serialization has the wrong typeStr", request2.getValidationType(), typeStr); assertEquals("serialization has the wrong requestTimeout", request2.getRequestTimeout(), requestTimeout); } diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java index 510ed2683..0b28b67f7 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java @@ -16,12 +16,13 @@ import java.util.Map; import org.junit.Test; -import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.ConfigValidationIssue; +import org.opensearch.timeseries.transport.ValidateConfigResponse; public class ValidateAnomalyDetectorResponseTests extends AbstractTimeSeriesTest { @@ -30,22 +31,22 @@ public void testResponseSerialization() throws IOException { Map subIssues = new HashMap<>(); subIssues.put("a", "b"); subIssues.put("c", "d"); - DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithSubIssues(subIssues); - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + ConfigValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithSubIssues(subIssues); + ValidateConfigResponse response = new ValidateConfigResponse(issue); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ValidateAnomalyDetectorResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); + ValidateConfigResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); assertEquals("serialization has the wrong issue", issue, readResponse.getIssue()); } @Test public void testResponseSerializationWithEmptyIssue() throws IOException { - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse((DetectorValidationIssue) null); + ValidateConfigResponse response = new ValidateConfigResponse((ConfigValidationIssue) null); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ValidateAnomalyDetectorResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); + ValidateConfigResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); assertNull("serialization should have empty issue", readResponse.getIssue()); } @@ -53,8 +54,8 @@ public void testResponseToXContentWithSubIssues() throws IOException { Map subIssues = new HashMap<>(); subIssues.put("a", "b"); subIssues.put("c", "d"); - DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithSubIssues(subIssues); - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + ConfigValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithSubIssues(subIssues); + ValidateConfigResponse response = new ValidateConfigResponse(issue); String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); String message = issue.getMessage(); assertEquals( @@ -64,23 +65,23 @@ public void testResponseToXContentWithSubIssues() throws IOException { } public void testResponseToXContent() throws IOException { - DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssue(); - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + ConfigValidationIssue issue = TestHelpers.randomDetectorValidationIssue(); + ValidateConfigResponse response = new ValidateConfigResponse(issue); String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); String message = issue.getMessage(); assertEquals("{\"detector\":{\"name\":{\"message\":\"" + message + "\"}}}", validationResponse); } public void testResponseToXContentNull() throws IOException { - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse((DetectorValidationIssue) null); + ValidateConfigResponse response = new ValidateConfigResponse((ConfigValidationIssue) null); String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); assertEquals("{}", validationResponse); } public void testResponseToXContentWithIntervalRec() throws IOException { long intervalRec = 5; - DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithDetectorIntervalRec(intervalRec); - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + ConfigValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithDetectorIntervalRec(intervalRec); + ValidateConfigResponse response = new ValidateConfigResponse(issue); String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); assertEquals( "{\"model\":{\"detection_interval\":{\"message\":\"" @@ -94,12 +95,12 @@ public void testResponseToXContentWithIntervalRec() throws IOException { @Test public void testResponseSerializationWithIntervalRec() throws IOException { long intervalRec = 5; - DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithDetectorIntervalRec(intervalRec); - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + ConfigValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithDetectorIntervalRec(intervalRec); + ValidateConfigResponse response = new ValidateConfigResponse(issue); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ValidateAnomalyDetectorResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); + ValidateConfigResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); assertEquals(issue, readResponse.getIssue()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java index 604fc2c46..d9f6cb041 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java @@ -25,12 +25,15 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.common.unit.TimeValue; import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.ValidateConfigRequest; +import org.opensearch.timeseries.transport.ValidateConfigResponse; import com.google.common.base.Charsets; import com.google.common.collect.ImmutableList; @@ -44,30 +47,34 @@ public void testValidateAnomalyDetectorWithNoIssue() throws IOException { AnomalyDetector anomalyDetector = TestHelpers .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(sumValueFeature(nameField, ipField + ".is_error", "test-2"))); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNull(response.getIssue()); } @Test public void testValidateAnomalyDetectorWithNoIndexFound() throws IOException { AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(ValidationIssueType.INDICES, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); @@ -80,15 +87,17 @@ public void testValidateAnomalyDetectorWithDuplicateName() throws IOException { ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); createDetectorIndex(); createDetector(anomalyDetector); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(ValidationIssueType.NAME, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); @@ -99,15 +108,17 @@ public void testValidateAnomalyDetectorWithNonExistingFeatureField() throws IOEx Feature maxFeature = maxValueFeature(nameField, "non_existing_field", nameField); AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); @@ -123,15 +134,17 @@ public void testValidateAnomalyDetectorWithDuplicateFeatureAggregationNames() th AnomalyDetector anomalyDetector = TestHelpers .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertTrue(response.getIssue().getMessage().contains("Config has duplicate feature aggregation query names:")); assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); @@ -145,15 +158,17 @@ public void testValidateAnomalyDetectorWithDuplicateFeatureNamesAndDuplicateAggr AnomalyDetector anomalyDetector = TestHelpers .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertTrue(response.getIssue().getMessage().contains("Config has duplicate feature aggregation query names:")); assertTrue(response.getIssue().getMessage().contains("There are duplicate feature names:")); @@ -168,15 +183,17 @@ public void testValidateAnomalyDetectorWithDuplicateFeatureNames() throws IOExce AnomalyDetector anomalyDetector = TestHelpers .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertTrue( "actual: " + response.getIssue().getMessage(), @@ -191,15 +208,17 @@ public void testValidateAnomalyDetectorWithInvalidFeatureField() throws IOExcept Feature maxFeature = maxValueFeature(nameField, categoryField, nameField); AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); @@ -218,15 +237,17 @@ public void testValidateAnomalyDetectorWithUnknownFeatureField() throws IOExcept ImmutableList.of(new Feature(randomAlphaOfLength(5), nameField, true, aggregationBuilder)) ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); @@ -241,15 +262,17 @@ public void testValidateAnomalyDetectorWithMultipleInvalidFeatureField() throws AnomalyDetector anomalyDetector = TestHelpers .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(response.getIssue().getSubIssues().keySet().size(), 2); assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); @@ -273,15 +296,17 @@ public void testValidateAnomalyDetectorWithCustomResultIndex() throws IOExceptio resultIndex ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNull(response.getIssue()); } @@ -311,15 +336,17 @@ public void testValidateAnomalyDetectorWithCustomResultIndexWithInvalidMapping() resultIndex ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertEquals(ValidationIssueType.RESULT_INDEX, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); assertTrue(response.getIssue().getMessage().contains(CommonMessages.INVALID_RESULT_INDEX_MAPPING)); @@ -340,20 +367,23 @@ private void testValidateAnomalyDetectorWithCustomResultIndex(boolean resultInde resultIndex ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNull(response.getIssue()); } @Test public void testValidateAnomalyDetectorWithInvalidDetectorName() throws IOException { + Feature feature = TestHelpers.randomFeature(); AnomalyDetector anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), @@ -361,7 +391,7 @@ public void testValidateAnomalyDetectorWithInvalidDetectorName() throws IOExcept randomAlphaOfLength(5), timeField, ImmutableList.of(randomAlphaOfLength(5).toLowerCase(Locale.ROOT)), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -372,18 +402,24 @@ public void testValidateAnomalyDetectorWithInvalidDetectorName() throws IOExcept null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertEquals(ValidationIssueType.NAME, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); assertEquals(CommonMessages.INVALID_NAME, response.getIssue().getMessage()); @@ -391,6 +427,7 @@ public void testValidateAnomalyDetectorWithInvalidDetectorName() throws IOExcept @Test public void testValidateAnomalyDetectorWithDetectorNameTooLong() throws IOException { + Feature feature = TestHelpers.randomFeature(); AnomalyDetector anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), @@ -398,7 +435,7 @@ public void testValidateAnomalyDetectorWithDetectorNameTooLong() throws IOExcept randomAlphaOfLength(5), timeField, ImmutableList.of(randomAlphaOfLength(5).toLowerCase(Locale.ROOT)), - ImmutableList.of(TestHelpers.randomFeature()), + ImmutableList.of(feature), TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -409,18 +446,24 @@ public void testValidateAnomalyDetectorWithDetectorNameTooLong() throws IOExcept null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertEquals(ValidationIssueType.NAME, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); assertTrue(response.getIssue().getMessage().contains("Name should be shortened. The maximum limit is")); @@ -430,15 +473,17 @@ public void testValidateAnomalyDetectorWithDetectorNameTooLong() throws IOExcept public void testValidateAnomalyDetectorWithNonExistentTimefield() throws IOException { AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertEquals(ValidationIssueType.TIMEFIELD_FIELD, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); assertEquals( @@ -451,15 +496,17 @@ public void testValidateAnomalyDetectorWithNonExistentTimefield() throws IOExcep public void testValidateAnomalyDetectorWithNonDateTimeField() throws IOException { AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(categoryField, "index-test"); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertEquals(ValidationIssueType.TIMEFIELD_FIELD, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); assertEquals( diff --git a/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java b/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java index 12f966ffe..353611121 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java @@ -29,7 +29,6 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.transport.AnomalyResultTests; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; @@ -42,6 +41,7 @@ import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; public abstract class AbstractIndexHandlerTest extends AbstractTimeSeriesTest { enum IndexCreation { @@ -92,7 +92,7 @@ public void setUp() throws Exception { setWriteBlockAdResultIndex(false); context = TestHelpers.createThreadPool(); clientUtil = new ClientUtil(client); - indexUtil = new IndexUtils(client, clientUtil, clusterService, indexNameResolver); + indexUtil = new IndexUtils(clusterService, indexNameResolver); } protected void setWriteBlockAdResultIndex(boolean blocked) { diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java index 68699b74e..af3442433 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java @@ -34,9 +34,10 @@ import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.util.IndexUtils; +import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -46,17 +47,20 @@ import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; import com.google.common.collect.ImmutableList; public class AnomalyResultBulkIndexHandlerTests extends ADUnitTestCase { - private AnomalyResultBulkIndexHandler bulkIndexHandler; + private ResultBulkIndexingHandler bulkIndexHandler; private Client client; private IndexUtils indexUtils; private ActionListener listener; private ADIndexManagement anomalyDetectionIndices; + private String configId; @Override public void setUp() throws Exception { @@ -70,14 +74,17 @@ public void setUp() throws Exception { indexUtils = mock(IndexUtils.class); ClusterService clusterService = mock(ClusterService.class); ThreadPool threadPool = mock(ThreadPool.class); - bulkIndexHandler = new AnomalyResultBulkIndexHandler( + bulkIndexHandler = new ResultBulkIndexingHandler( client, settings, threadPool, + ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, clientUtil, indexUtils, clusterService, - anomalyDetectionIndices + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); listener = spy(new ActionListener() { @Override @@ -86,10 +93,11 @@ public void onResponse(BulkResponse bulkItemResponses) {} @Override public void onFailure(Exception e) {} }); + configId = "testId"; } public void testNullAnomalyResults() { - bulkIndexHandler.bulkIndexAnomalyResult(null, null, listener); + bulkIndexHandler.bulk(null, null, null, listener); verify(listener, times(1)).onResponse(null); verify(anomalyDetectionIndices, never()).doesConfigIndexExist(); } @@ -97,9 +105,9 @@ public void testNullAnomalyResults() { public void testAnomalyResultBulkIndexHandler_IndexNotExist() { when(anomalyDetectionIndices.doesIndexExist("testIndex")).thenReturn(false); AnomalyResult anomalyResult = mock(AnomalyResult.class); - when(anomalyResult.getConfigId()).thenReturn("testId"); + when(anomalyResult.getConfigId()).thenReturn(configId); - bulkIndexHandler.bulkIndexAnomalyResult("testIndex", ImmutableList.of(anomalyResult), listener); + bulkIndexHandler.bulk("testIndex", ImmutableList.of(anomalyResult), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals("Can't find result index testIndex", exceptionCaptor.getValue().getMessage()); } @@ -108,9 +116,10 @@ public void testAnomalyResultBulkIndexHandler_InValidResultIndexMapping() { when(anomalyDetectionIndices.doesIndexExist("testIndex")).thenReturn(true); when(anomalyDetectionIndices.isValidResultIndexMapping("testIndex")).thenReturn(false); AnomalyResult anomalyResult = mock(AnomalyResult.class); - when(anomalyResult.getConfigId()).thenReturn("testId"); - bulkIndexHandler.bulkIndexAnomalyResult("testIndex", ImmutableList.of(anomalyResult), listener); + when(anomalyResult.getConfigId()).thenReturn(configId); + + bulkIndexHandler.bulk("testIndex", ImmutableList.of(anomalyResult), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals("wrong index mapping of custom AD result index", exceptionCaptor.getValue().getMessage()); } @@ -119,10 +128,10 @@ public void testAnomalyResultBulkIndexHandler_FailBulkIndexAnomaly() throws IOEx when(anomalyDetectionIndices.doesIndexExist("testIndex")).thenReturn(true); when(anomalyDetectionIndices.isValidResultIndexMapping("testIndex")).thenReturn(true); AnomalyResult anomalyResult = mock(AnomalyResult.class); - when(anomalyResult.getConfigId()).thenReturn("testId"); + when(anomalyResult.getConfigId()).thenReturn(configId); when(anomalyResult.toXContent(any(), any())).thenThrow(new RuntimeException()); - bulkIndexHandler.bulkIndexAnomalyResult("testIndex", ImmutableList.of(anomalyResult), listener); + bulkIndexHandler.bulk("testIndex", ImmutableList.of(anomalyResult), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals("Failed to prepare request to bulk index anomaly results", exceptionCaptor.getValue().getMessage()); } @@ -133,7 +142,7 @@ public void testCreateADResultIndexNotAcknowledged() throws IOException { listener.onResponse(new CreateIndexResponse(false, false, ANOMALY_RESULT_INDEX_ALIAS)); return null; }).when(anomalyDetectionIndices).initDefaultResultIndexDirectly(any()); - bulkIndexHandler.bulkIndexAnomalyResult(null, ImmutableList.of(mock(AnomalyResult.class)), listener); + bulkIndexHandler.bulk(null, ImmutableList.of(mock(AnomalyResult.class)), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals("Creating anomaly result index with mappings call not acknowledged", exceptionCaptor.getValue().getMessage()); } @@ -166,8 +175,7 @@ public void testWrongAnomalyResult() { listener.onResponse(bulkResponse); return null; }).when(client).bulk(any(), any()); - bulkIndexHandler - .bulkIndexAnomalyResult(null, ImmutableList.of(wrongAnomalyResult(), TestHelpers.randomAnomalyDetectResult()), listener); + bulkIndexHandler.bulk(null, ImmutableList.of(wrongAnomalyResult(), TestHelpers.randomAnomalyDetectResult()), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertTrue(exceptionCaptor.getValue().getMessage().contains("VersionConflictEngineException")); } @@ -184,7 +192,7 @@ public void testBulkSaveException() { return null; }).when(client).bulk(any(), any()); - bulkIndexHandler.bulkIndexAnomalyResult(null, ImmutableList.of(TestHelpers.randomAnomalyDetectResult()), listener); + bulkIndexHandler.bulk(null, ImmutableList.of(TestHelpers.randomAnomalyDetectResult()), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals(testError, exceptionCaptor.getValue().getMessage()); } diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java index b17008e1d..616fc0a51 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java @@ -34,7 +34,10 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; @@ -42,6 +45,7 @@ import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.transport.handler.ResultIndexingHandler; public class AnomalyResultHandlerTests extends AbstractIndexHandlerTest { @Mock @@ -54,7 +58,7 @@ public class AnomalyResultHandlerTests extends AbstractIndexHandlerTest { @Before public void setUp() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(AnomalyIndexHandler.class); + super.setUpLog4jForJUnit(ResultIndexingHandler.class); } @Override @@ -81,7 +85,7 @@ public void testSavingAdResult() throws IOException { listener.onResponse(mock(IndexResponse.class)); return null; }).when(client).index(any(IndexRequest.class), ArgumentMatchers.>any()); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler<>( client, settings, threadPool, @@ -89,19 +93,21 @@ public void testSavingAdResult() throws IOException { anomalyDetectionIndices, clientUtil, indexUtil, - clusterService + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); - assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); + assertEquals(1, testAppender.countMessage(ResultIndexingHandler.SUCCESS_SAVING_MSG, true)); } @Test public void testSavingFailureNotRetry() throws InterruptedException, IOException { savingFailureTemplate(false, 1, true); - assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.FAIL_TO_SAVE_ERR_MSG, true)); - assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); - assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.RETRY_SAVING_ERR_MSG, true)); + assertEquals(1, testAppender.countMessage(ResultIndexingHandler.FAIL_TO_SAVE_ERR_MSG, true)); + assertTrue(!testAppender.containsMessage(ResultIndexingHandler.SUCCESS_SAVING_MSG, true)); + assertTrue(!testAppender.containsMessage(ResultIndexingHandler.RETRY_SAVING_ERR_MSG, true)); } @Test @@ -109,15 +115,15 @@ public void testSavingFailureRetry() throws InterruptedException, IOException { setWriteBlockAdResultIndex(false); savingFailureTemplate(true, 3, true); - assertEquals(2, testAppender.countMessage(AnomalyIndexHandler.RETRY_SAVING_ERR_MSG, true)); - assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.FAIL_TO_SAVE_ERR_MSG, true)); - assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); + assertEquals(2, testAppender.countMessage(ResultIndexingHandler.RETRY_SAVING_ERR_MSG, true)); + assertEquals(1, testAppender.countMessage(ResultIndexingHandler.FAIL_TO_SAVE_ERR_MSG, true)); + assertTrue(!testAppender.containsMessage(ResultIndexingHandler.SUCCESS_SAVING_MSG, true)); } @Test public void testIndexWriteBlock() { setWriteBlockAdResultIndex(true); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler<>( client, settings, threadPool, @@ -125,17 +131,19 @@ public void testIndexWriteBlock() { anomalyDetectionIndices, clientUtil, indexUtil, - clusterService + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); - assertTrue(testAppender.containsMessage(AnomalyIndexHandler.CANNOT_SAVE_ERR_MSG, true)); + assertTrue(testAppender.containsMessage(ResultIndexingHandler.CANNOT_SAVE_ERR_MSG, true)); } @Test public void testAdResultIndexExist() throws IOException { setUpSavingAnomalyResultIndex(false, IndexCreation.RESOURCE_EXISTS_EXCEPTION); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler<>( client, settings, threadPool, @@ -143,7 +151,9 @@ public void testAdResultIndexExist() throws IOException { anomalyDetectionIndices, clientUtil, indexUtil, - clusterService + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); verify(client, times(1)).index(any(), any()); @@ -155,7 +165,7 @@ public void testAdResultIndexOtherException() throws IOException { expectedEx.expectMessage("Error in saving .opendistro-anomaly-results for detector " + detectorId); setUpSavingAnomalyResultIndex(false, IndexCreation.RUNTIME_EXCEPTION); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler<>( client, settings, threadPool, @@ -163,7 +173,9 @@ public void testAdResultIndexOtherException() throws IOException { anomalyDetectionIndices, clientUtil, indexUtil, - clusterService + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); verify(client, never()).index(any(), any()); @@ -213,7 +225,7 @@ private void savingFailureTemplate(boolean throwOpenSearchRejectedExecutionExcep .put("plugins.anomaly_detection.backoff_initial_delay", TimeValue.timeValueMillis(1)) .build(); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler<>( client, backoffSettings, threadPool, @@ -221,7 +233,9 @@ private void savingFailureTemplate(boolean throwOpenSearchRejectedExecutionExcep anomalyDetectionIndices, clientUtil, indexUtil, - clusterService + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); diff --git a/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java index f6483c8b7..fa7e3acdb 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java @@ -23,36 +23,29 @@ import org.junit.Test; import org.mockito.ArgumentMatchers; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; import org.opensearch.ad.transport.ADResultBulkAction; import org.opensearch.ad.transport.ADResultBulkRequest; -import org.opensearch.ad.transport.ADResultBulkResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.transport.ResultBulkResponse; public class MultiEntityResultHandlerTests extends AbstractIndexHandlerTest { - private MultiEntityResultHandler handler; + private ADIndexMemoryPressureAwareResultHandler handler; private ADResultBulkRequest request; - private ADResultBulkResponse response; + private ResultBulkResponse response; @Override public void setUp() throws Exception { super.setUp(); - handler = new MultiEntityResultHandler( - client, - settings, - threadPool, - anomalyDetectionIndices, - clientUtil, - indexUtil, - clusterService - ); + handler = new ADIndexMemoryPressureAwareResultHandler(client, anomalyDetectionIndices); request = new ADResultBulkRequest(); - ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + ADResultWriteRequest resultWriteRequest = new ADResultWriteRequest( Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), detectorId, RequestPriority.MEDIUM, @@ -61,15 +54,15 @@ public void setUp() throws Exception { ); request.add(resultWriteRequest); - response = new ADResultBulkResponse(); + response = new ResultBulkResponse(); - super.setUpLog4jForJUnit(MultiEntityResultHandler.class); + super.setUpLog4jForJUnit(ADIndexMemoryPressureAwareResultHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); listener.onResponse(response); return null; - }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); + }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); } @Override @@ -89,10 +82,7 @@ public void testIndexWriteBlock() throws InterruptedException { verified.countDown(); }, exception -> { assertTrue(exception instanceof TimeSeriesException); - assertTrue( - "actual: " + exception.getMessage(), - exception.getMessage().contains(MultiEntityResultHandler.CANNOT_SAVE_RESULT_ERR_MSG) - ); + assertTrue("actual: " + exception.getMessage(), exception.getMessage().contains(CommonMessages.CANNOT_SAVE_RESULT_ERR_MSG)); verified.countDown(); })); @@ -109,17 +99,17 @@ public void testSavingAdResult() throws IOException, InterruptedException { verified.countDown(); })); assertTrue(verified.await(100, TimeUnit.SECONDS)); - assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + assertEquals(1, testAppender.countMessage(CommonMessages.SUCCESS_SAVING_RESULT_MSG, false)); } @Test public void testSavingFailure() throws IOException, InterruptedException { setUpSavingAnomalyResultIndex(false); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException()); return null; - }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); + }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); CountDownLatch verified = new CountDownLatch(1); handler.flush(request, ActionListener.wrap(response -> { @@ -142,7 +132,7 @@ public void testAdResultIndexExists() throws IOException, InterruptedException { verified.countDown(); })); assertTrue(verified.await(100, TimeUnit.SECONDS)); - assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + assertEquals(1, testAppender.countMessage(CommonMessages.SUCCESS_SAVING_RESULT_MSG, false)); } @Test @@ -200,6 +190,6 @@ public void testCreateResourcExistsException() throws IOException, InterruptedEx verified.countDown(); })); assertTrue(verified.await(100, TimeUnit.SECONDS)); - assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + assertEquals(1, testAppender.countMessage(CommonMessages.SUCCESS_SAVING_RESULT_MSG, false)); } } diff --git a/src/test/java/org/opensearch/ad/util/BulkUtilTests.java b/src/test/java/org/opensearch/ad/util/BulkUtilTests.java index aadc2d999..5a5e35e81 100644 --- a/src/test/java/org/opensearch/ad/util/BulkUtilTests.java +++ b/src/test/java/org/opensearch/ad/util/BulkUtilTests.java @@ -25,6 +25,7 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.util.BulkUtil; public class BulkUtilTests extends OpenSearchTestCase { public void testGetFailedIndexRequest() { diff --git a/src/test/java/org/opensearch/ad/util/DateUtilsTests.java b/src/test/java/org/opensearch/ad/util/DateUtilsTests.java index 593445b01..0a5a1fb40 100644 --- a/src/test/java/org/opensearch/ad/util/DateUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/DateUtilsTests.java @@ -15,6 +15,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.util.DateUtils; public class DateUtilsTests extends OpenSearchTestCase { public void testDuration() { diff --git a/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java b/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java index 7234f6feb..cbbffb869 100644 --- a/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java @@ -20,6 +20,7 @@ import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; public class IndexUtilsTests extends OpenSearchIntegTestCase { @@ -36,7 +37,7 @@ public void setup() { @Test public void testGetIndexHealth_NoIndex() { - IndexUtils indexUtils = new IndexUtils(client(), clientUtil, clusterService(), indexNameResolver); + IndexUtils indexUtils = new IndexUtils(clusterService(), indexNameResolver); String output = indexUtils.getIndexHealthStatus("test"); assertEquals(IndexUtils.NONEXISTENT_INDEX_STATUS, output); } @@ -46,7 +47,7 @@ public void testGetIndexHealth_Index() { String indexName = "test-2"; createIndex(indexName); flush(); - IndexUtils indexUtils = new IndexUtils(client(), clientUtil, clusterService(), indexNameResolver); + IndexUtils indexUtils = new IndexUtils(clusterService(), indexNameResolver); String status = indexUtils.getIndexHealthStatus(indexName); assertTrue(status.equals("green") || status.equals("yellow")); } @@ -59,7 +60,7 @@ public void testGetIndexHealth_Alias() { flush(); AcknowledgedResponse response = client().admin().indices().prepareAliases().addAlias(indexName, aliasName).execute().actionGet(); assertTrue(response.isAcknowledged()); - IndexUtils indexUtils = new IndexUtils(client(), clientUtil, clusterService(), indexNameResolver); + IndexUtils indexUtils = new IndexUtils(clusterService(), indexNameResolver); String status = indexUtils.getIndexHealthStatus(aliasName); assertTrue(status.equals("green") || status.equals("yellow")); } diff --git a/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java b/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java index c2dd673b4..af919c1cd 100644 --- a/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java @@ -11,7 +11,6 @@ package org.opensearch.ad.util; -import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; import static org.opensearch.timeseries.util.ParseUtils.isAdmin; import java.io.IOException; @@ -127,16 +126,17 @@ public void testGenerateInternalFeatureQuery() throws IOException { public void testAddUserRoleFilterWithNullUser() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - addUserBackendRolesFilter(null, searchSourceBuilder); + ParseUtils.addUserBackendRolesFilter(null, searchSourceBuilder); assertEquals("{}", searchSourceBuilder.toString()); } public void testAddUserRoleFilterWithNullUserBackendRole() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - addUserBackendRolesFilter( - new User(randomAlphaOfLength(5), null, ImmutableList.of(randomAlphaOfLength(5)), ImmutableList.of(randomAlphaOfLength(5))), - searchSourceBuilder - ); + ParseUtils + .addUserBackendRolesFilter( + new User(randomAlphaOfLength(5), null, ImmutableList.of(randomAlphaOfLength(5)), ImmutableList.of(randomAlphaOfLength(5))), + searchSourceBuilder + ); assertEquals( "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":[]," + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," @@ -147,15 +147,16 @@ public void testAddUserRoleFilterWithNullUserBackendRole() { public void testAddUserRoleFilterWithEmptyUserBackendRole() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - addUserBackendRolesFilter( - new User( - randomAlphaOfLength(5), - ImmutableList.of(), - ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(randomAlphaOfLength(5)) - ), - searchSourceBuilder - ); + ParseUtils + .addUserBackendRolesFilter( + new User( + randomAlphaOfLength(5), + ImmutableList.of(), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + searchSourceBuilder + ); assertEquals( "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":[]," + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," @@ -168,15 +169,16 @@ public void testAddUserRoleFilterWithNormalUserBackendRole() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); String backendRole1 = randomAlphaOfLength(5); String backendRole2 = randomAlphaOfLength(5); - addUserBackendRolesFilter( - new User( - randomAlphaOfLength(5), - ImmutableList.of(backendRole1, backendRole2), - ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(randomAlphaOfLength(5)) - ), - searchSourceBuilder - ); + ParseUtils + .addUserBackendRolesFilter( + new User( + randomAlphaOfLength(5), + ImmutableList.of(backendRole1, backendRole2), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + searchSourceBuilder + ); assertEquals( "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":" + "[\"" diff --git a/src/test/java/org/opensearch/forecast/model/ForecastTaskSerializationTests.java b/src/test/java/org/opensearch/forecast/model/ForecastTaskSerializationTests.java index 28ec18bab..b097d67d8 100644 --- a/src/test/java/org/opensearch/forecast/model/ForecastTaskSerializationTests.java +++ b/src/test/java/org/opensearch/forecast/model/ForecastTaskSerializationTests.java @@ -49,7 +49,7 @@ public void testConstructor_allFieldsPresent() throws IOException { assertEquals("task123", readTask.getTaskId()); assertEquals("FORECAST_HISTORICAL_HC_ENTITY", readTask.getTaskType()); - assertTrue(readTask.isEntityTask()); + assertTrue(readTask.isHistoricalEntityTask()); assertEquals("config123", readTask.getConfigId()); assertEquals(originalTask.getForecaster(), readTask.getForecaster()); assertEquals("Running", readTask.getState()); @@ -93,7 +93,7 @@ public void testConstructor_missingOptionalFields() throws IOException { assertEquals("task123", readTask.getTaskId()); assertEquals("FORECAST_HISTORICAL_HC_ENTITY", readTask.getTaskType()); - assertTrue(readTask.isEntityTask()); + assertTrue(readTask.isHistoricalEntityTask()); assertEquals("config123", readTask.getConfigId()); assertEquals(null, readTask.getForecaster()); assertEquals("Running", readTask.getState()); diff --git a/src/test/java/org/opensearch/forecast/model/ForecastTaskTypeTests.java b/src/test/java/org/opensearch/forecast/model/ForecastTaskTypeTests.java index 4ee403a0e..db309f886 100644 --- a/src/test/java/org/opensearch/forecast/model/ForecastTaskTypeTests.java +++ b/src/test/java/org/opensearch/forecast/model/ForecastTaskTypeTests.java @@ -11,28 +11,16 @@ public class ForecastTaskTypeTests extends OpenSearchTestCase { - public void testHistoricalForecasterTaskTypes() { + public void testRunOnceForecasterTaskTypes() { assertEquals( - Arrays.asList(ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM), - ForecastTaskType.HISTORICAL_FORECASTER_TASK_TYPES - ); - } - - public void testAllHistoricalTaskTypes() { - assertEquals( - Arrays - .asList( - ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, - ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY - ), - ForecastTaskType.ALL_HISTORICAL_TASK_TYPES + Arrays.asList(ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER, ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM), + ForecastTaskType.RUN_ONCE_TASK_TYPES ); } public void testRealtimeTaskTypes() { assertEquals( - Arrays.asList(ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER), + Arrays.asList(ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM, ForecastTaskType.REALTIME_FORECAST_HC_FORECASTER), ForecastTaskType.REALTIME_TASK_TYPES ); } @@ -41,11 +29,10 @@ public void testAllForecastTaskTypes() { assertEquals( Arrays .asList( - ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, - ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, - ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY + ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM, + ForecastTaskType.REALTIME_FORECAST_HC_FORECASTER, + ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER, + ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM ), ForecastTaskType.ALL_FORECAST_TASK_TYPES ); diff --git a/src/test/java/org/opensearch/forecast/model/ForecasterTests.java b/src/test/java/org/opensearch/forecast/model/ForecasterTests.java index 0b64912bf..b47110d82 100644 --- a/src/test/java/org/opensearch/forecast/model/ForecasterTests.java +++ b/src/test/java/org/opensearch/forecast/model/ForecasterTests.java @@ -54,9 +54,11 @@ public class ForecasterTests extends AbstractTimeSeriesTest { User user = new User("testUser", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); String resultIndex = null; Integer horizon = 1; + int recencyEmphasis = 20; + int seasonality = 20; public void testForecasterConstructor() { - ImputationOption imputationOption = TestHelpers.randomImputationOption(); + ImputationOption imputationOption = TestHelpers.randomImputationOption(0); Forecaster forecaster = new Forecaster( forecasterId, @@ -77,7 +79,10 @@ public void testForecasterConstructor() { user, resultIndex, horizon, - imputationOption + imputationOption, + recencyEmphasis, + seasonality, + randomIntBetween(1, 1000) ); assertEquals(forecasterId, forecaster.getId()); @@ -124,7 +129,10 @@ public void testForecasterConstructorWithNullForecastInterval() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(0), + recencyEmphasis, + seasonality, + randomIntBetween(1, 1000) ); }); @@ -156,7 +164,10 @@ public void testNegativeInterval() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(0), + recencyEmphasis, + seasonality, + randomIntBetween(1, 1000) ); }); @@ -188,7 +199,10 @@ public void testMaxCategoryFieldsLimits() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(0), + recencyEmphasis, + seasonality, + randomIntBetween(1, 1000) ); }); @@ -220,7 +234,10 @@ public void testBlankName() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(0), + recencyEmphasis, + seasonality, + randomIntBetween(1, 1000) ); }); @@ -252,7 +269,10 @@ public void testInvalidCustomResultIndex() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(0), + recencyEmphasis, + seasonality, + randomIntBetween(1, 1000) ); }); @@ -283,7 +303,10 @@ public void testValidCustomResultIndex() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(0), + recencyEmphasis, + seasonality, + randomIntBetween(1, 1000) ); assertEquals(resultIndex, forecaster.getCustomResultIndex()); @@ -312,7 +335,10 @@ public void testInvalidHorizon() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(0), + recencyEmphasis, + seasonality, + randomIntBetween(1, 1000) ); }); @@ -383,14 +409,4 @@ public void testParseNullImpute() throws IOException { Forecaster parsedForecaster = Forecaster.parse(TestHelpers.parser(forecasterString)); assertEquals("Parsing forecaster doesn't work", forecaster, parsedForecaster); } - - public void testGetImputer() throws IOException { - Forecaster forecaster = TestHelpers.randomForecaster(); - assertTrue(null != forecaster.getImputer()); - } - - public void testGetImputerNullImputer() throws IOException { - Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setNullImputationOption().build(); - assertTrue(null != forecaster.getImputer()); - } } diff --git a/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java b/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java index dda3a8761..28dd54e51 100644 --- a/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java +++ b/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java @@ -15,16 +15,4 @@ public void testIsForecastEnabled() { assertTrue(!ForecastEnabledSetting.isForecastEnabled()); } - public void testIsForecastBreakerEnabled() { - assertTrue(ForecastEnabledSetting.isForecastBreakerEnabled()); - ForecastEnabledSetting.getInstance().setSettingValue(ForecastEnabledSetting.FORECAST_BREAKER_ENABLED, false); - assertTrue(!ForecastEnabledSetting.isForecastBreakerEnabled()); - } - - public void testIsDoorKeeperInCacheEnabled() { - assertTrue(!ForecastEnabledSetting.isDoorKeeperInCacheEnabled()); - ForecastEnabledSetting.getInstance().setSettingValue(ForecastEnabledSetting.FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, true); - assertTrue(ForecastEnabledSetting.isDoorKeeperInCacheEnabled()); - } - } diff --git a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java index 7f024ef6d..b3e3b7150 100644 --- a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java +++ b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java @@ -28,13 +28,12 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.AbstractProfileRunnerTests; import org.opensearch.ad.AnomalyDetectorProfileRunner; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileNodeResponse; -import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.cluster.ClusterName; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.BigArrays; @@ -48,6 +47,8 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileResponse; import org.opensearch.timeseries.util.SecurityClientUtil; /** @@ -85,7 +86,8 @@ private void setUpMultiEntityClientGet(DetectorStatus detectorStatus, JobStatus nodeFilter, requiredSamples, transportService, - adTaskManager + adTaskManager, + mock(ADTaskProfileRunner.class) ); doAnswer(invocation -> { @@ -169,7 +171,7 @@ private void setUpMultiEntityClientSearch(ADResultStatus resultStatus, Cardinali for (int i = 0; i < 100; i++) { hyperLogLog.collect(0, BitMixer.mix64(randomIntBetween(1, 100))); } - aggs.add(new InternalCardinality(ADCommonName.TOTAL_ENTITIES, hyperLogLog, new HashMap<>())); + aggs.add(new InternalCardinality(CommonName.TOTAL_ENTITIES, hyperLogLog, new HashMap<>())); when(response.getAggregations()).thenReturn(InternalAggregations.from(aggs)); listener.onResponse(response); break; @@ -204,7 +206,7 @@ private void setUpProfileAction() { listener.onResponse(new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, Collections.emptyList())); return null; - }).when(client).execute(eq(ProfileAction.INSTANCE), any(), any()); + }).when(client).execute(eq(ADProfileAction.INSTANCE), any(), any()); } public void testFailGetEntityStats() throws IOException, InterruptedException { diff --git a/src/test/java/org/opensearch/timeseries/NodeStateManagerTests.java b/src/test/java/org/opensearch/timeseries/NodeStateManagerTests.java index e52255818..7b196a9af 100644 --- a/src/test/java/org/opensearch/timeseries/NodeStateManagerTests.java +++ b/src/test/java/org/opensearch/timeseries/NodeStateManagerTests.java @@ -192,9 +192,9 @@ private void setupCheckpoint(boolean responseExists) throws IOException { doAnswer(invocation -> { Object[] args = invocation.getArguments(); assertTrue( - String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), - args.length >= 2 - ); + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length >= 2 + ); GetRequest request = null; ActionListener listener = null; diff --git a/src/test/java/org/opensearch/timeseries/TestHelpers.java b/src/test/java/org/opensearch/timeseries/TestHelpers.java index 685f3a07e..f0c034ec4 100644 --- a/src/test/java/org/opensearch/timeseries/TestHelpers.java +++ b/src/test/java/org/opensearch/timeseries/TestHelpers.java @@ -18,7 +18,14 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; -import static org.opensearch.test.OpenSearchTestCase.*; +import static org.opensearch.test.OpenSearchTestCase.buildNewFakeTransportAddress; +import static org.opensearch.test.OpenSearchTestCase.randomAlphaOfLength; +import static org.opensearch.test.OpenSearchTestCase.randomBoolean; +import static org.opensearch.test.OpenSearchTestCase.randomDouble; +import static org.opensearch.test.OpenSearchTestCase.randomDoubleBetween; +import static org.opensearch.test.OpenSearchTestCase.randomInt; +import static org.opensearch.test.OpenSearchTestCase.randomIntBetween; +import static org.opensearch.test.OpenSearchTestCase.randomLong; import java.io.IOException; import java.nio.ByteBuffer; @@ -59,8 +66,6 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.constant.CommonValue; -import org.opensearch.ad.feature.Features; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.mock.model.MockSimpleLog; @@ -71,10 +76,8 @@ import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.AnomalyResultBucket; import org.opensearch.ad.model.DetectorInternalState; -import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.ad.model.ExpectedValueList; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.Request; @@ -134,9 +137,12 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.constant.CommonValue; import org.opensearch.timeseries.dataprocessor.ImputationMethod; import org.opensearch.timeseries.dataprocessor.ImputationOption; +import org.opensearch.timeseries.feature.Features; import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigValidationIssue; import org.opensearch.timeseries.model.DataByFeatureId; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; @@ -148,6 +154,7 @@ import org.opensearch.timeseries.model.TimeConfiguration; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.google.common.collect.ImmutableList; @@ -321,7 +328,11 @@ public static AnomalyDetector randomAnomalyDetector( categoryFields, user, null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption((int) features.stream().filter(Feature::getEnabled).count()), + randomIntBetween(1, 10000), + randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE * 2), + randomIntBetween(1, 1000), + null ); } @@ -366,7 +377,11 @@ public static AnomalyDetector randomDetector( categoryFields, null, resultIndex, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption((int) features.stream().filter(Feature::getEnabled).count()), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); } @@ -421,7 +436,11 @@ public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields( categoryFields, randomUser(), resultIndex, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(1), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); } @@ -452,7 +471,11 @@ public static AnomalyDetector randomAnomalyDetector(String timefield, String ind null, randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption((int) features.stream().filter(Feature::getEnabled).count()), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); } @@ -475,7 +498,11 @@ public static AnomalyDetector randomAnomalyDetectorWithEmptyFeature() throws IOE null, randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); } @@ -485,6 +512,7 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguration interval, boolean hcDetector) throws IOException { List categoryField = hcDetector ? ImmutableList.of(randomAlphaOfLength(5)) : null; + Feature feature = randomFeature(); return new AnomalyDetector( randomAlphaOfLength(10), randomLong(), @@ -492,7 +520,7 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio randomAlphaOfLength(30), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), - ImmutableList.of(randomFeature()), + ImmutableList.of(feature), randomQuery(), interval, randomIntervalTimeConfiguration(), @@ -503,7 +531,11 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio categoryField, randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); } @@ -650,7 +682,11 @@ public AnomalyDetector build() { categoryFields, user, resultIndex, - imputationOption + imputationOption, + randomIntBetween(1, 10000), + randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE * 2), + randomIntBetween(1, 1000), + null ); } } @@ -676,7 +712,11 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio categoryField, randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(featureEnabled ? 1 : 0), + randomIntBetween(1, 10000), + randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), + randomIntBetween(1, 1000), + null ); } @@ -888,8 +928,8 @@ public static AnomalyResult randomHCADAnomalyDetectResult(double score, double g return randomHCADAnomalyDetectResult(score, grade, null); } - public static ResultWriteRequest randomResultWriteRequest(String detectorId, double score, double grade) { - ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + public static ADResultWriteRequest randomADResultWriteRequest(String detectorId, double score, double grade) { + ADResultWriteRequest resultWriteRequest = new ADResultWriteRequest( Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), detectorId, RequestPriority.MEDIUM, @@ -980,7 +1020,8 @@ public static Job randomAnomalyDetectorJob(boolean enabled, Instant enabledTime, Instant.now().truncatedTo(ChronoUnit.SECONDS), 60L, randomUser(), - null + null, + AnalysisType.AD ); } @@ -1187,6 +1228,15 @@ public static GetResponse createBrokenGetResponse(String id, String indexName) t ); } + public static GetResponse createGetResponse(Map source, String id, String indexName) throws IOException { + XContentBuilder xContent = XContentFactory.jsonBuilder(); + xContent.map(source); + BytesReference documentSource = BytesReference.bytes(xContent); + return new GetResponse( + new GetResult(indexName, id, UNASSIGNED_SEQ_NO, 0, -1, true, documentSource, Collections.emptyMap(), Collections.emptyMap()) + ); + } + public static SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); @@ -1517,8 +1567,8 @@ public static Map parseStatsResult(String statsResult) throws IO return adStats; } - public static DetectorValidationIssue randomDetectorValidationIssue() { - DetectorValidationIssue issue = new DetectorValidationIssue( + public static ConfigValidationIssue randomDetectorValidationIssue() { + ConfigValidationIssue issue = new ConfigValidationIssue( ValidationAspect.DETECTOR, ValidationIssueType.NAME, randomAlphaOfLength(5) @@ -1526,8 +1576,8 @@ public static DetectorValidationIssue randomDetectorValidationIssue() { return issue; } - public static DetectorValidationIssue randomDetectorValidationIssueWithSubIssues(Map subIssues) { - DetectorValidationIssue issue = new DetectorValidationIssue( + public static ConfigValidationIssue randomDetectorValidationIssueWithSubIssues(Map subIssues) { + ConfigValidationIssue issue = new ConfigValidationIssue( ValidationAspect.DETECTOR, ValidationIssueType.NAME, randomAlphaOfLength(5), @@ -1537,8 +1587,8 @@ public static DetectorValidationIssue randomDetectorValidationIssueWithSubIssues return issue; } - public static DetectorValidationIssue randomDetectorValidationIssueWithDetectorIntervalRec(long intervalRec) { - DetectorValidationIssue issue = new DetectorValidationIssue( + public static ConfigValidationIssue randomDetectorValidationIssueWithDetectorIntervalRec(long intervalRec) { + ConfigValidationIssue issue = new ConfigValidationIssue( ValidationAspect.MODEL, ValidationIssueType.DETECTION_INTERVAL, CommonMessages.INTERVAL_REC + intervalRec, @@ -1585,8 +1635,8 @@ public static ClusterState createClusterState() { return clusterState; } - public static ImputationOption randomImputationOption() { - double[] defaultFill = DoubleStream.generate(OpenSearchTestCase::randomDouble).limit(10).toArray(); + public static ImputationOption randomImputationOption(int featureSize) { + double[] defaultFill = DoubleStream.generate(OpenSearchTestCase::randomDouble).limit(featureSize).toArray(); ImputationOption fixedValue = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill), false); ImputationOption linear = new ImputationOption(ImputationMethod.LINEAR, Optional.of(defaultFill), false); ImputationOption linearIntSensitive = new ImputationOption(ImputationMethod.LINEAR, Optional.of(defaultFill), true); @@ -1640,7 +1690,7 @@ public static class ForecasterBuilder { user = randomUser(); resultIndex = null; horizon = randomIntBetween(1, 20); - imputationOption = randomImputationOption(); + imputationOption = randomImputationOption((int) features.stream().filter(Feature::getEnabled).count()); } public static ForecasterBuilder newInstance() throws IOException { @@ -1757,7 +1807,10 @@ public Forecaster build() { user, resultIndex, horizon, - imputationOption + imputationOption, + randomInt(), + randomInt(), + randomIntBetween(1, 1000) ); } } @@ -1782,7 +1835,10 @@ public static Forecaster randomForecaster() throws IOException { randomUser(), null, randomIntBetween(1, 20), - randomImputationOption() + randomImputationOption(1), + randomInt(), + randomInt(), + randomIntBetween(1, 1000) ); } diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/FixedValueImputerTests.java b/src/test/java/org/opensearch/timeseries/dataprocessor/FixedValueImputerTests.java deleted file mode 100644 index 81b9b5bfb..000000000 --- a/src/test/java/org/opensearch/timeseries/dataprocessor/FixedValueImputerTests.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.timeseries.dataprocessor; - -import static org.junit.Assert.assertArrayEquals; - -import org.junit.Test; - -public class FixedValueImputerTests { - - @Test - public void testImpute() { - // Initialize the FixedValueImputer with some fixed values - double[] fixedValues = { 2.0, 3.0 }; - FixedValueImputer imputer = new FixedValueImputer(fixedValues); - - // Create a sample array with some missing values (Double.NaN) - double[][] samples = { { 1.0, Double.NaN, 3.0 }, { Double.NaN, 2.0, 3.0 } }; - - // Call the impute method - double[][] imputed = imputer.impute(samples, 3); - - // Check the results - double[][] expected = { { 1.0, 2.0, 3.0 }, { 3.0, 2.0, 3.0 } }; - double delta = 0.0001; - - for (int i = 0; i < expected.length; i++) { - assertArrayEquals("The arrays are not equal", expected[i], imputed[i], delta); - } - } -} diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputerTests.java b/src/test/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputerTests.java deleted file mode 100644 index fb39d83f2..000000000 --- a/src/test/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputerTests.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.timeseries.dataprocessor; - -import java.util.Arrays; - -import org.opensearch.test.OpenSearchTestCase; - -public class PreviousValueImputerTests extends OpenSearchTestCase { - public void testSingleFeatureImpute() { - PreviousValueImputer imputer = new PreviousValueImputer(); - - double[] samples = { 1.0, Double.NaN, 3.0, Double.NaN, 5.0 }; - double[] expected = { 1.0, 1.0, 3.0, 3.0, 5.0 }; - - assertTrue("Imputation failed", Arrays.equals(expected, imputer.singleFeatureImpute(samples, 0))); - - // The second test checks whether the method removes leading Double.NaN values from the array - samples = new double[] { Double.NaN, 2.0, Double.NaN, 4.0 }; - expected = new double[] { Double.NaN, 2.0, 2.0, 4.0 }; - - assertTrue("Imputation failed with leading NaN", Arrays.equals(expected, imputer.singleFeatureImpute(samples, 0))); - } -} diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/ZeroImputerTests.java b/src/test/java/org/opensearch/timeseries/dataprocessor/ZeroImputerTests.java deleted file mode 100644 index 8e03821e2..000000000 --- a/src/test/java/org/opensearch/timeseries/dataprocessor/ZeroImputerTests.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.timeseries.dataprocessor; - -import static org.junit.Assert.assertArrayEquals; - -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; - -import junitparams.JUnitParamsRunner; -import junitparams.Parameters; - -@RunWith(JUnitParamsRunner.class) -public class ZeroImputerTests { - - private Imputer imputer; - - @Before - public void setup() { - imputer = new ZeroImputer(); - } - - private Object[] imputeData() { - return new Object[] { - new Object[] { new double[] { 25.25, Double.NaN, 25.75 }, 3, new double[] { 25.25, 0, 25.75 } }, - new Object[] { new double[] { Double.NaN, 25, 75 }, 3, new double[] { 0, 25, 75 } }, - new Object[] { new double[] { 25, 75.5, Double.NaN }, 3, new double[] { 25, 75.5, 0 } }, }; - } - - @Test - @Parameters(method = "imputeData") - public void impute_returnExpected(double[] samples, int num, double[] expected) { - assertArrayEquals("The arrays are not equal", expected, imputer.singleFeatureImpute(samples, num), 0.001); - } -} diff --git a/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoTests.java b/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoTests.java index 9731d31b5..ad907c448 100644 --- a/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoTests.java @@ -263,7 +263,7 @@ public void getLatestDataTime_returnExpectedToListener() { when(searchResponse.getAggregations()).thenReturn(internalAggregations); ActionListener> listener = mock(ActionListener.class); - searchFeatureDao.getLatestDataTime(detector, listener); + searchFeatureDao.getLatestDataTime(detector, Optional.empty(), AnalysisType.AD, listener); ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(captor.capture()); diff --git a/src/test/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSettingTests.java b/src/test/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSettingTests.java new file mode 100644 index 000000000..ae107c7e9 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSettingTests.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.settings; + +import org.opensearch.test.OpenSearchTestCase; + +public class TimeSeriesEnabledSettingTests extends OpenSearchTestCase { + public void testIsForecastBreakerEnabled() { + assertTrue(TimeSeriesEnabledSetting.isBreakerEnabled()); + TimeSeriesEnabledSetting.getInstance().setSettingValue(TimeSeriesEnabledSetting.BREAKER_ENABLED, false); + assertTrue(!TimeSeriesEnabledSetting.isBreakerEnabled()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/ADResultBulkTransportActionTests.java similarity index 85% rename from src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/ADResultBulkTransportActionTests.java index 9887f1aff..e497988fc 100644 --- a/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/ADResultBulkTransportActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; @@ -32,6 +32,9 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.ADResultBulkTransportAction; +import org.opensearch.ad.transport.AnomalyResultTests; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -41,6 +44,7 @@ import org.opensearch.index.IndexingPressure; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.transport.TransportService; public class ADResultBulkTransportActionTests extends AbstractTimeSeriesTest { @@ -98,8 +102,8 @@ public void testSendAll() { when(indexingPressure.getCurrentReplicaBytes()).thenReturn(0L); ADResultBulkRequest originalRequest = new ADResultBulkRequest(); - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 0.8d, 0d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 8d, 0.2d)); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -118,7 +122,7 @@ public void testSendAll() { return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -131,8 +135,8 @@ public void testSendPartial() { when(indexingPressure.getCurrentReplicaBytes()).thenReturn(24L); ADResultBulkRequest originalRequest = new ADResultBulkRequest(); - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 0.8d, 0d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 8d, 0.2d)); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -151,7 +155,7 @@ public void testSendPartial() { return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -165,10 +169,10 @@ public void testSendRandomPartial() { ADResultBulkRequest originalRequest = new ADResultBulkRequest(); for (int i = 0; i < 1000; i++) { - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 0.8d, 0d)); } - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 8d, 0.2d)); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -190,7 +194,7 @@ public void testSendRandomPartial() { return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -198,8 +202,8 @@ public void testSendRandomPartial() { public void testSerialzationRequest() throws IOException { ADResultBulkRequest request = new ADResultBulkRequest(); - request.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); - request.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + request.add(TestHelpers.randomADResultWriteRequest(detectorId, 0.8d, 0d)); + request.add(TestHelpers.randomADResultWriteRequest(detectorId, 8d, 0.2d)); BytesStreamOutput output = new BytesStreamOutput(); request.writeTo(output); @@ -210,6 +214,6 @@ public void testSerialzationRequest() throws IOException { public void testValidateRequest() { ActionRequestValidationException e = new ADResultBulkRequest().validate(); - assertThat(e.validationErrors(), hasItem(ADResultBulkRequest.NO_REQUESTS_ADDED_ERR)); + assertThat(e.validationErrors(), hasItem(CommonMessages.NO_REQUESTS_ADDED_ERR)); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/ADStatsNodesTransportActionTests.java similarity index 68% rename from src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/ADStatsNodesTransportActionTests.java index 2284c311e..8bcc0163d 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/ADStatsNodesTransportActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -26,18 +26,13 @@ import org.junit.Before; import org.junit.Test; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.stats.ADStat; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.InternalStatNames; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; +import org.opensearch.ad.stats.suppliers.ADModelsOnNodeSupplier; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.IndexUtils; +import org.opensearch.ad.transport.ADStatsNodesTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -47,14 +42,19 @@ import org.opensearch.monitor.jvm.JvmStats; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.stats.InternalStatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; +import org.opensearch.timeseries.util.IndexUtils; import org.opensearch.transport.TransportService; public class ADStatsNodesTransportActionTests extends OpenSearchIntegTestCase { private ADStatsNodesTransportAction action; private ADStats adStats; - private Map> statsMap; + private Map> statsMap; private String clusterStatName1, clusterStatName2; private String nodeStatName1, nodeStatName2; private ADTaskManager adTaskManager; @@ -68,10 +68,10 @@ public void setUp() throws Exception { Clock clock = mock(Clock.class); ThreadPool threadPool = mock(ThreadPool.class); IndexNameExpressionResolver indexNameResolver = mock(IndexNameExpressionResolver.class); - IndexUtils indexUtils = new IndexUtils(client, new ClientUtil(client), clusterService(), indexNameResolver); - ModelManager modelManager = mock(ModelManager.class); - CacheProvider cacheProvider = mock(CacheProvider.class); - EntityCache cache = mock(EntityCache.class); + IndexUtils indexUtils = new IndexUtils(clusterService(), indexNameResolver); + ADModelManager modelManager = mock(ADModelManager.class); + ADCacheProvider cacheProvider = mock(ADCacheProvider.class); + ADPriorityCache cache = mock(ADPriorityCache.class); when(cacheProvider.get()).thenReturn(cache); clusterStatName1 = "clusterStat1"; @@ -87,13 +87,16 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - statsMap = new HashMap>() { + statsMap = new HashMap>() { { - put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); - put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService))); - put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); - put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); - put(InternalStatNames.JVM_HEAP_USAGE.getName(), new ADStat<>(true, new SettableSupplier())); + put(nodeStatName1, new TimeSeriesStat<>(false, new CounterSupplier())); + put( + nodeStatName2, + new TimeSeriesStat<>(false, new ADModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService)) + ); + put(clusterStatName1, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); + put(clusterStatName2, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); + put(InternalStatNames.JVM_HEAP_USAGE.getName(), new TimeSeriesStat<>(true, new SettableSupplier())); } }; @@ -121,10 +124,10 @@ public void setUp() throws Exception { @Test public void testNewNodeRequest() { String nodeId = "nodeId1"; - ADStatsRequest adStatsRequest = new ADStatsRequest(nodeId); + StatsRequest adStatsRequest = new StatsRequest(nodeId); - ADStatsNodeRequest adStatsNodeRequest1 = new ADStatsNodeRequest(adStatsRequest); - ADStatsNodeRequest adStatsNodeRequest2 = action.newNodeRequest(adStatsRequest); + StatsNodeRequest adStatsNodeRequest1 = new StatsNodeRequest(adStatsRequest); + StatsNodeRequest adStatsNodeRequest2 = action.newNodeRequest(adStatsRequest); assertEquals(adStatsNodeRequest1.getADStatsRequest(), adStatsNodeRequest2.getADStatsRequest()); } @@ -132,7 +135,7 @@ public void testNewNodeRequest() { @Test public void testNodeOperation() { String nodeId = clusterService().localNode().getId(); - ADStatsRequest adStatsRequest = new ADStatsRequest((nodeId)); + StatsRequest adStatsRequest = new StatsRequest((nodeId)); adStatsRequest.clear(); Set statsToBeRetrieved = new HashSet<>(Arrays.asList(nodeStatName1, nodeStatName2)); @@ -141,7 +144,7 @@ public void testNodeOperation() { adStatsRequest.addStat(stat); } - ADStatsNodeResponse response = action.nodeOperation(new ADStatsNodeRequest(adStatsRequest)); + StatsNodeResponse response = action.nodeOperation(new StatsNodeRequest(adStatsRequest)); Map stats = response.getStatsMap(); @@ -154,7 +157,7 @@ public void testNodeOperation() { @Test public void testNodeOperationWithJvmHeapUsage() { String nodeId = clusterService().localNode().getId(); - ADStatsRequest adStatsRequest = new ADStatsRequest((nodeId)); + StatsRequest adStatsRequest = new StatsRequest((nodeId)); adStatsRequest.clear(); Set statsToBeRetrieved = new HashSet<>(Arrays.asList(nodeStatName1, InternalStatNames.JVM_HEAP_USAGE.getName())); @@ -163,7 +166,7 @@ public void testNodeOperationWithJvmHeapUsage() { adStatsRequest.addStat(stat); } - ADStatsNodeResponse response = action.nodeOperation(new ADStatsNodeRequest(adStatsRequest)); + StatsNodeResponse response = action.nodeOperation(new StatsNodeRequest(adStatsRequest)); Map stats = response.getStatsMap(); diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobActionTests.java similarity index 81% rename from src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobActionTests.java index 79bc66527..67fd5b8cf 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -24,10 +24,10 @@ import org.junit.Before; import org.junit.Test; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.ExecuteADResultResponseRecorder; -import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyDetectorJobAction; +import org.opensearch.ad.transport.AnomalyDetectorJobTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -41,13 +41,12 @@ import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.model.DateRange; -import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.transport.TransportService; public class AnomalyDetectorJobActionTests extends OpenSearchIntegTestCase { private AnomalyDetectorJobTransportAction action; private Task task; - private AnomalyDetectorJobRequest request; + private JobRequest request; private ActionListener response; @Override @@ -75,13 +74,11 @@ public void setUp() throws Exception { client, clusterService, indexSettings(), - mock(ADIndexManagement.class), xContentRegistry(), - mock(ADTaskManager.class), - mock(ExecuteADResultResponseRecorder.class) + mock(ADIndexJobActionHandler.class) ); task = mock(Task.class); - request = new AnomalyDetectorJobRequest("1234", 4567, 7890, "_start"); + request = new JobRequest("1234", new DateRange(Instant.ofEpochMilli(4567), Instant.ofEpochMilli(7890)), true, "_start"); response = new ActionListener() { @Override public void onResponse(JobResponse adResponse) { @@ -104,7 +101,12 @@ public void testStartAdJobTransportAction() { @Test public void testStopAdJobTransportAction() { - AnomalyDetectorJobRequest stopRequest = new AnomalyDetectorJobRequest("1234", 4567, 7890, "_stop"); + JobRequest stopRequest = new JobRequest( + "1234", + new DateRange(Instant.ofEpochMilli(4567), Instant.ofEpochMilli(7890)), + true, + "_stop" + ); action.doExecute(task, stopRequest, response); } @@ -117,13 +119,13 @@ public void testAdJobAction() { @Test public void testAdJobRequest() throws IOException { DateRange detectionDateRange = new DateRange(Instant.MIN, Instant.now()); - request = new AnomalyDetectorJobRequest("1234", detectionDateRange, false, 4567, 7890, "_start"); + request = new JobRequest("1234", detectionDateRange, false, "_start"); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - AnomalyDetectorJobRequest newRequest = new AnomalyDetectorJobRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + JobRequest newRequest = new JobRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); } @Test @@ -131,8 +133,8 @@ public void testAdJobRequest_NullDetectionDateRange() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - AnomalyDetectorJobRequest newRequest = new AnomalyDetectorJobRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + JobRequest newRequest = new JobRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); } @Test diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobTransportActionTests.java similarity index 82% rename from src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobTransportActionTests.java index 6f7629039..df82df9a7 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobTransportActionTests.java @@ -9,16 +9,13 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; -import static org.opensearch.ad.constant.ADCommonMessages.DETECTOR_IS_RUNNING; import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR; -import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM; -import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.opensearch.timeseries.TestHelpers.HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; +import static org.opensearch.timeseries.constant.CommonMessages.CONFIG_IS_RUNNING; import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; @@ -44,22 +41,26 @@ import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.mock.transport.MockAnomalyDetectorJobAction; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.transport.AnomalyDetectorJobAction; +import org.opensearch.ad.transport.GetAnomalyDetectorAction; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.ad.transport.StatsAnomalyDetectorAction; import org.opensearch.client.Client; import org.opensearch.common.lucene.uid.Versions; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.index.IndexNotFoundException; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.model.TaskState; import org.opensearch.timeseries.stats.StatNames; -import org.opensearch.timeseries.transport.JobResponse; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -98,7 +99,7 @@ protected Settings nodeSettings(int nodeOrdinal) { public void testDetectorIndexNotFound() { deleteDetectorIndex(); String detectorId = randomAlphaOfLength(5); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); IndexNotFoundException exception = expectThrows( IndexNotFoundException.class, () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(3000) @@ -108,12 +109,12 @@ public void testDetectorIndexNotFound() { public void testDetectorNotFound() { String detectorId = randomAlphaOfLength(5); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); OpenSearchStatusException exception = expectThrows( OpenSearchStatusException.class, () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000) ); - assertTrue(exception.getMessage().contains(FAIL_TO_FIND_CONFIG_MSG)); + assertTrue(exception.getMessage().contains(CommonMessages.FAIL_TO_FIND_CONFIG_MSG)); } public void testValidHistoricalAnalysis() throws IOException, InterruptedException { @@ -127,14 +128,7 @@ public void testStartHistoricalAnalysisWithUser() throws IOException { AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); Client nodeClient = getDataNodeClient(); if (nodeClient != null) { JobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100000); @@ -156,14 +150,7 @@ public void testStartHistoricalAnalysisForSingleCategoryHCWithUser() throws IOEx ImmutableList.of(categoryField) ); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); Client nodeClient = getDataNodeClient(); if (nodeClient != null) { @@ -208,14 +195,7 @@ public void testStartHistoricalAnalysisForMultiCategoryHCWithUser() throws IOExc ImmutableList.of(categoryField, ipField) ); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); Client nodeClient = getDataNodeClient(); if (nodeClient != null) { @@ -253,7 +233,7 @@ public void testRunMultipleTasksForHistoricalAnalysis() throws IOException, Inte AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); assertNotNull(response.getId()); OpenSearchStatusException exception = null; @@ -263,7 +243,7 @@ public void testRunMultipleTasksForHistoricalAnalysis() throws IOException, Inte OpenSearchStatusException.class, () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000) ); - if (exception.getMessage().contains(DETECTOR_IS_RUNNING)) { + if (exception.getMessage().contains(CONFIG_IS_RUNNING)) { break; } else { logger.error("Unexpected error happened when rerun detector", exception); @@ -271,8 +251,8 @@ public void testRunMultipleTasksForHistoricalAnalysis() throws IOException, Inte Thread.sleep(1000); } assertNotNull(exception); - assertTrue(exception.getMessage().contains(DETECTOR_IS_RUNNING)); - assertEquals(DETECTOR_IS_RUNNING, exception.getMessage()); + assertTrue(exception.getMessage().contains(CONFIG_IS_RUNNING)); + assertEquals(CONFIG_IS_RUNNING, exception.getMessage()); Thread.sleep(20000); List adTasks = searchADTasks(detectorId, null, 100); assertEquals(1, adTasks.size()); @@ -283,14 +263,7 @@ public void testRaceConditionByStartingMultipleTasks() throws IOException, Inter AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); client().execute(AnomalyDetectorJobAction.INSTANCE, request); client().execute(AnomalyDetectorJobAction.INSTANCE, request); @@ -318,14 +291,7 @@ public void testCleanOldTaskDocs() throws InterruptedException, IOException { long count = countDocs(ADCommonName.DETECTION_STATE_INDEX); assertEquals(states.size(), count); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - randomLong(), - randomLong(), - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); AtomicReference response = new AtomicReference<>(); CountDownLatch latch = new CountDownLatch(1); @@ -369,7 +335,7 @@ private List startRealtimeDetector() throws IOException { AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, null); + JobRequest request = startDetectorJobRequest(detectorId, null); JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); String jobId = response.getId(); assertEquals(detectorId, jobId); @@ -378,29 +344,29 @@ private List startRealtimeDetector() throws IOException { public void testRealtimeDetectorWithoutFeature() throws IOException { AnomalyDetector detector = TestHelpers.randomDetector(ImmutableList.of(), testIndex, detectionIntervalInMinutes, timeField); - testInvalidDetector(detector, "Can't start detector job as no features configured"); + testInvalidDetector(detector, "Can't start job as no features configured"); } public void testHistoricalDetectorWithoutFeature() throws IOException { AnomalyDetector detector = TestHelpers.randomDetector(ImmutableList.of(), testIndex, detectionIntervalInMinutes, timeField); - testInvalidDetector(detector, "Can't start detector job as no features configured"); + testInvalidDetector(detector, "Can't start job as no features configured"); } public void testRealtimeDetectorWithoutEnabledFeature() throws IOException { AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(TestHelpers.randomFeature(false)), testIndex, detectionIntervalInMinutes, timeField); - testInvalidDetector(detector, "Can't start detector job as no enabled features configured"); + testInvalidDetector(detector, "Can't start job as no enabled features configured"); } public void testHistoricalDetectorWithoutEnabledFeature() throws IOException { AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(TestHelpers.randomFeature(false)), testIndex, detectionIntervalInMinutes, timeField); - testInvalidDetector(detector, "Can't start detector job as no enabled features configured"); + testInvalidDetector(detector, "Can't start job as no enabled features configured"); } private void testInvalidDetector(AnomalyDetector detector, String error) throws IOException { String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); OpenSearchStatusException exception = expectThrows( OpenSearchStatusException.class, () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000) @@ -408,12 +374,12 @@ private void testInvalidDetector(AnomalyDetector detector, String error) throws assertEquals(error, exception.getMessage()); } - private AnomalyDetectorJobRequest startDetectorJobRequest(String detectorId, DateRange dateRange) { - return new AnomalyDetectorJobRequest(detectorId, dateRange, false, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, START_JOB); + private JobRequest startDetectorJobRequest(String detectorId, DateRange dateRange) { + return new JobRequest(detectorId, dateRange, false, START_JOB); } - private AnomalyDetectorJobRequest stopDetectorJobRequest(String detectorId, boolean historical) { - return new AnomalyDetectorJobRequest(detectorId, null, historical, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, STOP_JOB); + private JobRequest stopDetectorJobRequest(String detectorId, boolean historical) { + return new JobRequest(detectorId, null, historical, STOP_JOB); } public void testStopRealtimeDetector() throws IOException { @@ -421,7 +387,7 @@ public void testStopRealtimeDetector() throws IOException { String detectorId = realtimeResult.get(0); String jobId = realtimeResult.get(1); - AnomalyDetectorJobRequest request = stopDetectorJobRequest(detectorId, false); + JobRequest request = stopDetectorJobRequest(detectorId, false); client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); GetResponse doc = getDoc(CommonName.JOB_INDEX, detectorId); Job job = toADJob(doc); @@ -448,7 +414,7 @@ public void testStopHistoricalDetector() throws IOException, InterruptedExceptio if (taskRunning) { // It's possible that the task not started on worker node yet. Recancel it to make sure // task cancelled. - AnomalyDetectorJobRequest request = stopDetectorJobRequest(adTask.getConfigId(), true); + JobRequest request = stopDetectorJobRequest(adTask.getConfigId(), true); client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); } return !taskRunning; @@ -463,9 +429,9 @@ public void testStopHistoricalDetector() throws IOException, InterruptedExceptio public void testProfileHistoricalDetector() throws IOException, InterruptedException { ADTask adTask = startHistoricalAnalysis(startTime, endTime); - GetAnomalyDetectorRequest request = taskProfileRequest(adTask.getConfigId()); + GetConfigRequest request = taskProfileRequest(adTask.getConfigId()); GetAnomalyDetectorResponse response = client().execute(GetAnomalyDetectorAction.INSTANCE, request).actionGet(10000); - assertTrue(response.getDetectorProfile().getAdTaskProfile() != null); + assertTrue(response.getDetectorProfile().getTaskProfile() != null); ADTask finishedTask = getADTask(adTask.getTaskId()); int i = 0; @@ -477,8 +443,8 @@ public void testProfileHistoricalDetector() throws IOException, InterruptedExcep assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(finishedTask.getState())); response = client().execute(GetAnomalyDetectorAction.INSTANCE, request).actionGet(10000); - assertNull(response.getDetectorProfile().getAdTaskProfile().getNodeId()); - ADTask profileAdTask = response.getDetectorProfile().getAdTaskProfile().getAdTask(); + assertNull(response.getDetectorProfile().getTaskProfile().getNodeId()); + ADTask profileAdTask = response.getDetectorProfile().getTaskProfile().getTask(); assertEquals(finishedTask.getTaskId(), profileAdTask.getTaskId()); assertEquals(finishedTask.getConfigId(), profileAdTask.getConfigId()); assertEquals(finishedTask.getDetector(), profileAdTask.getDetector()); @@ -489,28 +455,28 @@ public void testProfileWithMultipleRunningTask() throws IOException { ADTask adTask1 = startHistoricalAnalysis(startTime, endTime); ADTask adTask2 = startHistoricalAnalysis(startTime, endTime); - GetAnomalyDetectorRequest request1 = taskProfileRequest(adTask1.getConfigId()); - GetAnomalyDetectorRequest request2 = taskProfileRequest(adTask2.getConfigId()); + GetConfigRequest request1 = taskProfileRequest(adTask1.getConfigId()); + GetConfigRequest request2 = taskProfileRequest(adTask2.getConfigId()); GetAnomalyDetectorResponse response1 = client().execute(GetAnomalyDetectorAction.INSTANCE, request1).actionGet(10000); GetAnomalyDetectorResponse response2 = client().execute(GetAnomalyDetectorAction.INSTANCE, request2).actionGet(10000); - ADTaskProfile taskProfile1 = response1.getDetectorProfile().getAdTaskProfile(); - ADTaskProfile taskProfile2 = response2.getDetectorProfile().getAdTaskProfile(); + TaskProfile taskProfile1 = response1.getDetectorProfile().getTaskProfile(); + TaskProfile taskProfile2 = response2.getDetectorProfile().getTaskProfile(); assertNotNull(taskProfile1.getNodeId()); assertNotNull(taskProfile2.getNodeId()); assertNotEquals(taskProfile1.getNodeId(), taskProfile2.getNodeId()); } - private GetAnomalyDetectorRequest taskProfileRequest(String detectorId) throws IOException { - return new GetAnomalyDetectorRequest(detectorId, Versions.MATCH_ANY, false, false, "", PROFILE, true, null); + private GetConfigRequest taskProfileRequest(String detectorId) throws IOException { + return new GetConfigRequest(detectorId, Versions.MATCH_ANY, false, false, "", PROFILE, true, null); } private long getExecutingADTask() { - ADStatsRequest adStatsRequest = new ADStatsRequest(getDataNodesArray()); + StatsRequest adStatsRequest = new StatsRequest(getDataNodesArray()); Set validStats = ImmutableSet.of(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName()); adStatsRequest.addAll(validStats); - StatsAnomalyDetectorResponse statsResponse = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5000); + StatsTimeSeriesResponse statsResponse = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5000); AtomicLong totalExecutingTask = new AtomicLong(0); - statsResponse.getAdStatsResponse().getADStatsNodesResponse().getNodes().forEach(node -> { + statsResponse.getAdStatsResponse().getStatsNodesResponse().getNodes().forEach(node -> { totalExecutingTask.getAndAdd((Long) node.getStatsMap().get(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName())); }); return totalExecutingTask.get(); diff --git a/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java similarity index 77% rename from src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java index 7c3de7ed2..dd6e66ab9 100644 --- a/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -23,12 +23,11 @@ import org.junit.Before; import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -38,9 +37,14 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.task.ForecastTaskManager; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.transport.TransportService; import com.google.gson.JsonElement; @@ -65,14 +69,20 @@ public void setUp() throws Exception { TransportService transportService = mock(TransportService.class); ActionFilters actionFilters = mock(ActionFilters.class); NodeStateManager tarnsportStatemanager = mock(NodeStateManager.class); - ModelManager modelManager = mock(ModelManager.class); + ADModelManager modelManager = mock(ADModelManager.class); FeatureManager featureManager = mock(FeatureManager.class); - CacheProvider cacheProvider = mock(CacheProvider.class); - EntityCache entityCache = mock(EntityCache.class); - EntityColdStarter entityColdStarter = mock(EntityColdStarter.class); + ADCacheProvider cacheProvider = mock(ADCacheProvider.class); + ADPriorityCache entityCache = mock(ADPriorityCache.class); + ADColdStart entityColdStarter = mock(ADColdStart.class); when(cacheProvider.get()).thenReturn(entityCache); ADTaskManager adTaskManager = mock(ADTaskManager.class); + ForecastCacheProvider forecastCacheProvider = mock(ForecastCacheProvider.class); + ForecastPriorityCache forecastCache = mock(ForecastPriorityCache.class); + ForecastColdStart forecastColdStarter = mock(ForecastColdStart.class); + when(forecastCacheProvider.get()).thenReturn(forecastCache); + ForecastTaskManager forecastTaskManager = mock(ForecastTaskManager.class); + action = new CronTransportAction( threadPool, clusterService, @@ -82,8 +92,11 @@ public void setUp() throws Exception { modelManager, featureManager, cacheProvider, + forecastCacheProvider, entityColdStarter, - adTaskManager + forecastColdStarter, + adTaskManager, + forecastTaskManager ); } diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorActionTests.java similarity index 84% rename from src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorActionTests.java index 93e291325..00c667b86 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -27,6 +27,8 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.DeleteAnomalyDetectorAction; +import org.opensearch.ad.transport.DeleteAnomalyDetectorTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.ClusterSettings; @@ -35,6 +37,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.transport.TransportService; public class DeleteAnomalyDetectorActionTests extends OpenSearchIntegTestCase { @@ -60,6 +63,7 @@ public void setUp() throws Exception { clusterService, Settings.EMPTY, xContentRegistry(), + mock(NodeStateManager.class), adTaskManager ); response = new ActionListener() { @@ -83,18 +87,18 @@ public void testStatsAction() { @Test public void testDeleteRequest() throws IOException { - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - DeleteAnomalyDetectorRequest newRequest = new DeleteAnomalyDetectorRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + DeleteConfigRequest newRequest = new DeleteConfigRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); Assert.assertNull(newRequest.validate()); } @Test public void testEmptyDeleteRequest() { - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest(""); + DeleteConfigRequest request = new DeleteConfigRequest(""); ActionRequestValidationException exception = request.validate(); Assert.assertNotNull(exception); } @@ -103,14 +107,14 @@ public void testEmptyDeleteRequest() { public void testTransportActionWithAdIndex() { // DeleteResponse is not called because detector ID will not exist createIndex(".opendistro-anomaly-detector-jobs"); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); action.doExecute(mock(Task.class), request, response); } @Test public void testTransportActionWithoutAdIndex() throws IOException { // DeleteResponse is not called because detector ID will not exist - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); action.doExecute(mock(Task.class), request, response); } } diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java b/src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorTests.java similarity index 89% rename from src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java rename to src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorTests.java index 03608048f..8092368e3 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorTests.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -38,6 +38,7 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.DeleteAnomalyDetectorTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -54,6 +55,8 @@ import org.opensearch.tasks.Task; import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; @@ -73,6 +76,7 @@ public class DeleteAnomalyDetectorTests extends AbstractTimeSeriesTest { private GetResponse getResponse; ClusterService clusterService; private Job jobParameter; + private NodeStateManager nodeStatemanager; @BeforeClass public static void setUpBeforeClass() { @@ -109,6 +113,7 @@ public void setUp() throws Exception { actionFilters = mock(ActionFilters.class); adTaskManager = mock(ADTaskManager.class); + nodeStatemanager = mock(NodeStateManager.class); action = new DeleteAnomalyDetectorTransportAction( transportService, actionFilters, @@ -116,6 +121,7 @@ public void setUp() throws Exception { clusterService, Settings.EMPTY, xContentRegistry(), + nodeStatemanager, adTaskManager ); @@ -128,32 +134,32 @@ public void setUp() throws Exception { public void testDeleteADTransportAction_FailDeleteResponse() { future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(true, true, false, false); action.doExecute(mock(Task.class), request, future); - verify(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + verify(adTaskManager).deleteTasks(eq("1234"), any(), any()); verify(client, times(1)).delete(any(), any()); verify(future).onFailure(any(OpenSearchStatusException.class)); } public void testDeleteADTransportAction_NullAnomalyDetector() { future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(true, false, false, false); action.doExecute(mock(Task.class), request, future); - verify(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + verify(adTaskManager).deleteTasks(eq("1234"), any(), any()); verify(client, times(3)).delete(any(), any()); } public void testDeleteADTransportAction_DeleteResponseException() { future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(true, false, true, false); action.doExecute(mock(Task.class), request, future); - verify(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + verify(adTaskManager).deleteTasks(eq("1234"), any(), any()); verify(client, times(1)).delete(any(), any()); verify(future).onFailure(any(RuntimeException.class)); } @@ -167,10 +173,10 @@ public void testDeleteADTransportAction_LatestDetectorLevelTask() { ADTask adTask = ADTask.builder().state("RUNNING").build(); consumer.accept(Optional.of(adTask)); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(eq("1234"), any(), any(), eq(transportService), eq(true), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(eq("1234"), any(), any(), eq(transportService), eq(true), any()); future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(false, false, false, false); action.doExecute(mock(Task.class), request, future); @@ -180,7 +186,7 @@ public void testDeleteADTransportAction_LatestDetectorLevelTask() { public void testDeleteADTransportAction_JobRunning() { when(clusterService.state()).thenReturn(createClusterState()); future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(false, false, false, false); action.doExecute(mock(Task.class), request, future); @@ -190,7 +196,7 @@ public void testDeleteADTransportAction_JobRunning() { public void testDeleteADTransportAction_GetResponseException() { when(clusterService.state()).thenReturn(createClusterState()); future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(false, false, false, true); action.doExecute(mock(Task.class), request, future); @@ -246,7 +252,7 @@ private void setupMocks( consumer.accept(Optional.of(ad)); } return null; - }).when(adTaskManager).getDetector(any(), any(), any()); + }).when(nodeStatemanager).getConfig(any(), any(), any()); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -254,7 +260,7 @@ private void setupMocks( function.execute(); return null; - }).when(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + }).when(adTaskManager).deleteTasks(eq("1234"), any(), any()); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -300,7 +306,8 @@ private void setupMocks( Instant.now(), 60L, TestHelpers.randomUser(), - jobParameter.getCustomResultIndex() + jobParameter.getCustomResultIndex(), + AnalysisType.AD ).toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS) ), Collections.emptyMap(), diff --git a/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java b/src/test/java/org/opensearch/timeseries/transport/EntityProfileTests.java similarity index 86% rename from src/test/java/org/opensearch/ad/transport/EntityProfileTests.java rename to src/test/java/org/opensearch/timeseries/transport/EntityProfileTests.java index bbb436c36..d4e6cf8bf 100644 --- a/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/EntityProfileTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.anyString; @@ -31,15 +31,12 @@ import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.model.ModelProfileOnNode; +import org.opensearch.ad.transport.ADEntityProfileAction; +import org.opensearch.ad.transport.ADEntityProfileTransportAction; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -51,9 +48,13 @@ import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; import org.opensearch.transport.ConnectTransportException; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportException; @@ -78,8 +79,8 @@ public class EntityProfileTests extends AbstractTimeSeriesTest { private TransportService transportService; private Settings settings; private ClusterService clusterService; - private CacheProvider cacheProvider; - private EntityProfileTransportAction action; + private ADCacheProvider cacheProvider; + private ADEntityProfileTransportAction action; private Task task; private PlainActionFuture future; private TransportAddress transportAddress1; @@ -135,18 +136,18 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); - cacheProvider = mock(CacheProvider.class); - EntityCache cache = mock(EntityCache.class); + cacheProvider = mock(ADCacheProvider.class); + ADPriorityCache cache = mock(ADPriorityCache.class); updates = 1L; when(cache.getTotalUpdates(anyString(), anyString())).thenReturn(updates); when(cache.isActive(anyString(), anyString())).thenReturn(isActive); - when(cache.getLastActiveMs(anyString(), anyString())).thenReturn(lastActiveTimestamp); + when(cache.getLastActiveTime(anyString(), anyString())).thenReturn(lastActiveTimestamp); Map modelSizeMap = new HashMap<>(); modelSizeMap.put(modelId, modelSize); when(cache.getModelSize(anyString())).thenReturn(modelSizeMap); when(cacheProvider.get()).thenReturn(cache); - action = new EntityProfileTransportAction(actionFilters, transportService, settings, hashRing, clusterService, cacheProvider); + action = new ADEntityProfileTransportAction(actionFilters, transportService, settings, hashRing, clusterService, cacheProvider); future = new PlainActionFuture<>(); transportAddress1 = new TransportAddress(new InetSocketAddress(InetAddress.getByName("1.2.3.4"), 9300)); @@ -167,7 +168,7 @@ public void sendRequest( TransportRequestOptions options, TransportResponseHandler handler ) { - if (EntityProfileAction.NAME.equals(action)) { + if (ADEntityProfileAction.NAME.equals(action)) { sender.sendRequest(connection, action, request, options, entityProfileHandler(handler)); } else { sender.sendRequest(connection, action, request, options, handler); @@ -189,7 +190,7 @@ public void sendRequest( TransportRequestOptions options, TransportResponseHandler handler ) { - if (EntityProfileAction.NAME.equals(action)) { + if (ADEntityProfileAction.NAME.equals(action)) { sender.sendRequest(connection, action, request, options, entityFailureProfileandler(handler)); } else { sender.sendRequest(connection, action, request, options, handler); @@ -238,7 +239,7 @@ public void handleResponse(T response) { .handleException( new ConnectTransportException( new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()), - EntityProfileAction.NAME + ADEntityProfileAction.NAME ) ); } @@ -256,7 +257,7 @@ public String executor() { } private void registerHandler(FakeNode node) { - new EntityProfileTransportAction( + new ADEntityProfileTransportAction( new ActionFilters(Collections.emptySet()), node.transportService, Settings.EMPTY, @@ -267,15 +268,15 @@ private void registerHandler(FakeNode node) { } public void testInvalidRequest() { - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.empty()); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(anyString())).thenReturn(Optional.empty()); action.doExecute(task, request, future); - assertException(future, TimeSeriesException.class, EntityProfileTransportAction.NO_NODE_FOUND_MSG); + assertException(future, TimeSeriesException.class, ADEntityProfileTransportAction.NO_NODE_FOUND_MSG); } public void testLocalNodeHit() { DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.of(localNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(anyString())).thenReturn(Optional.of(localNode)); when(clusterService.localNode()).thenReturn(localNode); action.doExecute(task, request, future); @@ -285,7 +286,7 @@ public void testLocalNodeHit() { public void testAllHit() { DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.of(localNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(anyString())).thenReturn(Optional.of(localNode)); when(clusterService.localNode()).thenReturn(localNode); request = new EntityProfileRequest(detectorId, entity, all); @@ -302,7 +303,7 @@ public void testGetRemoteUpdateResponse() { TransportService realTransportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; - action = new EntityProfileTransportAction( + action = new ADEntityProfileTransportAction( actionFilters, realTransportService, settings, @@ -311,7 +312,7 @@ public void testGetRemoteUpdateResponse() { cacheProvider ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); @@ -332,7 +333,7 @@ public void testGetRemoteFailureResponse() { TransportService realTransportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; - action = new EntityProfileTransportAction( + action = new ADEntityProfileTransportAction( actionFilters, realTransportService, settings, @@ -341,7 +342,7 @@ public void testGetRemoteFailureResponse() { cacheProvider ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); @@ -361,7 +362,7 @@ public void testResponseToXContent() throws IOException, JsonPathNotFoundExcepti EntityProfileResponse response = builder.build(); String json = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); assertEquals(lastActiveTimestamp, JsonDeserializer.getLongValue(json, EntityProfileResponse.LAST_ACTIVE_TS)); - assertEquals(modelSize, JsonDeserializer.getChildNode(json, ADCommonName.MODEL, CommonName.MODEL_SIZE_IN_BYTES).getAsLong()); + assertEquals(modelSize, JsonDeserializer.getChildNode(json, CommonName.MODEL, CommonName.MODEL_SIZE_IN_BYTES).getAsLong()); } public void testResponseHashCodeEquals() { @@ -378,8 +379,8 @@ public void testResponseHashCodeEquals() { } public void testEntityProfileName() { - assertEquals("state", EntityProfileName.getName(ADCommonName.STATE).getName()); - assertEquals("models", EntityProfileName.getName(ADCommonName.MODELS).getName()); + assertEquals("state", EntityProfileName.getName(CommonName.STATE).getName()); + assertEquals("models", EntityProfileName.getName(CommonName.MODELS).getName()); IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> EntityProfileName.getName("abc")); assertEquals(exception.getMessage(), ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); } diff --git a/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java b/src/test/java/org/opensearch/timeseries/transport/SearchAnomalyDetectorInfoActionTests.java similarity index 82% rename from src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/SearchAnomalyDetectorInfoActionTests.java index f06761bb6..47a5d0877 100644 --- a/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/SearchAnomalyDetectorInfoActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -30,6 +30,8 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.transport.SearchAnomalyDetectorInfoAction; +import org.opensearch.ad.transport.SearchAnomalyDetectorInfoTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -46,15 +48,15 @@ import org.opensearch.transport.TransportService; public class SearchAnomalyDetectorInfoActionTests extends OpenSearchIntegTestCase { - private SearchAnomalyDetectorInfoRequest request; - private ActionListener response; + private SearchConfigInfoRequest request; + private ActionListener response; private SearchAnomalyDetectorInfoTransportAction action; private Task task; private ClusterService clusterService; private Client client; private ThreadPool threadPool; ThreadContext threadContext; - private PlainActionFuture future; + private PlainActionFuture future; @Override @Before @@ -67,9 +69,9 @@ public void setUp() throws Exception { clusterService() ); task = mock(Task.class); - response = new ActionListener() { + response = new ActionListener() { @Override - public void onResponse(SearchAnomalyDetectorInfoResponse response) { + public void onResponse(SearchConfigInfoResponse response) { Assert.assertEquals(response.getCount(), 0); Assert.assertEquals(response.isNameExists(), false); } @@ -100,14 +102,14 @@ public void onFailure(Exception e) { @Test public void testSearchCount() throws IOException { // Anomaly Detectors index will not exist, onResponse will be called - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest(null, "count"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest(null, "count"); action.doExecute(task, request, response); } @Test public void testSearchMatch() throws IOException { // Anomaly Detectors index will not exist, onResponse will be called - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "match"); action.doExecute(task, request, response); } @@ -119,11 +121,11 @@ public void testSearchInfoAction() { @Test public void testSearchInfoRequest() throws IOException { - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "match"); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - SearchAnomalyDetectorInfoRequest newRequest = new SearchAnomalyDetectorInfoRequest(input); + SearchConfigInfoRequest newRequest = new SearchConfigInfoRequest(input); Assert.assertEquals(request.getName(), newRequest.getName()); Assert.assertEquals(request.getRawPath(), newRequest.getRawPath()); Assert.assertNull(newRequest.validate()); @@ -131,11 +133,11 @@ public void testSearchInfoRequest() throws IOException { @Test public void testSearchInfoResponse() throws IOException { - SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(1, true); + SearchConfigInfoResponse response = new SearchConfigInfoResponse(1, true); BytesStreamOutput out = new BytesStreamOutput(); response.writeTo(out); StreamInput input = out.bytes().streamInput(); - SearchAnomalyDetectorInfoResponse newResponse = new SearchAnomalyDetectorInfoResponse(input); + SearchConfigInfoResponse newResponse = new SearchConfigInfoResponse(input); Assert.assertEquals(response.getCount(), newResponse.getCount()); Assert.assertEquals(response.isNameExists(), newResponse.isNameExists()); Assert.assertNotNull(response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); @@ -156,9 +158,9 @@ public void testSearchInfoResponse_CountSuccessWithEmptyResponse() throws IOExce client, clusterService ); - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "count"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "count"); action.doExecute(task, request, future); - verify(future).onResponse(any(SearchAnomalyDetectorInfoResponse.class)); + verify(future).onResponse(any(SearchConfigInfoResponse.class)); } public void testSearchInfoResponse_MatchSuccessWithEmptyResponse() throws IOException { @@ -176,9 +178,9 @@ public void testSearchInfoResponse_MatchSuccessWithEmptyResponse() throws IOExce client, clusterService ); - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "match"); action.doExecute(task, request, future); - verify(future).onResponse(any(SearchAnomalyDetectorInfoResponse.class)); + verify(future).onResponse(any(SearchConfigInfoResponse.class)); } public void testSearchInfoResponse_CountRuntimeException() throws IOException { @@ -194,7 +196,7 @@ public void testSearchInfoResponse_CountRuntimeException() throws IOException { client, clusterService ); - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "count"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "count"); action.doExecute(task, request, future); verify(future).onFailure(any(RuntimeException.class)); } @@ -212,7 +214,7 @@ public void testSearchInfoResponse_MatchRuntimeException() throws IOException { client, clusterService ); - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "match"); action.doExecute(task, request, future); verify(future).onFailure(any(RuntimeException.class)); } diff --git a/src/test/java/org/opensearch/timeseries/util/ClientUtilTests.java b/src/test/java/org/opensearch/timeseries/util/ClientUtilTests.java index d4241fc4f..031c234c8 100644 --- a/src/test/java/org/opensearch/timeseries/util/ClientUtilTests.java +++ b/src/test/java/org/opensearch/timeseries/util/ClientUtilTests.java @@ -78,7 +78,8 @@ public void testAsyncRequestOnSuccess() throws InterruptedException { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); BiConsumer> consumer = (request, actionListener) -> { // simulate successful operation @@ -122,7 +123,8 @@ public void testExecuteOnSuccess() throws InterruptedException { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); doAnswer(invocationOnMock -> { ((ActionListener) invocationOnMock.getArguments()[2]).onResponse(expected); diff --git a/src/test/java/test/org/opensearch/ad/util/MLUtil.java b/src/test/java/test/org/opensearch/ad/util/MLUtil.java index 6b6bb39af..87dbfb63b 100644 --- a/src/test/java/test/org/opensearch/ad/util/MLUtil.java +++ b/src/test/java/test/org/opensearch/ad/util/MLUtil.java @@ -14,17 +14,20 @@ import static java.lang.Math.PI; import java.time.Clock; +import java.time.Instant; import java.util.ArrayDeque; +import java.util.Deque; import java.util.HashMap; import java.util.Map; -import java.util.Queue; +import java.util.Optional; import java.util.Random; import java.util.stream.IntStream; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.common.collect.Tuple; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -53,13 +56,13 @@ private static String randomString(int targetStringLength) { .toString(); } - public static Queue createQueueSamples(int size) { - Queue res = new ArrayDeque<>(); - IntStream.range(0, size).forEach(i -> res.offer(new double[] { random.nextDouble() })); + public static Deque createQueueSamples(int size) { + Deque res = new ArrayDeque<>(); + IntStream.range(0, size).forEach(i -> res.offer(new Sample(new double[] { random.nextDouble() }, Instant.now(), Instant.now()))); return res; } - public static ModelState randomModelState(RandomModelStateConfig config) { + public static ModelState randomModelState(RandomModelStateConfig config) { boolean fullModel = config.getFullModel() != null && config.getFullModel().booleanValue() ? true : false; float priority = config.getPriority() != null ? config.getPriority() : random.nextFloat(); String detectorId = config.getId() != null ? config.getId() : randomString(15); @@ -75,27 +78,36 @@ public static ModelState randomModelState(RandomModelStateConfig co } else { entity = Entity.createSingleAttributeEntity("", ""); } - EntityModel model = null; + Pair> model = null; if (fullModel) { model = createNonEmptyModel(detectorId, sampleSize, entity); } else { model = createEmptyModel(entity, sampleSize); } - return new ModelState<>(model, detectorId, detectorId, ModelType.ENTITY.getName(), clock, priority); + return new ModelState( + model.getLeft(), + detectorId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + model.getRight() + ); } - public static EntityModel createEmptyModel(Entity entity, int sampleSize) { - Queue samples = createQueueSamples(sampleSize); - return new EntityModel(entity, samples, null); + public static Pair> createEmptyModel(Entity entity, int sampleSize) { + Deque samples = createQueueSamples(sampleSize); + return Pair.of(null, samples); } - public static EntityModel createEmptyModel(Entity entity) { + public static Pair> createEmptyModel(Entity entity) { return createEmptyModel(entity, random.nextInt(minSampleSize)); } - public static EntityModel createNonEmptyModel(String detectorId, int sampleSize, Entity entity) { - Queue samples = createQueueSamples(sampleSize); + public static Pair> createNonEmptyModel(String detectorId, int sampleSize, Entity entity) { + Deque samples = createQueueSamples(sampleSize); int numDataPoints = random.nextInt(1000) + TimeSeriesSettings.NUM_MIN_SAMPLES; ThresholdedRandomCutForest trcf = new ThresholdedRandomCutForest( ThresholdedRandomCutForest @@ -103,7 +115,7 @@ public static EntityModel createNonEmptyModel(String detectorId, int sampleSize, .dimensions(1) .sampleSize(TimeSeriesSettings.NUM_SAMPLES_PER_TREE) .numberOfTrees(TimeSeriesSettings.NUM_TREES) - .timeDecay(TimeSeriesSettings.TIME_DECAY) + .timeDecay(0.0001) .outputAfter(TimeSeriesSettings.NUM_MIN_SAMPLES) .initialAcceptFraction(0.125d) .parallelExecutionEnabled(false) @@ -116,11 +128,10 @@ public static EntityModel createNonEmptyModel(String detectorId, int sampleSize, for (int i = 0; i < numDataPoints; i++) { trcf.process(new double[] { random.nextDouble() }, i); } - EntityModel entityModel = new EntityModel(entity, samples, trcf); - return entityModel; + return Pair.of(trcf, samples); } - public static EntityModel createNonEmptyModel(String detectorId) { + public static Pair> createNonEmptyModel(String detectorId) { return createNonEmptyModel(detectorId, random.nextInt(minSampleSize), Entity.createSingleAttributeEntity("", "")); } @@ -177,23 +188,27 @@ static double[] getDataD(int num, double amplitude, double noise, long seed) { * Prepare models and return training samples * @param inputDimension Input dimension * @param rcfConfig RCF config + * @param intervalMillis detector interval in milliseconds * @return models and return training samples */ - public static Tuple, ThresholdedRandomCutForest> prepareModel( + public static Tuple, ThresholdedRandomCutForest> prepareModel( int inputDimension, - ThresholdedRandomCutForest.Builder rcfConfig + ThresholdedRandomCutForest.Builder rcfConfig, + long intervalMillis ) { - Queue samples = new ArrayDeque<>(); + Deque samples = new ArrayDeque<>(); Random r = new Random(); ThresholdedRandomCutForest rcf = new ThresholdedRandomCutForest(rcfConfig); int trainDataNum = 1000; + Instant currentTime = Instant.now(); for (int i = 0; i < trainDataNum; i++) { double[] point = r.ints(inputDimension, 0, 50).asDoubleStream().toArray(); - samples.add(point); - rcf.process(point, 0); + samples.add(new Sample(point, currentTime.minusMillis(intervalMillis), currentTime)); + rcf.process(point, currentTime.getEpochSecond()); + currentTime = currentTime.plusMillis(intervalMillis); } return Tuple.tuple(samples, rcf);