Skip to content

Commit

Permalink
Merge pull request #36 from LamaAni/add_resources_to_env
Browse files Browse the repository at this point in the history
Add resources to env
  • Loading branch information
LamaAni authored Feb 17, 2021
2 parents 6128981 + ea7cc7f commit 6cb58cd
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 41 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ legacy_job_task = KubernetesLegacyJobOperator(
NOTE: that the success/failure of the task is tracked only on the `first` resource, no matter of its kind. Currently native support exists for Pods and Jobs only, though you can always add a [custom resource](docs/custom_kinds.md).

```yaml
# First resource: this resource will be tracked by the operator. Other resources will not be tracked.
apiVersion: batch/v1
kind: Job
metadata:
Expand Down Expand Up @@ -121,6 +122,15 @@ spec:
targetPort: 8080
```
### YAML field augmentation
The operator auto augments the following fields on the first resource containers,
1. envs are added to all pods
1. special envs are added:
1. KUBERNETES_JOB_OPERATOR_RESOURCES - A list of all resource names.
All other changes are made on the main container (the first container) of the first resource only.
# Configuration
Airflow config extra sections,
Expand Down
56 changes: 40 additions & 16 deletions airflow_kubernetes_job_operator/kubernetes_job_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,28 +189,52 @@ def _create_job_name(cls, name):
max_length=DEFAULT_KUBERNETES_MAX_RESOURCE_NAME_LENGTH,
)

@classmethod
def _to_kubernetes_env_list(cls, envs: dict):
return [{"name": k, "value": f"{envs[k]}"} for k in envs.keys()]

def _get_kubernetes_env_list(self):
return [{"name": k, "value": f"{self.envs[k]}"} for k in self.envs.keys()]
return self._to_kubernetes_env_list(self.envs or {})

def _get_kubernetes_job_operator_envs(self):
body = self.job_runner.body
names = [r.get("metadata", {}).get("name", None) for r in body]
names = [n for n in names if n is not None]
return self._to_kubernetes_env_list(
{
"KUBERNETES_JOB_OPERATOR_RESOURCES": " ".join(names),
}
)

def update_override_params(self, o: dict):
def _update_container_yaml(self, container):
container["env"] = [
*self._get_kubernetes_job_operator_envs(),
*container.get("env", []),
*self._get_kubernetes_env_list(),
]

def _update_main_container_yaml(self, container: dict):
if self.command:
container["command"] = self.command
if self.arguments:
container["args"] = self.arguments
if self.image:
container["image"] = self.image
if self.image_pull_policy:
container["imagePullPolicy"] = self.image_pull_policy

def _update_override_params(self, o: dict):
if "spec" in o and "containers" in o.get("spec", {}):
containers: List[dict] = o["spec"]["containers"]
if isinstance(containers, list) and len(containers) > 0:
main_container = containers[0]
if self.command:
main_container["command"] = self.command
if self.arguments:
main_container["args"] = self.arguments
if self.envs:
env_list = [*main_container.get("env", []), *self._get_kubernetes_env_list()]
main_container["env"] = env_list
if self.image:
main_container["image"] = self.image
if self.image_pull_policy:
main_container["imagePullPolicy"] = self.image_pull_policy
for container in containers:
self._update_container_yaml(container=container)
self._update_main_container_yaml(containers[0])

## adding env resources
for c in o.values():
if isinstance(c, dict):
self.update_override_params(c)
self._update_override_params(c)

def _validate_job_runner(self):
if self._job_runner is not None:
Expand Down Expand Up @@ -270,7 +294,7 @@ def pre_execute(self, context):
self.prepare_and_update_body()

# write override params
self.update_override_params(self.job_runner.body[0])
self._update_override_params(self.job_runner.body[0])

# call parent.
return super().pre_execute(context)
Expand Down
5 changes: 3 additions & 2 deletions tests/dags/templates/test_job_with_service.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
apiVersion: batch/v1
kind: Job
metadata: {}
metadata:
name: thejob
spec:
template:
spec:
Expand Down Expand Up @@ -33,7 +34,7 @@ spec:
apiVersion: v1
kind: Service
metadata:
name: test-service
name: theservice
spec:
selector:
app: myapp
Expand Down
46 changes: 23 additions & 23 deletions tests/dags/test_job_operator_jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,30 @@
jinja_job_args={"test": "lama"},
)

bash_script = """
#/usr/bin/env bash
echo "Legacy start for taskid {{ti.task_id}} {{job.test}}"
cur_count=0
while true; do
cur_count=$((cur_count + 1))
if [ "$cur_count" -ge "$TIC_COUNT" ]; then
break
fi
date
sleep 1
done
# bash_script = """
# #/usr/bin/env bash
# echo "Legacy start for taskid {{ti.task_id}} {{job.test}}"
# cur_count=0
# while true; do
# cur_count=$((cur_count + 1))
# if [ "$cur_count" -ge "$TIC_COUNT" ]; then
# break
# fi
# date
# sleep 1
# done

echo "Complete"
"""
KubernetesLegacyJobOperator(
task_id="legacy-test-job-success",
image="{{default_image}}",
cmds=["bash", "-c", bash_script],
dag=dag,
is_delete_operator_pod=True,
env_vars=envs,
delete_policy=default_delete_policy,
)
# echo "Complete"
# """
# KubernetesLegacyJobOperator(
# task_id="legacy-test-job-success",
# image="{{default_image}}",
# cmds=["bash", "-c", bash_script],
# dag=dag,
# is_delete_operator_pod=True,
# env_vars=envs,
# delete_policy=default_delete_policy,
# )

if __name__ == "__main__":
dag.clear(reset_dag_runs=True)
Expand Down
30 changes: 30 additions & 0 deletions tests/dags/test_job_operator_with_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from utils import default_args
from datetime import timedelta
from airflow import DAG
from airflow_kubernetes_job_operator import (
KubernetesJobOperator,
JobRunnerDeletePolicy,
)

dag = DAG(
"kub-job-op-test-jinja",
default_args=default_args,
description="Test base job operator",
schedule_interval=None,
catchup=False,
)

with dag:
namespace = None
default_delete_policy = JobRunnerDeletePolicy.Never

KubernetesJobOperator(
task_id="test-job-success",
namespace=namespace,
body_filepath="./templates/test_job_with_service.yaml",
delete_policy=default_delete_policy,
)

if __name__ == "__main__":
dag.clear(reset_dag_runs=True)
dag.run()

0 comments on commit 6cb58cd

Please sign in to comment.