Skip to content

Commit

Permalink
feat(backend): Add Parallelism Limit to ParallelFor tasks. Fixes #8718 (
Browse files Browse the repository at this point in the history
#10798)

* feat(backend): Add Parallelism Limit to ParallelFor tasks

Signed-off-by: Giulio Frasca <[email protected]>

* feat(backend): Add intermediate Template for iterator Tasks in ArgoCompiler

Signed-off-by: Giulio Frasca <[email protected]>

* test: Add argoCompiler test case to validate individual parallel-limited tasks

Signed-off-by: Giulio Frasca <[email protected]>

* test: Update tests for ParallelFor loop update

Signed-off-by: Giulio Frasca <[email protected]>

* fix(backend): Fix broken dependantTasks in ParallelFor

Signed-off-by: Giulio Frasca <[email protected]>

* fix(backend): pass correct ParentDagID to iterator DAG

- Passthrough ParentDagID rather than DriverExecutionID to iterator such
  that iteration item correctly detects dependentTasks.
- Remove depends from iterator DAG as it is already handled by
  root-level task
- Update Iterator template names/nomenclature for clarity
- Update tests accordingly

Signed-off-by: Giulio Frasca <[email protected]>

* fix(backend): Remove DAG Driver from Iterator abstraction template

- Removes the Driver pod from the Iterator abstraction-layer template
  as it confuses MLMD and is purley an Argo implementation
- Drivers still used on the Component and Iteration-item templates

Signed-off-by: Giulio Frasca <[email protected]>

---------

Signed-off-by: Giulio Frasca <[email protected]>
  • Loading branch information
gmfrasca authored Dec 3, 2024
1 parent 6ebf4aa commit b7d8c97
Show file tree
Hide file tree
Showing 4 changed files with 901 additions and 4 deletions.
5 changes: 5 additions & 0 deletions backend/src/v2/compiler/argocompiler/argo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ func Test_argo_compiler(t *testing.T) {
platformSpecPath: "",
argoYAMLPath: "testdata/importer.yaml",
},
{
jobPath: "../testdata/multiple_parallel_loops.json",
platformSpecPath: "",
argoYAMLPath: "testdata/multiple_parallel_loops.yaml",
},
{
jobPath: "../testdata/create_mount_delete_dynamic_pvc.json",
platformSpecPath: "../testdata/create_mount_delete_dynamic_pvc_platform.json",
Expand Down
61 changes: 57 additions & 4 deletions backend/src/v2/compiler/argocompiler/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,62 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline
}
}()
componentName := task.GetComponentRef().GetName()
// Set up Loop Control Template
iteratorTasks, err := c.iterationItemTask("iteration", task, taskJson, parentDagID)
if err != nil {
return nil, err
}
loopTmpl := &wfapi.Template{
Inputs: wfapi.Inputs{
Parameters: []wfapi.Parameter{
{Name: paramParentDagID},
},
},
DAG: &wfapi.DAGTemplate{
Tasks: iteratorTasks,
},
}
parallelismLimit := int64(task.GetIteratorPolicy().GetParallelismLimit())
if parallelismLimit > 0 {
loopTmpl.Parallelism = &parallelismLimit
}

loopTmplName, err := c.addTemplate(loopTmpl, fmt.Sprintf("%s-%s-iterator", componentName, name))
if err != nil {
return nil, err
}

tasks = []wfapi.DAGTask{
{
Name: name + "-loop",
Template: loopTmplName,
Depends: depends(task.GetDependentTasks()),
Arguments: wfapi.Arguments{
Parameters: []wfapi.Parameter{
{
Name: paramParentDagID,
Value: wfapi.AnyStringPtr(parentDagID),
},
},
},
},
}
return tasks, nil
}

func (c *workflowCompiler) iterationItemTask(name string, task *pipelinespec.PipelineTaskSpec, taskJson string, parentDagID string) (tasks []wfapi.DAGTask, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("iterationItem task: %w", err)
}
}()
componentName := task.GetComponentRef().GetName()
componentSpecPlaceholder, err := c.useComponentSpec(componentName)
if err != nil {
return nil, err
}

// Set up Iteration (Single Task) Template
driverArgoName := name + "-driver"
driverInputs := dagDriverInputs{
component: componentSpecPlaceholder,
Expand All @@ -279,10 +331,10 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline
if err != nil {
return nil, err
}
driver.Depends = depends(task.GetDependentTasks())

iterationCount := intstr.FromString(driverOutputs.iterationCount)
iterationTasks, err := c.task(
"iteration",
"iteration-item",
task,
taskInputs{
parentDagID: inputParameter(paramParentDagID),
Expand Down Expand Up @@ -311,7 +363,8 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline
if task.GetTriggerPolicy().GetCondition() != "" {
when = driverOutputs.condition + " != false"
}
tasks = []wfapi.DAGTask{

iteratorTasks := []wfapi.DAGTask{
*driver,
{
Name: name + "-iterations",
Expand All @@ -330,7 +383,7 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline
WithSequence: &wfapi.Sequence{Count: &iterationCount},
},
}
return tasks, nil
return iteratorTasks, nil
}

type dagDriverOutputs struct {
Expand Down
Loading

0 comments on commit b7d8c97

Please sign in to comment.