diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index c0b90ebacb9a6..583388c6d14d6 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -19,7 +19,7 @@ from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s @@ -177,12 +177,7 @@ def create_job_name(self): return self._set_name(updated_name) @staticmethod - def _get_pod_identifying_label_string(labels) -> str: - filtered_labels = {label_id: label for label_id, label in labels.items() if label_id != "try_number"} - return ",".join([label_id + "=" + label for label_id, label in sorted(filtered_labels.items())]) - - @staticmethod - def create_labels_for_pod(context: dict | None = None, include_try_number: bool = True) -> dict: + def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool = True) -> dict[str, str]: """ Generate labels for the pod to track the pod in case of Operator crash. @@ -193,8 +188,9 @@ def create_labels_for_pod(context: dict | None = None, include_try_number: bool if not context: return {} - ti = context["ti"] - run_id = context["run_id"] + context_dict = cast(dict, context) + ti = context_dict["ti"] + run_id = context_dict["run_id"] labels = { "dag_id": ti.dag_id, @@ -213,8 +209,8 @@ def create_labels_for_pod(context: dict | None = None, include_try_number: bool # In the case of sub dags this is just useful # TODO: Remove this when the minimum version of Airflow is bumped to 3.0 - if getattr(context["dag"], "is_subdag", False): - labels["parent_dag_id"] = context["dag"].parent_dag.dag_id + if getattr(context_dict["dag"], "is_subdag", False): + labels["parent_dag_id"] = context_dict["dag"].parent_dag.dag_id # Ensure that label is valid for Kube, # and if not truncate/remove invalid chars and replace with short hash. for label_id, label in labels.items(): @@ -235,9 +231,11 @@ def template_body(self): """Templated body for CustomObjectLauncher.""" return self.manage_template_specs() - def find_spark_job(self, context): - labels = self.create_labels_for_pod(context, include_try_number=False) - label_selector = self._get_pod_identifying_label_string(labels) + ",spark-role=driver" + def find_spark_job(self, context, exclude_checked: bool = True): + label_selector = ( + self._build_find_pod_label_selector(context, exclude_checked=exclude_checked) + + ",spark-role=driver" + ) pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items pod = None diff --git a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py index 3c4b31893698c..9df12ce4fbdd8 100644 --- a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -699,6 +699,35 @@ def test_get_logs_from_driver( follow_logs=True, ) + def test_find_custom_pod_labels( + self, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_start, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + task_name = "test_find_custom_pod_labels" + job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) + + mock_create_job_name.return_value = task_name + op = SparkKubernetesOperator( + template_spec=job_spec, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id=task_name, + get_logs=True, + ) + context = create_context(op) + op.execute(context) + label_selector = op._build_find_pod_label_selector(context) + ",spark-role=driver" + op.find_spark_job(context) + mock_get_kube_client.list_namespaced_pod.assert_called_with("default", label_selector=label_selector) + @pytest.mark.db_test def test_template_body_templating(create_task_instance_of_operator, session):