Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: spark operator label #45353

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections.abc import Mapping
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s

Expand Down Expand Up @@ -174,12 +174,7 @@ def create_job_name(self):
return self._set_name(updated_name)

@staticmethod
def _get_pod_identifying_label_string(labels) -> str:
filtered_labels = {label_id: label for label_id, label in labels.items() if label_id != "try_number"}
return ",".join([label_id + "=" + label for label_id, label in sorted(filtered_labels.items())])

@staticmethod
def create_labels_for_pod(context: dict | None = None, include_try_number: bool = True) -> dict:
def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool = True) -> dict[str, str]:
"""
Generate labels for the pod to track the pod in case of Operator crash.

Expand All @@ -190,8 +185,9 @@ def create_labels_for_pod(context: dict | None = None, include_try_number: bool
if not context:
return {}

ti = context["ti"]
run_id = context["run_id"]
context_dict = cast(dict, context)
ti = context_dict["ti"]
run_id = context_dict["run_id"]

labels = {
"dag_id": ti.dag_id,
Expand All @@ -210,8 +206,8 @@ def create_labels_for_pod(context: dict | None = None, include_try_number: bool

# In the case of sub dags this is just useful
# TODO: Remove this when the minimum version of Airflow is bumped to 3.0
if getattr(context["dag"], "is_subdag", False):
labels["parent_dag_id"] = context["dag"].parent_dag.dag_id
if getattr(context_dict["dag"], "is_subdag", False):
labels["parent_dag_id"] = context_dict["dag"].parent_dag.dag_id
# Ensure that label is valid for Kube,
# and if not truncate/remove invalid chars and replace with short hash.
for label_id, label in labels.items():
Expand All @@ -232,9 +228,11 @@ def template_body(self):
"""Templated body for CustomObjectLauncher."""
return self.manage_template_specs()

def find_spark_job(self, context):
labels = self.create_labels_for_pod(context, include_try_number=False)
label_selector = self._get_pod_identifying_label_string(labels) + ",spark-role=driver"
def find_spark_job(self, context, exclude_checked: bool = True):
label_selector = (
self._build_find_pod_label_selector(context, exclude_checked=exclude_checked)
+ ",spark-role=driver"
)
pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items

pod = None
Expand Down
30 changes: 30 additions & 0 deletions providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,36 @@ def test_get_logs_from_driver(
)


def test_find_custom_pod_labels(
self,
mock_create_namespaced_crd,
mock_get_namespaced_custom_object_status,
mock_cleanup,
mock_create_job_name,
mock_get_kube_client,
mock_create_pod,
mock_await_pod_start,
mock_await_pod_completion,
mock_fetch_requested_container_logs,
data_file,
):
task_name = "test_find_custom_pod_labels"
job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text())

mock_create_job_name.return_value = task_name
op = SparkKubernetesOperator(
template_spec=job_spec,
kubernetes_conn_id="kubernetes_default_kube_config",
task_id=task_name,
get_logs=True,
)
context = create_context(op)
op.execute(context)
label_selector = op._build_find_pod_label_selector(context) + ",spark-role=driver"
op.find_spark_job(context)
mock_get_kube_client.list_namespaced_pod.assert_called_with("default", label_selector=label_selector)


@pytest.mark.db_test
def test_template_body_templating(create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
Expand Down