From 90808120eaae925d6a3e4506c18e4c4869c2b2f6 Mon Sep 17 00:00:00 2001 From: Furkan Pehlivan <65170388+pehlicd@users.noreply.github.com> Date: Tue, 9 Apr 2024 10:01:26 +0200 Subject: [PATCH] fix: type error in throttle mechanism (#1066) Co-authored-by: Tal Co-authored-by: Shahar Glazner --- keep/step/step.py | 2 +- keep/throttles/base_throttle.py | 2 +- keep/throttles/one_until_resolved_throttle.py | 5 +++-- keep/throttles/throttle_factory.py | 4 ++-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/keep/step/step.py b/keep/step/step.py index 7653d98e9..7e37d98ba 100644 --- a/keep/step/step.py +++ b/keep/step/step.py @@ -73,7 +73,7 @@ def _check_throttling(self, action_name): throttling_type = throttling.get("type") throttling_config = throttling.get("with") - throttle = ThrottleFactory.get_instance(throttling_type, throttling_config) + throttle = ThrottleFactory.get_instance(self.context_manager, throttling_type, throttling_config) alert_id = self.context_manager.get_workflow_id() return throttle.check_throttling(action_name, alert_id) diff --git a/keep/throttles/base_throttle.py b/keep/throttles/base_throttle.py index 5ecc85bc6..1a6d6379b 100644 --- a/keep/throttles/base_throttle.py +++ b/keep/throttles/base_throttle.py @@ -17,7 +17,7 @@ def __init__( Args: **kwargs: Provider configuration loaded from the provider yaml file. """ - # Initalize logger for every provider + # Initialize logger for every provider self.logger = logging.getLogger(self.__class__.__name__) self.throttle_type = throttle_type self.throttle_config = throttle_config diff --git a/keep/throttles/one_until_resolved_throttle.py b/keep/throttles/one_until_resolved_throttle.py index f2a0830bb..e9c9afcc4 100644 --- a/keep/throttles/one_until_resolved_throttle.py +++ b/keep/throttles/one_until_resolved_throttle.py @@ -1,4 +1,5 @@ from keep.throttles.base_throttle import BaseThrottle +from keep.contextmanager.contextmanager import ContextManager class OneUntilResolvedThrottle(BaseThrottle): @@ -8,8 +9,8 @@ class OneUntilResolvedThrottle(BaseThrottle): BaseThrottle (_type_): _description_ """ - def __init__(self, throttle_type, throttle_config): - super().__init__(throttle_type, throttle_config) + def __init__(self, context_manager: ContextManager, throttle_type, throttle_config): + super().__init__(context_manager=context_manager, throttle_type=throttle_type, throttle_config=throttle_config) def check_throttling(self, action_name, alert_id, **kwargs) -> bool: last_alert_run = self.context_manager.get_last_workflow_run(alert_id) diff --git a/keep/throttles/throttle_factory.py b/keep/throttles/throttle_factory.py index 8f8bb4828..bb05a51f3 100644 --- a/keep/throttles/throttle_factory.py +++ b/keep/throttles/throttle_factory.py @@ -5,9 +5,9 @@ class ThrottleFactory: @staticmethod - def get_instance(throttle_type, throttle_config) -> BaseThrottle: + def get_instance(context_manager, throttle_type, throttle_config) -> BaseThrottle: module = importlib.import_module(f"keep.throttles.{throttle_type}_throttle") throttle_class = getattr( module, throttle_type.title().replace("_", "") + "Throttle" ) - return throttle_class(throttle_type, throttle_config) + return throttle_class(context_manager, throttle_type, throttle_config)