Skip to content

Commit

Permalink
feat: force retrain
Browse files Browse the repository at this point in the history
Signed-off-by: Kushal Batra <[email protected]>
  • Loading branch information
s0nicboOm committed Jul 18, 2024
1 parent eb9ceb1 commit 090c490
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 5 deletions.
1 change: 1 addition & 0 deletions numalogic/udfs/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class TrainerPayload(_BasePayload):

metrics: list[str]
header: Header = Header.TRAIN_REQUEST
force_train_req: bool = False

def to_json(self):
return orjson.dumps(self)
Expand Down
2 changes: 1 addition & 1 deletion numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
metrics=payload.metrics,
)
# Send training request if inference fails
msgs = Messages(get_trainer_message(keys, _stream_conf, payload))
msgs = Messages(get_trainer_message(keys, _stream_conf, payload, _force_train=True))
if _conf.numalogic_conf.score.adjust:
msgs.append(get_static_thresh_message(keys, payload))
return msgs
Expand Down
2 changes: 1 addition & 1 deletion numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload_metrics=payload.metrics,
)
# Send training request if postprocess fails
msgs = Messages(get_trainer_message(keys, _stream_conf, payload))
msgs = Messages(get_trainer_message(keys, _stream_conf, payload, _force_train=True))
if _conf.numalogic_conf.score.adjust:
msgs.append(get_static_thresh_message(keys, payload))
return msgs
Expand Down
8 changes: 7 additions & 1 deletion numalogic/udfs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
status=Status.RUNTIME_ERROR,
)
msgs = Messages(
get_trainer_message(keys, _stream_conf, payload, **_metric_label_values),
get_trainer_message(
keys=keys,
stream_conf=_stream_conf,
payload=payload,
_force_train=True,
**_metric_label_values
),
)
if _conf.numalogic_conf.score.adjust:
msgs.append(get_static_thresh_message(keys, payload))
Expand Down
9 changes: 8 additions & 1 deletion numalogic/udfs/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pandas import DataFrame
from pynumaflow.mapper import Message


from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.tools.exceptions import RedisRegistryError
from numalogic.tools.types import KEYS, redis_client_t
Expand Down Expand Up @@ -285,6 +284,7 @@ def ack_read(
retry: int = 600,
min_train_records: int = 180,
data_freq: int = 60,
_force_train: bool = False,
) -> bool:
"""
Acknowledge the read message. Return True when the msg has to be trained.
Expand All @@ -295,6 +295,7 @@ def ack_read(
retry: Time difference(in secs) between triggering retraining and msg read_ack.
min_train_records: minimum number of records required for training.
data_freq: data granularity/frequency in secs.
_force_train: force training for the key.
Returns
-------
Expand Down Expand Up @@ -332,6 +333,10 @@ def ack_read(
logger.debug("Model with key is being trained by another process")
return False

if _force_train:
logger.debug("Forcing training for the key")
return True

Check warning on line 338 in numalogic/udfs/tools.py

View check run for this annotation

Codecov / codecov/patch

numalogic/udfs/tools.py#L337-L338

Added lines #L337 - L338 were not covered by tests

# This check is needed if there is backpressure in the pipeline
if _msg_train_ts and time.time() - float(_msg_train_ts) < retrain_freq * 60 * 60:
logger.debug(
Expand Down Expand Up @@ -374,6 +379,7 @@ def get_trainer_message(
keys: list[str],
stream_conf: StreamConf,
payload: StreamPayload,
_force_train: bool = False,
**metric_values: dict,
) -> Message:
"""
Expand All @@ -397,6 +403,7 @@ def get_trainer_message(
metrics=payload.metrics,
config_id=payload.config_id,
pipeline_id=payload.pipeline_id,
force_train_req=_force_train,
)
if metric_values:
_increment_counter(
Expand Down
1 change: 1 addition & 0 deletions numalogic/udfs/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
retry=retry_ts,
min_train_records=_conf.numalogic_conf.trainer.min_train_size,
data_freq=_conf.numalogic_conf.trainer.data_freq_sec,
_force_train=payload.force_train_req,
):
_increment_counter(
counter="MSG_DROPPED_COUNTER",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.12.4"
version = "0.12.5"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down

0 comments on commit 090c490

Please sign in to comment.