Skip to content

Commit

Permalink
Pass an error message to the failure node (#6181)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Wu <[email protected]>
  • Loading branch information
popojk authored Feb 3, 2025
1 parent 5ec1a60 commit 9dbce43
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package executors
import (
"context"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
)

type FailureNodeLookup struct {
NodeLookup
FailureNode v1alpha1.ExecutableNode
FailureNodeStatus v1alpha1.ExecutableNodeStatus
OriginalError *core.ExecutionError
}

func (f FailureNodeLookup) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) {
Expand All @@ -35,10 +37,15 @@ func (f FailureNodeLookup) FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, erro
return nil, nil
}

func NewFailureNodeLookup(nodeLookup NodeLookup, failureNode v1alpha1.ExecutableNode, failureNodeStatus v1alpha1.ExecutableNodeStatus) NodeLookup {
func (f FailureNodeLookup) GetOriginalError() (*core.ExecutionError, error) {
return f.OriginalError, nil
}

func NewFailureNodeLookup(nodeLookup NodeLookup, failureNode v1alpha1.ExecutableNode, failureNodeStatus v1alpha1.ExecutableNodeStatus, originalError *core.ExecutionError) NodeLookup {
return FailureNodeLookup{
NodeLookup: nodeLookup,
FailureNode: failureNode,
FailureNodeStatus: failureNodeStatus,
OriginalError: originalError,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
"github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks"
)
Expand All @@ -26,26 +27,33 @@ func TestNewFailureNodeLookup(t *testing.T) {
nl := nl{}
en := en{}
ns := ns{}
nodeLoopUp := NewFailureNodeLookup(nl, en, ns)
execErr := &core.ExecutionError{
Message: "node failure",
}
nodeLoopUp := NewFailureNodeLookup(nl, en, ns, execErr)
assert.NotNil(t, nl)
typed := nodeLoopUp.(FailureNodeLookup)
assert.Equal(t, nl, typed.NodeLookup)
assert.Equal(t, en, typed.FailureNode)
assert.Equal(t, ns, typed.FailureNodeStatus)
assert.Equal(t, execErr, typed.OriginalError)
}

func TestNewTestFailureNodeLookup(t *testing.T) {
n := &mocks.ExecutableNode{}
ns := &mocks.ExecutableNodeStatus{}
failureNodeID := "fn1"
originalErr := &core.ExecutionError{
Message: "node failure",
}
nl := NewTestNodeLookup(
map[string]v1alpha1.ExecutableNode{v1alpha1.StartNodeID: n, failureNodeID: n},
map[string]v1alpha1.ExecutableNodeStatus{v1alpha1.StartNodeID: ns, failureNodeID: ns},
)

assert.NotNil(t, nl)

failureNodeLookup := NewFailureNodeLookup(nl, n, ns)
failureNodeLookup := NewFailureNodeLookup(nl, n, ns, originalErr).(FailureNodeLookup)
r, ok := failureNodeLookup.GetNode(v1alpha1.StartNodeID)
assert.True(t, ok)
assert.Equal(t, n, r)
Expand All @@ -64,4 +72,9 @@ func TestNewTestFailureNodeLookup(t *testing.T) {
nodeIDs, err = failureNodeLookup.FromNode(failureNodeID)
assert.Nil(t, nodeIDs)
assert.Nil(t, err)

oe, err := failureNodeLookup.GetOriginalError()
assert.NotNil(t, oe)
assert.Equal(t, originalErr, oe)
assert.Nil(t, err)
}
13 changes: 13 additions & 0 deletions flytepropeller/pkg/controller/nodes/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,19 @@ func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructur
}

if nodeInputs != nil {
// Resolve error input if current node is an on failure node
failureNodeLookup, ok := nCtx.ContextualNodeLookup().(executors.FailureNodeLookup)
if ok {
originalErr, err := failureNodeLookup.GetOriginalError()
if err != nil {
return handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "FailureNodeError", err.Error(), nil), nil
} else if originalErr != nil {
err = ResolveOnFailureNodeInput(ctx, nodeInputs, node.GetID(), originalErr)
if err != nil {
return handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "FailureNodeInputResolvingError", err.Error(), nil), nil
}
}
}
p := common.CheckOffloadingCompat(ctx, nCtx, nodeInputs.GetLiterals(), node, c.literalOffloadingConfig)
if p != nil {
return *p, nil
Expand Down
45 changes: 45 additions & 0 deletions flytepropeller/pkg/controller/nodes/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,48 @@ func Resolve(ctx context.Context, outputResolver OutputResolver, nl executors.No
Literals: literalMap,
}, nil
}

func ResolveErrorInputLiteralData(ctx context.Context, literals map[string]*core.Literal, nodeID v1alpha1.NodeID, execErr *core.ExecutionError) {
if literal, exists := literals["err"]; exists {
// make new Scalar for literal map
errorUnion := &core.Scalar_Union{
Union: &core.Union{
Value: &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Error{
Error: &core.Error{
Message: execErr.GetMessage(),
FailedNodeId: nodeID,
},
},
},
},
},
Type: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_ERROR,
},
Structure: &core.TypeStructure{
Tag: "FlyteError",
},
},
},
}
literal.Value = &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: errorUnion,
},
}
}
}

func ResolveOnFailureNodeInput(ctx context.Context, nodeInputs *core.LiteralMap, nodeID v1alpha1.NodeID, execErr *core.ExecutionError) error {
literals := nodeInputs.GetLiterals()
if literals != nil {
ResolveErrorInputLiteralData(ctx, literals, nodeID, execErr)
} else {
return errors.Errorf(errors.BindingResolutionError, "id", nodeID, "Node inputs are empty")
}
return nil
}
71 changes: 71 additions & 0 deletions flytepropeller/pkg/controller/nodes/resolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,74 @@ func TestResolve(t *testing.T) {
})

}

func TestResolveErrorInputLiteralData(t *testing.T) {
ctx := context.Background()
t.Run("ResolveErrorInputsLiteralData", func(t *testing.T) {
noneLiteral, _ := coreutils.MakeLiteral(nil)
inputLiterals := make(map[string]*core.Literal, 1)
inputLiterals["err"] = &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Value: noneLiteral,
Type: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_NONE,
},
Structure: &core.TypeStructure{
Tag: "none",
},
},
},
},
},
},
}
nID := "fn"
execErr := &core.ExecutionError{
Message: "node failure",
}
expectedLiterals := make(map[string]*core.Literal, 1)
errorLiteral, err := coreutils.MakeLiteral(&core.Error{Message: execErr.GetMessage(), FailedNodeId: nID})
assert.NoError(t, err)
expectedLiterals["err"] = &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Value: errorLiteral,
Type: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_ERROR,
},
Structure: &core.TypeStructure{
Tag: "FlyteError",
},
},
},
},
},
},
}
// Execute resolve
ResolveErrorInputLiteralData(ctx, inputLiterals, nID, execErr)
flyteassert.EqualLiterals(t, inputLiterals["err"], expectedLiterals["err"])
})
}

func TestResolveOnFailureNodeInput(t *testing.T) {
ctx := context.Background()
t.Run("ResolveWithNilInputs", func(t *testing.T) {
nID := "fn"
execErr := &core.ExecutionError{
Message: "node failure",
}
nilLiteralMap := &core.LiteralMap{
Literals: nil,
}
err := ResolveOnFailureNodeInput(ctx, nilLiteralMap, nID, execErr)
assert.Error(t, err)
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context,
status := nCtx.NodeStatus()
subworkflowNodeLookup := executors.NewNodeLookup(subworkflow, status, subworkflow)
failureNodeStatus := status.GetNodeExecutionStatus(ctx, failureNode.GetID())
failureNodeLookup := executors.NewFailureNodeLookup(subworkflowNodeLookup, failureNode, failureNodeStatus)
failureNodeLookup := executors.NewFailureNodeLookup(subworkflowNodeLookup, failureNode, failureNodeStatus, originalError)

state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, failureNodeLookup, failureNodeLookup, failureNode)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion flytepropeller/pkg/controller/workflow/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func (c *workflowExecutor) handleFailureNode(ctx context.Context, w *v1alpha1.Fl
execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow())

failureNodeStatus := w.GetExecutionStatus().GetNodeExecutionStatus(ctx, failureNode.GetID())
failureNodeLookup := executors.NewFailureNodeLookup(w, failureNode, failureNodeStatus)
failureNodeLookup := executors.NewFailureNodeLookup(w, failureNode, failureNodeStatus, execErr)
state, handlerErr := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, failureNodeLookup, failureNodeLookup, failureNode)
c.updateExecutionStats(ctx, execcontext)

Expand Down
42 changes: 42 additions & 0 deletions flytepropeller/pkg/utils/assert/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,27 @@ import (
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
)

func EqualLiteralType(t *testing.T, lt1 *core.LiteralType, lt2 *core.LiteralType) {
if !assert.Equal(t, lt1 == nil, lt2 == nil) {
assert.FailNow(t, "One of the values is nil")
}
assert.Equal(t, reflect.TypeOf(lt1.GetType()), reflect.TypeOf(lt2.GetType()))
switch lt1.GetType().(type) {
case *core.LiteralType_Simple:
assert.Equal(t, lt1.GetType().(*core.LiteralType_Simple).Simple, lt2.GetType().(*core.LiteralType_Simple).Simple)
default:
assert.FailNow(t, "Not yet implemented for types %v", reflect.TypeOf(lt1.GetType()))
}
structure1 := lt1.GetStructure()
structure2 := lt2.GetStructure()
if (structure1 == nil && structure2 != nil) || (structure1 != nil && structure2 == nil) {
assert.FailNow(t, "One of the structures is nil while the other is not")
}
if structure1 != nil && structure2 != nil {
assert.Equal(t, structure1.GetTag(), structure2.GetTag())
}
}

func EqualPrimitive(t *testing.T, p1 *core.Primitive, p2 *core.Primitive) {
if !assert.Equal(t, p1 == nil, p2 == nil) {
assert.FailNow(t, "One of the values is nil")
Expand All @@ -27,6 +48,23 @@ func EqualPrimitive(t *testing.T, p1 *core.Primitive, p2 *core.Primitive) {
}
}

func EqualError(t *testing.T, e1 *core.Error, e2 *core.Error) {
if !assert.Equal(t, e1 == nil, e2 == nil) {
assert.FailNow(t, "One of the values is nil")
}
assert.Equal(t, e1.GetMessage(), e2.GetMessage())
assert.Equal(t, e1.GetFailedNodeId(), e2.GetFailedNodeId())
}

func EqualUnion(t *testing.T, u1 *core.Union, u2 *core.Union) {
if !assert.Equal(t, u1 == nil, u2 == nil) {
assert.FailNow(t, "One of the values is nil")
}
assert.Equal(t, reflect.TypeOf(u1.GetValue()), reflect.TypeOf(u2.GetValue()))
EqualLiterals(t, u1.GetValue(), u2.GetValue())
EqualLiteralType(t, u1.GetType(), u2.GetType())
}

func EqualScalar(t *testing.T, p1 *core.Scalar, p2 *core.Scalar) {
if !assert.Equal(t, p1 == nil, p2 == nil) {
assert.FailNow(t, "One of the values is nil")
Expand All @@ -38,6 +76,10 @@ func EqualScalar(t *testing.T, p1 *core.Scalar, p2 *core.Scalar) {
switch p1.GetValue().(type) {
case *core.Scalar_Primitive:
EqualPrimitive(t, p1.GetPrimitive(), p2.GetPrimitive())
case *core.Scalar_Error:
EqualError(t, p1.GetError(), p2.GetError())
case *core.Scalar_Union:
EqualUnion(t, p1.GetUnion(), p2.GetUnion())
default:
assert.FailNow(t, "Not yet implemented for types %v", reflect.TypeOf(p1.GetValue()))
}
Expand Down

0 comments on commit 9dbce43

Please sign in to comment.