Skip to content

Commit

Permalink
PR fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Pietro Max Marsella committed Jan 11, 2025
1 parent 7d262a8 commit 3cc3407
Show file tree
Hide file tree
Showing 37 changed files with 587 additions and 398 deletions.
4 changes: 2 additions & 2 deletions lib/utils/include/utils/graph/algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ std::unordered_set<Node> get_neighbors(DiGraphView const &, Node const &);
// &);

// return the set of nodes without incoming edges
std::unordered_set<Node> get_sources(DiGraphView const &);
std::unordered_set<Node> get_initial_nodes(DiGraphView const &);

// return the set of nodes without outgoing edges
std::unordered_set<Node> get_sinks(DiGraphView const &);
std::unordered_set<Node> get_terminal_nodes(DiGraphView const &);

// std::unordered_set<Node> get_closed_sources(OpenMultiDiGraphView const &g);
// std::unordered_set<Node> get_closed_sinks(OpenMultiDiGraphView const &g);
Expand Down
4 changes: 2 additions & 2 deletions lib/utils/include/utils/graph/digraph/algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
namespace FlexFlow {

std::unordered_set<DirectedEdge> get_edges(DiGraphView const &);
std::unordered_set<Node> get_sources(DiGraphView const &);
std::unordered_set<Node> get_sinks(DiGraphView const &);
std::unordered_set<Node> get_initial_nodes(DiGraphView const &);
std::unordered_set<Node> get_terminal_nodes(DiGraphView const &);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ std::unordered_set<MultiDiEdge> get_incoming_edges(MultiDiGraphView const &,
Node const &);

std::unordered_map<Node, std::unordered_set<MultiDiEdge>>
get_incoming_edges(MultiDiGraphView const &g);
get_incoming_edges(MultiDiGraphView const &g,
std::unordered_set<Node> const &nodes);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_ALGORITHMS_GET_OUTGOING_EDGES_H

#include "utils/graph/multidigraph/multidigraph_view.h"
#include <unordered_set>

namespace FlexFlow {

std::unordered_set<MultiDiEdge> get_outgoing_edges(MultiDiGraphView const &,
Node const &);

std::unordered_map<Node, std::unordered_set<MultiDiEdge>>
get_outgoing_edges(MultiDiGraphView const &g);
get_outgoing_edges(MultiDiGraphView const &g,
std::unordered_set<Node> const &ns);

} // namespace FlexFlow

Expand Down
1 change: 0 additions & 1 deletion lib/utils/include/utils/graph/node/node.struct.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ features = [
"hash",
"fmt",
"json",
"rapidcheck",
]

includes = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
namespace = "FlexFlow"
name = "ExtendedParallelReduction"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"utils/graph/multidigraph/multidiedge.dtg.h",
"<unordered_set>"
]

src_includes = [
"utils/hash/unordered_set.h",
"utils/fmt/unordered_set.h",
]

[[fields]]
name = "edges"
type = "std::unordered_set<::FlexFlow::MultiDiEdge>"
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
namespace = "FlexFlow"
name = "ExtendedSeriesReduction"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"utils/graph/multidigraph/multidiedge.dtg.h",
"<vector>"
]

src_includes = [
"utils/hash/vector.h",
"utils/fmt/vector.h",
]

[[fields]]
name = "edges"
type = "std::vector<::FlexFlow::MultiDiEdge>"
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
#include "utils/graph/digraph/digraph.h"
#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h"
#include "utils/optional.h"
#include <variant>
#include <vector>

namespace FlexFlow {

Expand Down
19 changes: 17 additions & 2 deletions lib/utils/include/utils/graph/series_parallel/parallel_reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,39 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H

#include "utils/graph/multidigraph/multidigraph.h"
#include "utils/graph/series_parallel/extended_parallel_reduction.dtg.h"
#include "utils/graph/series_parallel/parallel_reduction.dtg.h"
#include <optional>
#include <unordered_set>

namespace FlexFlow {

ParallelReduction make_parallel_reduction(MultiDiEdge const &,
MultiDiEdge const &);

std::optional<ParallelReduction>
find_parallel_reduction(MultiDiGraphView const &);

std::unordered_map<DirectedEdge, std::unordered_set<MultiDiEdge>>
/**
* @brief Finds all ExtendedParallelReduction for a given MultiDiGraph
* @details An ExtendedParallelReduction is a unordered collection of
* `MultiDiEdge`s such that they share a common source and destination node.
*/
std::unordered_set<ExtendedParallelReduction>
find_all_extended_parallel_reductions(MultiDiGraphView const &);

MultiDiEdge apply_parallel_reduction(MultiDiGraph &, ParallelReduction const &);

/**
* @brief Applies a given ExtendedParallelReduction in place to a given
* MultiDiGraph
* @details The reduction removes all but one `MultiDiEdge`, so that the source,
* destination nodes associated with the reduction become connected by a single
* edge.
*/
MultiDiEdge
apply_extended_parallel_reduction(MultiDiGraph &,
std::unordered_set<MultiDiEdge> const &);
ExtendedParallelReduction const &);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp,
// duplicate nodes within `sp` are counted multiple times
size_t num_nodes(SeriesParallelDecomposition const &sp);

SeriesParallelDecomposition serial_composition(
SeriesParallelDecomposition series_composition(
std::vector<SeriesParallelDecomposition> const &sp_compositions);
SeriesParallelDecomposition parallel_composition(
std::unordered_multiset<SeriesParallelDecomposition> const
Expand Down
44 changes: 41 additions & 3 deletions lib/utils/include/utils/graph/series_parallel/series_reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "utils/graph/multidigraph/multidiedge.dtg.h"
#include "utils/graph/multidigraph/multidigraph.h"
#include "utils/graph/series_parallel/extended_series_reduction.dtg.h"
#include "utils/graph/series_parallel/series_reduction.dtg.h"
#include "utils/hash/vector.h"

Expand All @@ -15,13 +16,50 @@ Node get_center_node(MultiDiGraphView const &, SeriesReduction const &);
SeriesReduction make_series_reduction(MultiDiEdge const &, MultiDiEdge const &);
std::optional<SeriesReduction> find_series_reduction(MultiDiGraphView const &);

std::unordered_set<std::vector<MultiDiEdge>>
/**
* @brief Finds all the ExtendedSeriesReduction structures in a given graph.
*
* @details An `ExtendedSeriesReduction` is an ordered collection of
* `MultiDiEdges` such that:
* - The destination node of the nth edge is the same as the source node of the
* (n+1)th edge.
* - Such a node (intermediate node) has exactly two edges: one incoming (nth
* edge) and one outgoing ((n+1)th edge).
*
* For example, in the following graph:
*
* A -> B -> D -> E
* \ /
* -> C ->
*
* We have that [(A,B), (B,D), (D,E)] and [(A,C), (C,E)] both constitute
* `ExtendedSeriesReduction`.
*/
std::unordered_set<ExtendedSeriesReduction>
find_all_extended_series_reductions(MultiDiGraphView const &g);

MultiDiEdge apply_series_reduction(MultiDiGraph &, SeriesReduction const &);

MultiDiEdge apply_extended_series_reduction(
MultiDiGraph &g, std::vector<MultiDiEdge> const &series_edges);
/**
* @brief Applies a given ExtendedSeriesReduction in-place to a given graph.
*
* For example, in the following graph:
*
* A -> B -> D -> E
* \ /
* -> C ->
*
* Given the ExtendedSeriesReduction [(A,B), (B,D), (D,E)], the intermediate
*nodes B, D, will be deleted, and the resulting graph will be:
*
* A ----> E
* \ /
* -> C ->
*
**/
MultiDiEdge
apply_extended_series_reduction(MultiDiGraph &g,
ExtendedSeriesReduction const &reduction);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ features = [
"ord",
"hash",
"fmt",
"rapidcheck"
]

includes = [
Expand Down
6 changes: 3 additions & 3 deletions lib/utils/src/utils/graph/digraph/algorithms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ std::unordered_set<DirectedEdge> get_edges(DiGraphView const &g) {
return g.query_edges(directed_edge_query_all());
}

std::unordered_set<Node> get_sinks(DiGraphView const &g) {
return get_sources(flipped(g));
std::unordered_set<Node> get_terminal_nodes(DiGraphView const &g) {
return get_initial_nodes(flipped(g));
}

std::unordered_set<Node> get_sources(DiGraphView const &g) {
std::unordered_set<Node> get_initial_nodes(DiGraphView const &g) {
std::unordered_set<Node> all_nodes = get_nodes(g);
std::unordered_set<Node> with_incoming_edge =
transform(get_edges(g), [](DirectedEdge const &e) { return e.dst; });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ std::optional<CompleteBipartiteCompositeDecomposition>
extend(already_in_a_tail, tail);
}

assert(already_in_a_head == set_minus(get_nodes(g), get_sinks(g)));
assert(already_in_a_tail == set_minus(get_nodes(g), get_sources(g)));
assert(already_in_a_head == set_minus(get_nodes(g), get_terminal_nodes(g)));
assert(already_in_a_tail == set_minus(get_nodes(g), get_initial_nodes(g)));

return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace FlexFlow {

bool is_complete_bipartite_digraph(DiGraphView const &g) {
return is_complete_bipartite_digraph(g, get_sources(g));
return is_complete_bipartite_digraph(g, get_initial_nodes(g));
}

bool is_complete_bipartite_digraph(DiGraphView const &g,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ namespace FlexFlow {

std::unordered_map<Node, std::unordered_set<Node>>
get_dominators_map(DiGraphView const &g) {
std::unordered_set<Node> sources = get_sources(g);
std::unordered_set<Node> initial_nodes = get_initial_nodes(g);

std::queue<Node> queue;

for (Node src : get_sources(g)) {
for (Node src : get_initial_nodes(g)) {
queue.push(src);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace FlexFlow {

static std::vector<Node>
get_unchecked_topological_ordering(DiGraphView const &g) {
auto dfs_view = unchecked_dfs(g, get_sources(g));
auto dfs_view = unchecked_dfs(g, get_initial_nodes(g));
std::vector<Node> order;
std::unordered_set<Node> seen;
std::unordered_map<Node, std::unordered_set<Node>> predecessors =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ std::optional<InverseLineGraphResult>
return get_component_containing_node_in_tail(cbc_decomposition, n).value();
};

std::unordered_set<Node> sources = get_sources(view);
std::unordered_set<Node> sinks = get_sinks(view);
std::unordered_set<Node> initial_nodes = get_initial_nodes(view);
std::unordered_set<Node> terminal_nodes = get_terminal_nodes(view);

auto src_for_node = [&](Node const &v) -> Node {
if (contains(sources, v)) {
if (contains(initial_nodes, v)) {
return alpha;
} else {
return component_nodes.at_l(t(v));
}
};

auto dst_for_node = [&](Node const &v) -> Node {
if (contains(sinks, v)) {
if (contains(terminal_nodes, v)) {
return omega;
} else {
return component_nodes.at_l(h(v));
Expand Down
6 changes: 3 additions & 3 deletions lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ std::optional<bool> is_acyclic(DiGraphView const &g) {
if (num_nodes(g) == 0) {
return std::nullopt;
}
std::unordered_set<Node> sources = get_sources(g);
if (sources.size() == 0) {
std::unordered_set<Node> initial_nodes = get_initial_nodes(g);
if (initial_nodes.size() == 0) {
return false;
}
auto dfs_view = unchecked_dfs(g, sources);
auto dfs_view = unchecked_dfs(g, initial_nodes);
std::unordered_set<Node> seen;
for (unchecked_dfs_iterator it = dfs_view.begin(); it != dfs_view.end();
it++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#include "utils/containers/group_by.h"
#include "utils/graph/multidigraph/algorithms/get_edges.h"
#include "utils/graph/multidigraph/multidiedge.dtg.h"
#include "utils/graph/multidigraph/multidiedge_query.dtg.h"
#include "utils/graph/node/algorithms.h"
#include "utils/graph/query_set.h"

namespace FlexFlow {

Expand All @@ -12,12 +14,14 @@ std::unordered_set<MultiDiEdge> get_incoming_edges(MultiDiGraphView const &g,
}

std::unordered_map<Node, std::unordered_set<MultiDiEdge>>
get_incoming_edges(MultiDiGraphView const &g) {
get_incoming_edges(MultiDiGraphView const &g,
std::unordered_set<Node> const &ns) {
std::unordered_map<Node, std::unordered_set<MultiDiEdge>> result =
group_by(get_edges(g),
group_by(g.query_edges(MultiDiEdgeQuery{query_set<Node>::matchall(),
query_set<Node>{ns}}),
[&](MultiDiEdge const &e) { return g.get_multidiedge_dst(e); });

for (Node const &n : get_nodes(g)) {
for (Node const &n : ns) {
result[n];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "utils/containers/group_by.h"
#include "utils/graph/multidigraph/algorithms/get_edges.h"
#include "utils/graph/node/algorithms.h"
#include <unordered_set>
namespace FlexFlow {

std::unordered_set<MultiDiEdge> get_outgoing_edges(MultiDiGraphView const &g,
Expand All @@ -10,12 +11,14 @@ std::unordered_set<MultiDiEdge> get_outgoing_edges(MultiDiGraphView const &g,
}

std::unordered_map<Node, std::unordered_set<MultiDiEdge>>
get_outgoing_edges(MultiDiGraphView const &g) {
get_outgoing_edges(MultiDiGraphView const &g,
std::unordered_set<Node> const &ns) {
std::unordered_map<Node, std::unordered_set<MultiDiEdge>> result =
group_by(get_edges(g),
group_by(g.query_edges(MultiDiEdgeQuery{query_set<Node>{ns},
query_set<Node>::matchall()}),
[&](MultiDiEdge const &e) { return g.get_multidiedge_src(e); });

for (Node const &n : get_nodes(g)) {
for (Node const &n : ns) {
result[n];
}

Expand Down
Loading

0 comments on commit 3cc3407

Please sign in to comment.