diff --git a/numalogic/config/_config.py b/numalogic/config/_config.py index 67f37eb9..a37469e7 100644 --- a/numalogic/config/_config.py +++ b/numalogic/config/_config.py @@ -157,7 +157,7 @@ class NumalogicConf: trainer: TrainerConf = field(default_factory=TrainerConf) preprocess: list[ModelInfo] = field(default_factory=list) threshold: ModelInfo = field(default_factory=lambda: ModelInfo(name="StdDevThreshold")) - postprocess: ModelInfo = field( + postprocess: Optional[ModelInfo] = field( default_factory=lambda: ModelInfo(name="TanhNorm", stateful=False) ) score: ScoreConf = field(default_factory=lambda: ScoreConf()) diff --git a/numalogic/udfs/postprocess.py b/numalogic/udfs/postprocess.py index a8342a76..f891efa9 100644 --- a/numalogic/udfs/postprocess.py +++ b/numalogic/udfs/postprocess.py @@ -122,7 +122,11 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: load_latest=LOAD_LATEST, vertex=self._vtx, ) - postproc_tx = self.postproc_factory.get_instance(postprocess_cfg) + postproc_tx = ( + self.postproc_factory.get_instance(postprocess_cfg) if postprocess_cfg else None + ) + if not postproc_tx: + logger.info("Postprocess model is absent!") if thresh_artifact is None: payload = replace( diff --git a/pyproject.toml b/pyproject.toml index 82af7bec..58978327 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.9.1a8" +version = "0.9.1a9" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }] diff --git a/tests/connectors/test_druid.py b/tests/connectors/test_druid.py index 74311f0d..471a4a78 100644 --- a/tests/connectors/test_druid.py +++ b/tests/connectors/test_druid.py @@ -109,7 +109,7 @@ def group_by(*_, **__): "env": "prod", "status": 200, "http_status": "2xx", - "count": 20 + "count": 20, }, "timestamp": "2023-09-06T07:50:00.000Z", "version": "v1", @@ -120,8 +120,7 @@ def group_by(*_, **__): "env": "prod", "status": 500, "http_status": "5xx", - "count": 10 - + "count": 10, }, "timestamp": "2023-09-06T07:53:00.000Z", "version": "v1", diff --git a/tests/udfs/test_postprocess.py b/tests/udfs/test_postprocess.py index e6979f12..e0b5e853 100644 --- a/tests/udfs/test_postprocess.py +++ b/tests/udfs/test_postprocess.py @@ -14,6 +14,7 @@ from numalogic._constants import TESTS_DIR from numalogic.models.threshold import StdDevThreshold from numalogic.registry import RedisRegistry, ArtifactData +from numalogic.transforms import TanhNorm from numalogic.udfs import PipelineConf from numalogic.udfs.entities import Header, TrainerPayload, Status, OutputPayload from numalogic.udfs.postprocess import PostprocessUDF @@ -172,7 +173,15 @@ def test_postprocess_runtime_err_02(udf, mocker, bad_artifact): assert msgs[1].tags == ["staticthresh"] -def test_compute(udf, artifact): +def test_compute_without_postproc(udf, artifact): y_unified, x_inferred = udf.compute(artifact.artifact, np.asarray(DATA["data"])) assert isinstance(y_unified, float) assert x_inferred.shape == (2,) + + +def test_compute_with_postproc(udf, artifact): + y_unified, x_inferred = udf.compute( + artifact.artifact, np.asarray(DATA["data"]), postproc_tx=TanhNorm() + ) + assert isinstance(y_unified, float) + assert x_inferred.shape == (2,)