Skip to content

Commit

Permalink
Merge pull request #553 from armosec/fix_verify_pods
Browse files Browse the repository at this point in the history
Fix verify pods and jira rate limit handling
  • Loading branch information
kooomix authored Jan 1, 2025
2 parents fa67fab + 70b8144 commit ef23e89
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 49 deletions.
71 changes: 57 additions & 14 deletions infrastructure/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,13 +1965,63 @@ def post(self, url, **args):
if not url.startswith("http://") and not url.startswith("https://"):
url = self.server + url
return requests.post(url, **args)

def post_with_ratelimit(self, url, **args):
# Extract optional parameters with defaults
rate_limit_retries = args.pop("rate_limit_retries", 1)
rate_limit_sleep = args.pop("rate_limit_sleep", 45)

for attempt in range(1, rate_limit_retries + 1):
r = self.post(url, **args)

# Check for rate limiting in status code or response text
if r.status_code == 429 or "retryafter" in r.text.lower():
Logger.logger.debug(
f"Rate limit reached for URL: {url}. Attempt {attempt} of {rate_limit_retries}. "
f"Retrying in {rate_limit_sleep} seconds."
)
if attempt < rate_limit_retries:
time.sleep(rate_limit_sleep)
else:
Logger.logger.warning(
f"Rate limit retries exhausted for URL: {url}. Returning last response."
)
else:
return r

# Return the last response if retries are exhausted
return r

@deco_cookie
def get(self, url, **args):
if not url.startswith("http://") and not url.startswith("https://"):
url = self.server + url
return requests.get(url, **args)

def get_with_rate_limit(self, url, **args):
rate_limit_retries = args.pop("rate_limit_retries", 1)
rate_limit_sleep = args.pop("rate_limit_sleep", 45)

for attempt in range(1, rate_limit_retries + 1):
r = self.get(url, **args)

if r.status_code == 429 or "retryafter" in r.text.lower():
Logger.logger.debug(
f"Rate limit reached for URL: {url}. Attempt {attempt} of {rate_limit_retries}. "
f"Retrying in {rate_limit_sleep} seconds."
)
if attempt < rate_limit_retries:
time.sleep(rate_limit_sleep)
else:
Logger.logger.warning(
f"Rate limit retries exhausted for URL: {url}. Returning last response."
)
else:
return r

# Return the last response if retries are exhausted
return r

@deco_cookie
def put(self, url, **args):
return requests.put(self.server + url, **args)
Expand Down Expand Up @@ -2886,7 +2936,7 @@ def get_integration_status(self, provider: str):

def get_jira_config(self):
url = API_INTEGRATIONS + "/jira/configV2"
r = self.get(url, params={"customerGUID": self.selected_tenant_id})
r = self.get_with_rate_limit(url, params={"customerGUID": self.selected_tenant_id})
assert 200 <= r.status_code < 300, f"{inspect.currentframe().f_code.co_name}, url: '{url}', customer: '{self.customer}' code: {r.status_code}, message: '{r.text}'"
return r.json()

Expand All @@ -2909,18 +2959,15 @@ def get_jira_collaboration_guid_by_site_name(self, site_name: str):

def update_jira_config(self, body: dict):
url = API_INTEGRATIONS + "/jira/configV2"
r = self.post(url,
params={"customerGUID": self.selected_tenant_id},
json=body)
r = self.post_with_ratelimit(url, params={"customerGUID": self.selected_tenant_id}, json=body)
if not 200 <= r.status_code < 300:
raise Exception(
'Error accessing smart remediation. Request: results of posture resources highlights "%s" (code: %d, message: %s)' % (
self.customer, r.status_code, r.text))

def search_jira_projects(self, body: dict):
url = API_INTEGRATIONS + "/jira/projectsV2/search"
r = self.post(url, params={"customerGUID": self.customer_guid},
json=body)
r = self.post_with_ratelimit(url, params={"customerGUID": self.selected_tenant_id}, json=body)
if not 200 <= r.status_code < 300:
raise Exception(
'Error accessing dashboard. Request to: %s "%s" (code: %d, message: %s)' % (
Expand All @@ -2929,8 +2976,7 @@ def search_jira_projects(self, body: dict):

def search_jira_issue_types(self, body: dict):
url = API_INTEGRATIONS + "/jira/issueTypesV2/search"
r = self.post(url, params={"customerGUID": self.customer_guid},
json=body)
r = self.post_with_ratelimit(url, params={"customerGUID": self.selected_tenant_id}, json=body)
if not 200 <= r.status_code < 300:
raise Exception(
'Error accessing dashboard. Request to: %s "%s" (code: %d, message: %s)' % (
Expand All @@ -2939,8 +2985,7 @@ def search_jira_issue_types(self, body: dict):

def search_jira_schema(self, body: dict):
url = API_INTEGRATIONS + "/jira/issueTypesV2/schema/search"
r = self.post(url, params={"customerGUID": self.customer_guid},
json=body)
r = self.post_with_ratelimit(url, params={"customerGUID": self.selected_tenant_id}, json=body)
if not 200 <= r.status_code < 300:
raise Exception(
'Error accessing dashboard. Request to: %s "%s" (code: %d, message: %s)' % (
Expand All @@ -2949,8 +2994,7 @@ def search_jira_schema(self, body: dict):

def search_jira_issue_field(self, body: dict):
url = API_INTEGRATIONS + "jira/issueTypes/fields/search"
r = self.post(url, params={"customerGUID": self.customer_guid},
json=body)
r = self.post_with_ratelimit(url, params={"customerGUID": self.selected_tenant_id}, json=body)
if not 200 <= r.status_code < 300:
raise Exception(
'Error accessing dashboard. Request to: %s "%s" (code: %d, message: %s)' % (
Expand All @@ -2959,8 +3003,7 @@ def search_jira_issue_field(self, body: dict):

def create_jira_issue(self, body: dict):
url = API_INTEGRATIONS + "/jira/issueV2"
r = self.post(url, params={"customerGUID": self.customer_guid},
json=body)
r = self.post_with_ratelimit(url, params={"customerGUID": self.selected_tenant_id}, json=body)
if not 200 <= r.status_code < 300:
raise Exception(
'Error accessing dashboard. Request to: %s "%s" (code: %d, message: %s)' % (
Expand Down
38 changes: 22 additions & 16 deletions tests_scripts/helm/jira_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,25 +144,23 @@ def setup_cluster_and_run_posture_scan(self):
cluster_name=cluster, wait_to_result=True, framework_name="AllControls"
)
assert report_guid != "", "report guid is empty"
self.report_guid = report_guid


# to make sure kubernetes resources are created
time.sleep(20)
Logger.logger.info(f"Trigger posture scan")
self.backend.trigger_posture_scan(cluster)

report_guid_new = self.get_report_guid(
cluster_name=cluster, wait_to_result=True, framework_name="AllControls", old_report_guid=report_guid
)
self.report_guid = report_guid_new
self.namespace = namespace
self.cluster = cluster

def create_jira_issue(self, issue, retries=3, sleep=45):
for i in range(retries):
Logger.logger.info(f"Create Jira issue attempt {i+1}")
try:
ticket = self.backend.create_jira_issue(issue)
assert ticket, "Jira ticket is empty"
return ticket
except (Exception, AssertionError) as e:
# we can get RetryAfter error, so we will retry
if "RetryAfter".lower() in str(e).lower():
Logger.logger.info(f"Jira issue creation failed with RetryAfter, retrying in {sleep} seconds")
time.sleep(sleep)
else:
raise e

return self.backend.create_jira_issue(issue)


def create_jira_issue_for_posture(self):
resource = self.get_posture_resource()
Expand Down Expand Up @@ -209,7 +207,15 @@ def get_posture_resource(self):

def create_jira_issue_for_security_risks(self):
security_risk_id = "R_0011"
resource = self.get_security_risks_resource(security_risk_id)

resource, t = self.wait_for_report(
report_type=self.get_security_risks_resource,
timeout=220,
sleep_interval=10,
security_risk_id=security_risk_id,
)

# resource = self.get_security_risks_resource(security_risk_id)
resourceHash = resource['k8sResourceHash']

Logger.logger.info(f"Create Jira issue for resource {resourceHash} and security_risk_id {security_risk_id}")
Expand Down
58 changes: 39 additions & 19 deletions tests_scripts/kubernetes/base_k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,23 @@ def get_nodes(self):
def get_all_pods(self):
return self.kubernetes_obj.client_CoreV1Api.list_pod_for_all_namespaces()

def get_all_pods_printable_details(self):
pods = self.get_all_pods()
message = ""
for pod in pods.items:
message += "Pod name: {0}, namespace: {1}, status: {2}\n".format(pod.metadata.name, pod.metadata.namespace,
pod.status.phase)
return message

def get_all_not_running_pods_describe_details(self):
pods = self.get_all_pods()
message = ""
for pod in pods.items:
if pod.status.phase != "Running":
message += "Pod name: {0}, namespace: {1}, status: {2}, pod: {3}\n".format(pod.metadata.name, pod.metadata.namespace,
pod.status.phase, pod)
return message

def get_pods(self, namespace: str = None, name: str = None, include_terminating: bool = True, wlid: str = None):
"""
:return: list of running pods
Expand All @@ -634,6 +651,9 @@ def get_pods(self, namespace: str = None, name: str = None, include_terminating:
if pods is None:
return []

# Safeguard: Explicit namespace filter
pods = [pod for pod in pods if pod.metadata.namespace == namespace]

if name:
if isinstance(name, str):
pods = [pod for pod in pods if name in pod.metadata.name]
Expand Down Expand Up @@ -702,9 +722,16 @@ def get_ready_pods(self, namespace, name: str = None):
"""
:return: list of running pods with all containers ready
"""
ready_pods = list(
filter(lambda pod: not any(container.ready is False for container in pod.status.container_statuses or []),
self.get_pods(namespace=namespace, name=name)))

pods = self.get_pods(namespace=namespace, name=name)

# Safeguard: Ensure namespace consistency
pods = [pod for pod in pods if pod.metadata.namespace == namespace]

ready_pods = [
pod for pod in pods
if all(container.ready for container in (pod.status.container_statuses or []))
]
return ready_pods

def restart_pods(self, wlid=None, namespace: str = None, name: str = None):
Expand Down Expand Up @@ -750,7 +777,7 @@ def verify_all_pods_are_running(self, workload, namespace: str, timeout=180):
timeout=timeout)
return replicas

def verify_running_pods(self, namespace: str, replicas: int = None, name: str = None, timeout=180,
def verify_running_pods(self, namespace: str, replicas: int = None, name: str = None, timeout=220,
comp_operator=operator.eq):
"""
compare number of expected running pods with actually running pods
Expand All @@ -767,13 +794,8 @@ def verify_running_pods(self, namespace: str, replicas: int = None, name: str =
running_pods = self.get_ready_pods(namespace=namespace, name=name)
if comp_operator(len(running_pods), replicas): # and len(running_pods) == len(total_pods):
Logger.logger.info(f"all pods are running after {delta_t} seconds")
result = subprocess.run("kubectl get pods -A", timeout=300, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
result = " ".join(result.stdout.splitlines())
Logger.logger.info(
"cluster state\n"
f"{result}"
)

all_pods_message = self.get_all_pods_printable_details()
Logger.logger.info(f"cluster states:\n{all_pods_message}")
return
delta_t = (datetime.now() - start).total_seconds()
time.sleep(10)
Expand All @@ -783,14 +805,12 @@ def verify_running_pods(self, namespace: str, replicas: int = None, name: str =
format(timeout,
KubectlWrapper.convert_workload_to_dict(non_running_pods, f_json=True, indent=2)))

result = subprocess.run("kubectl get pods -A", timeout=300, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
result = " ".join(result.stdout.splitlines())
Logger.logger.info(
"cluster state\n"
f"{result}"
)
raise Exception("wrong number of pods are running after {} seconds. expected: {}, running: {}, pods:{}"
.format(delta_t, replicas, len(running_pods), running_pods)) # , len(total_pods)))
all_pods_message = self.get_all_pods_printable_details()
Logger.logger.info(f"cluster states:\n{all_pods_message}")
not_running_pods_message = self.get_all_not_running_pods_describe_details()
Logger.logger.info(f"not running pods details:\n{not_running_pods_message}")
raise Exception("wrong number of pods are running after {} seconds. expected: {}, running: {}"
.format(delta_t, replicas, len(running_pods))) # , len(total_pods)))

def is_namespace_running(self, namespace):
for ns in self.kubernetes_obj.client_CoreV1Api.list_namespace().items:
Expand Down

0 comments on commit ef23e89

Please sign in to comment.