diff --git a/api/turing/cluster/spark.go b/api/turing/cluster/spark.go index 35fbb9999..83d5fcafd 100644 --- a/api/turing/cluster/spark.go +++ b/api/turing/cluster/spark.go @@ -2,6 +2,7 @@ package cluster import ( "fmt" + "os" "strconv" apisparkv1beta2 "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" @@ -192,7 +193,7 @@ func createSparkExecutor(request *CreateSparkRequest) (*apisparkv1beta2.Executor Path: serviceAccountMount, }, }, - Env: append(defaultEnvVars, getEnvVarFromRequest(request)...), + Env: getEnvVars(request), Labels: request.JobLabels, }, } @@ -242,7 +243,7 @@ func createSparkDriver(request *CreateSparkRequest) (*apisparkv1beta2.DriverSpec Path: serviceAccountMount, }, }, - Env: append(defaultEnvVars, getEnvVarFromRequest(request)...), + Env: getEnvVars(request), Labels: request.JobLabels, ServiceAccount: &request.ServiceAccountName, }, @@ -295,3 +296,13 @@ func toMegabyte(request string) (*string, error) { strVal := fmt.Sprintf("%sm", strconv.Itoa(int(inMegaBytes))) return &strVal, nil } + +func getEnvVars(request *CreateSparkRequest) []apicorev1.EnvVar { + envVars := defaultEnvVars + + for _, ev := range request.SparkInfraConfig.APIServerEnvVars { + envVars = append(envVars, apicorev1.EnvVar{Name: ev, Value: os.Getenv(ev)}) + } + + return append(envVars, getEnvVarFromRequest(request)...) +} diff --git a/api/turing/cluster/spark_test.go b/api/turing/cluster/spark_test.go index 629dc73e3..859f42166 100644 --- a/api/turing/cluster/spark_test.go +++ b/api/turing/cluster/spark_test.go @@ -1,6 +1,7 @@ package cluster import ( + "os" "testing" apisparkv1beta2 "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" @@ -124,6 +125,90 @@ func TestGetCPURequestAndLimit(t *testing.T) { } } +func TestGetEnvVars(t *testing.T) { + request := &CreateSparkRequest{ + JobName: jobName, + JobLabels: jobLabels, + JobImageRef: jobImageRef, + JobApplicationPath: jobApplicationPath, + JobArguments: jobArguments, + JobConfigMount: batch.JobConfigMount, + DriverCPURequest: cpuValue, + DriverMemoryRequest: memoryValue, + ExecutorCPURequest: cpuValue, + ExecutorMemoryRequest: memoryValue, + ExecutorReplica: executorReplica, + ServiceAccountName: serviceAccountName, + SparkInfraConfig: sparkInfraConfig, + EnvVars: &envVars, + } + tests := map[string]struct { + sparkInfraConfigAPIServerEnvVars []string + apiServerEnvVars []apicorev1.EnvVar + expectedEnvVars []apicorev1.EnvVar + }{ + "api server env vars specified": { + []string{"TEST_ENV_VAR_1"}, + []apicorev1.EnvVar{ + { + Name: "TEST_ENV_VAR_1", + Value: "TEST_VALUE_1", + }, + }, + []apicorev1.EnvVar{ + { + Name: envServiceAccountPathKey, + Value: envServiceAccountPath, + }, + { + Name: "TEST_ENV_VAR_1", + Value: "TEST_VALUE_1", + }, + { + Name: "foo", + Value: barString, + }, + }, + }, + "no api server env vars specified": { + []string{}, + []apicorev1.EnvVar{ + { + Name: "TEST_ENV_VAR_1", + Value: "TEST_VALUE_1", + }, + }, + []apicorev1.EnvVar{ + { + Name: envServiceAccountPathKey, + Value: envServiceAccountPath, + }, + { + Name: "foo", + Value: barString, + }, + }, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + for _, ev := range tt.apiServerEnvVars { + err := os.Setenv(ev.Name, ev.Value) + assert.NoError(t, err) + } + + request.SparkInfraConfig.APIServerEnvVars = tt.sparkInfraConfigAPIServerEnvVars + envVars := getEnvVars(request) + assert.Equal(t, tt.expectedEnvVars, envVars) + + for _, ev := range tt.apiServerEnvVars { + err := os.Unsetenv(ev.Name) + assert.NoError(t, err) + } + }) + } +} + var ( jobName = "jobname" jobImageRef = "gojek/nosuchimage" diff --git a/api/turing/config/config.go b/api/turing/config/config.go index 5477b332e..06820e769 100644 --- a/api/turing/config/config.go +++ b/api/turing/config/config.go @@ -223,6 +223,7 @@ type KanikoConfig struct { // SparkAppConfig contains the infra configurations that is unique to the user's Kubernetes type SparkAppConfig struct { NodeSelector map[string]string + APIServerEnvVars []string CorePerCPURequest float64 `validate:"required"` CPURequestToCPULimit float64 `validate:"required"` SparkVersion string `validate:"required"`