Skip to content

Commit

Permalink
fix: boto3 import
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed May 9, 2024
1 parent b20f343 commit 83633ed
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 42 deletions.
6 changes: 5 additions & 1 deletion numalogic/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
RDSConf,
RDSFetcherConf,
)
from numalogic.connectors.rds import RDSFetcher
from numalogic.connectors.prometheus import PrometheusFetcher

__all__ = [
Expand All @@ -26,6 +25,11 @@
"RDSFetcherConf",
]

if find_spec("boto3"):
from numalogic.connectors.rds import RDSFetcher # noqa: F401

__all__.append("RDSFetcher")

if find_spec("pydruid"):
from numalogic.connectors.druid import DruidFetcher # noqa: F401

Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.10.0a0"
version = "0.10.0a1"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
65 changes: 30 additions & 35 deletions tests/udfs/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import logging
import os
import unittest
from datetime import datetime
from fakeredis import FakeServer, FakeStrictRedis
from omegaconf import OmegaConf
from orjson import orjson
import pytest

from numalogic._constants import TESTS_DIR
from numalogic.udfs._config import PipelineConf
from numalogic.udfs import PipelineConf
from numalogic.udfs.payloadtx import PayloadTransformer
from tests.udfs.utility import input_json_from_file

logging.basicConfig(level=logging.DEBUG)
REDIS_CLIENT = FakeStrictRedis(server=FakeServer())
KEYS = ["service-mesh", "1", "2"]
DATUM = input_json_from_file(os.path.join(TESTS_DIR, "udfs", "resources", "data", "stream.json"))

Expand All @@ -22,34 +20,31 @@
}


class TestPipelineUDF(unittest.TestCase):
def setUp(self) -> None:
_given_conf = OmegaConf.load(os.path.join(TESTS_DIR, "udfs", "resources", "_config.yaml"))
_given_conf_2 = OmegaConf.load(
os.path.join(TESTS_DIR, "udfs", "resources", "_config2.yaml")
)
schema = OmegaConf.structured(PipelineConf)
pl_conf = PipelineConf(**OmegaConf.merge(schema, _given_conf))
pl_conf_2 = PipelineConf(**OmegaConf.merge(schema, _given_conf_2))
self.udf1 = PayloadTransformer(pl_conf=pl_conf)
self.udf2 = PayloadTransformer(pl_conf=pl_conf_2)
self.udf1.register_conf("druid-config", pl_conf.stream_confs["druid-config"])
self.udf2.register_conf("druid-config", pl_conf_2.stream_confs["druid-config"])

def test_pipeline_1(self):
msgs = self.udf1(KEYS, DATUM)
self.assertEqual(2, len(msgs))
for msg in msgs:
data_payload = orjson.loads(msg.value)
self.assertTrue(data_payload["pipeline_id"])

def test_pipeline_2(self):
msgs = self.udf2(KEYS, DATUM)
self.assertEqual(1, len(msgs))
for msg in msgs:
data_payload = orjson.loads(msg.value)
self.assertTrue(data_payload["pipeline_id"])


if __name__ == "__main__":
unittest.main()
@pytest.fixture
def setup():
_given_conf = OmegaConf.load(os.path.join(TESTS_DIR, "udfs", "resources", "_config.yaml"))
_given_conf_2 = OmegaConf.load(os.path.join(TESTS_DIR, "udfs", "resources", "_config2.yaml"))
schema = OmegaConf.structured(PipelineConf)
pl_conf = PipelineConf(**OmegaConf.merge(schema, _given_conf))
pl_conf_2 = PipelineConf(**OmegaConf.merge(schema, _given_conf_2))
udf1 = PayloadTransformer(pl_conf=pl_conf)
udf2 = PayloadTransformer(pl_conf=pl_conf_2)
udf1.register_conf("druid-config", pl_conf.stream_confs["druid-config"])
udf2.register_conf("druid-config", pl_conf_2.stream_confs["druid-config"])
return udf1, udf2


def test_pipeline_1(setup):
msgs = setup[0](KEYS, DATUM)
assert 2 == len(msgs)
for msg in msgs:
data_payload = orjson.loads(msg.value)
assert data_payload["pipeline_id"]


def test_pipeline_2(setup):
msgs = setup[1](KEYS, DATUM)
assert 1 == len(msgs)
for msg in msgs:
data_payload = orjson.loads(msg.value)
assert data_payload["pipeline_id"]

0 comments on commit 83633ed

Please sign in to comment.