Skip to content

Commit

Permalink
fix: GetDownstreamEdges is not cycle safe (#1447)
Browse files Browse the repository at this point in the history
Signed-off-by: Nishchith Shetty <[email protected]>
(cherry picked from commit 4306357)
  • Loading branch information
inishchith authored and whynowy committed Jan 13, 2024
1 parent 1d83b51 commit 907949b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
13 changes: 9 additions & 4 deletions pkg/apis/numaflow/v1alpha1/pipeline_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,22 @@ func (p Pipeline) GetAllBuckets() []string {

// GetDownstreamEdges returns all the downstream edges of a vertex
func (p Pipeline) GetDownstreamEdges(vertexName string) []Edge {
var f func(vertexName string, edges *[]Edge)
f = func(vertexName string, edges *[]Edge) {
var f func(vertexName string, edges *[]Edge, visited map[string]bool)
f = func(vertexName string, edges *[]Edge, visited map[string]bool) {
if visited[vertexName] {
return
}
visited[vertexName] = true
for _, b := range p.ListAllEdges() {
if b.From == vertexName {
*edges = append(*edges, b)
f(b.To, edges)
f(b.To, edges, visited)
}
}
}
result := []Edge{}
f(vertexName, &result)
visited := make(map[string]bool)
f(vertexName, &result, visited)
return result
}

Expand Down
14 changes: 11 additions & 3 deletions pkg/apis/numaflow/v1alpha1/pipeline_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,18 +301,26 @@ func Test_GetDownstreamEdges(t *testing.T) {
Edges: []Edge{
{From: "input", To: "p1"},
{From: "p1", To: "p11"},
{From: "p1", To: "p1"},
{From: "p1", To: "p2"},
{From: "p2", To: "output"},
},
},
}
edges := pl.GetDownstreamEdges("input")
assert.Equal(t, 4, len(edges))
assert.Equal(t, 5, len(edges))
assert.Equal(t, edges, pl.ListAllEdges())
assert.Equal(t, edges[2], Edge{From: "p1", To: "p2"})
assert.Equal(t, edges[2], Edge{From: "p1", To: "p1"})
assert.Equal(t, edges[3], Edge{From: "p1", To: "p2"})

edges = pl.GetDownstreamEdges("p1")
assert.Equal(t, 3, len(edges))
assert.Equal(t, 4, len(edges))

edges = pl.GetDownstreamEdges("p2")
assert.Equal(t, 1, len(edges))

edges = pl.GetDownstreamEdges("p11")
assert.Equal(t, 0, len(edges))

edges = pl.GetDownstreamEdges("output")
assert.Equal(t, 0, len(edges))
Expand Down

0 comments on commit 907949b

Please sign in to comment.