From 9ad6d8ddd168093a771a41a13432867efff99360 Mon Sep 17 00:00:00 2001 From: Avik Basu Date: Tue, 23 Apr 2024 13:25:55 -0700 Subject: [PATCH] feat: fallback to stddev if threshold is too low Signed-off-by: Avik Basu --- numalogic/models/threshold/_median.py | 18 +++++++++++++++++- pyproject.toml | 2 +- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/numalogic/models/threshold/_median.py b/numalogic/models/threshold/_median.py index ed7b806e..6caf49fe 100644 --- a/numalogic/models/threshold/_median.py +++ b/numalogic/models/threshold/_median.py @@ -5,6 +5,9 @@ from numalogic.base import BaseThresholdModel from numalogic.tools.exceptions import InvalidDataShapeError, ModelInitializationError +import logging + +LOGGER = logging.getLogger(__name__) _INLIER: Final[int] = 0 _OUTLIER: Final[int] = 1 _INPUT_DIMS: Final[int] = 2 @@ -19,18 +22,20 @@ class MaxPercentileThreshold(BaseThresholdModel): min_threshold: Value to be used if threshold is less than this """ - __slots__ = ("_max_percentile", "_min_thresh", "_thresh", "_is_fitted") + __slots__ = ("_max_percentile", "_min_thresh", "_thresh", "_is_fitted", "_adjust_threshold") def __init__( self, max_inlier_percentile: float = 96.0, min_threshold: float = 1e-4, + adjust_threshold: bool = False, ): super().__init__() self._max_percentile = max_inlier_percentile self._min_thresh = min_threshold self._thresh = None self._is_fitted = False + self._adjust_threshold = adjust_threshold @property def threshold(self): @@ -45,6 +50,17 @@ def _validate_input(x: npt.NDArray[float]) -> None: def fit(self, x: npt.NDArray[float]) -> Self: self._validate_input(x) self._thresh = np.percentile(x, self._max_percentile, axis=0) + + if self._adjust_threshold: + for idx, _ in enumerate(self._thresh): + if self._thresh[idx] / self._min_thresh < 1e-2: + LOGGER.info( + "Min threshold is less than 1e-2 times the " + "threshold for column %s; Using mean instead.", + idx, + ) + self._thresh[idx] = np.mean(x[:, idx]) + (3 * np.std(x[:, idx])) + self._thresh[self._thresh < self._min_thresh] = self._min_thresh self._is_fitted = True return self diff --git a/pyproject.toml b/pyproject.toml index c7bd5479..d80a3874 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.9.1a3" +version = "0.9.1a4" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }]