diff --git a/README.md b/README.md index b737655..d7dd241 100755 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ legacy_job_task = KubernetesLegacyJobOperator( NOTE: that the success/failure of the task is tracked only on the `first` resource, no matter of its kind. Currently native support exists for Pods and Jobs only, though you can always add a [custom resource](docs/custom_kinds.md). ```yaml +# First resource: this resource will be tracked by the operator. Other resources will not be tracked. apiVersion: batch/v1 kind: Job metadata: @@ -121,6 +122,15 @@ spec: targetPort: 8080 ``` +### YAML field augmentation + +The operator auto augments the following fields on the first resource containers, +1. envs are added to all pods +1. special envs are added: + 1. KUBERNETES_JOB_OPERATOR_RESOURCES - A list of all resource names. + +All other changes are made on the main container (the first container) of the first resource only. + # Configuration Airflow config extra sections, diff --git a/airflow_kubernetes_job_operator/kubernetes_job_operator.py b/airflow_kubernetes_job_operator/kubernetes_job_operator.py index 3a64ff2..0c6c577 100755 --- a/airflow_kubernetes_job_operator/kubernetes_job_operator.py +++ b/airflow_kubernetes_job_operator/kubernetes_job_operator.py @@ -189,28 +189,52 @@ def _create_job_name(cls, name): max_length=DEFAULT_KUBERNETES_MAX_RESOURCE_NAME_LENGTH, ) + @classmethod + def _to_kubernetes_env_list(cls, envs: dict): + return [{"name": k, "value": f"{envs[k]}"} for k in envs.keys()] + def _get_kubernetes_env_list(self): - return [{"name": k, "value": f"{self.envs[k]}"} for k in self.envs.keys()] + return self._to_kubernetes_env_list(self.envs or {}) + + def _get_kubernetes_job_operator_envs(self): + body = self.job_runner.body + names = [r.get("metadata", {}).get("name", None) for r in body] + names = [n for n in names if n is not None] + return self._to_kubernetes_env_list( + { + "KUBERNETES_JOB_OPERATOR_RESOURCES": " ".join(names), + } + ) - def update_override_params(self, o: dict): + def _update_container_yaml(self, container): + container["env"] = [ + *self._get_kubernetes_job_operator_envs(), + *container.get("env", []), + *self._get_kubernetes_env_list(), + ] + + def _update_main_container_yaml(self, container: dict): + if self.command: + container["command"] = self.command + if self.arguments: + container["args"] = self.arguments + if self.image: + container["image"] = self.image + if self.image_pull_policy: + container["imagePullPolicy"] = self.image_pull_policy + + def _update_override_params(self, o: dict): if "spec" in o and "containers" in o.get("spec", {}): containers: List[dict] = o["spec"]["containers"] if isinstance(containers, list) and len(containers) > 0: - main_container = containers[0] - if self.command: - main_container["command"] = self.command - if self.arguments: - main_container["args"] = self.arguments - if self.envs: - env_list = [*main_container.get("env", []), *self._get_kubernetes_env_list()] - main_container["env"] = env_list - if self.image: - main_container["image"] = self.image - if self.image_pull_policy: - main_container["imagePullPolicy"] = self.image_pull_policy + for container in containers: + self._update_container_yaml(container=container) + self._update_main_container_yaml(containers[0]) + + ## adding env resources for c in o.values(): if isinstance(c, dict): - self.update_override_params(c) + self._update_override_params(c) def _validate_job_runner(self): if self._job_runner is not None: @@ -270,7 +294,7 @@ def pre_execute(self, context): self.prepare_and_update_body() # write override params - self.update_override_params(self.job_runner.body[0]) + self._update_override_params(self.job_runner.body[0]) # call parent. return super().pre_execute(context) diff --git a/tests/dags/templates/test_job_with_service.yaml b/tests/dags/templates/test_job_with_service.yaml index f371839..f580f96 100755 --- a/tests/dags/templates/test_job_with_service.yaml +++ b/tests/dags/templates/test_job_with_service.yaml @@ -1,6 +1,7 @@ apiVersion: batch/v1 kind: Job -metadata: {} +metadata: + name: thejob spec: template: spec: @@ -33,7 +34,7 @@ spec: apiVersion: v1 kind: Service metadata: - name: test-service + name: theservice spec: selector: app: myapp diff --git a/tests/dags/test_job_operator_jinja.py b/tests/dags/test_job_operator_jinja.py index c2f5053..6fb582e 100755 --- a/tests/dags/test_job_operator_jinja.py +++ b/tests/dags/test_job_operator_jinja.py @@ -40,30 +40,30 @@ jinja_job_args={"test": "lama"}, ) -bash_script = """ -#/usr/bin/env bash -echo "Legacy start for taskid {{ti.task_id}} {{job.test}}" -cur_count=0 -while true; do - cur_count=$((cur_count + 1)) - if [ "$cur_count" -ge "$TIC_COUNT" ]; then - break - fi - date - sleep 1 -done +# bash_script = """ +# #/usr/bin/env bash +# echo "Legacy start for taskid {{ti.task_id}} {{job.test}}" +# cur_count=0 +# while true; do +# cur_count=$((cur_count + 1)) +# if [ "$cur_count" -ge "$TIC_COUNT" ]; then +# break +# fi +# date +# sleep 1 +# done -echo "Complete" -""" -KubernetesLegacyJobOperator( - task_id="legacy-test-job-success", - image="{{default_image}}", - cmds=["bash", "-c", bash_script], - dag=dag, - is_delete_operator_pod=True, - env_vars=envs, - delete_policy=default_delete_policy, -) +# echo "Complete" +# """ +# KubernetesLegacyJobOperator( +# task_id="legacy-test-job-success", +# image="{{default_image}}", +# cmds=["bash", "-c", bash_script], +# dag=dag, +# is_delete_operator_pod=True, +# env_vars=envs, +# delete_policy=default_delete_policy, +# ) if __name__ == "__main__": dag.clear(reset_dag_runs=True) diff --git a/tests/dags/test_job_operator_with_service.py b/tests/dags/test_job_operator_with_service.py new file mode 100755 index 0000000..24f3859 --- /dev/null +++ b/tests/dags/test_job_operator_with_service.py @@ -0,0 +1,30 @@ +from utils import default_args +from datetime import timedelta +from airflow import DAG +from airflow_kubernetes_job_operator import ( + KubernetesJobOperator, + JobRunnerDeletePolicy, +) + +dag = DAG( + "kub-job-op-test-jinja", + default_args=default_args, + description="Test base job operator", + schedule_interval=None, + catchup=False, +) + +with dag: + namespace = None + default_delete_policy = JobRunnerDeletePolicy.Never + + KubernetesJobOperator( + task_id="test-job-success", + namespace=namespace, + body_filepath="./templates/test_job_with_service.yaml", + delete_policy=default_delete_policy, + ) + +if __name__ == "__main__": + dag.clear(reset_dag_runs=True) + dag.run()