Skip to content

Commit

Permalink
get_series_parallel_decomposition fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Pietro Max Marsella committed Nov 27, 2024
1 parent e7055ad commit 8564b8b
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 74 deletions.
7 changes: 7 additions & 0 deletions lib/utils/include/utils/containers/find.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FIND_H

#include <algorithm>
#include <unordered_set>

namespace FlexFlow {

Expand All @@ -11,6 +12,12 @@ typename Container::const_iterator
return std::find(c.cbegin(), c.cend(), e);
}

template <typename V>
typename std::unordered_set<V>::const_iterator
find(std::unordered_set<V> const &c, V const &e) {
return c.find(e);
}

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &,
std::optional<ParallelReduction>
find_parallel_reduction(MultiDiGraphView const &);

std::unordered_map<DirectedEdge, std::unordered_set<MultiDiEdge>>
find_all_extended_parallel_reductions(MultiDiGraphView const &);

MultiDiEdge apply_parallel_reduction(MultiDiGraph &, ParallelReduction const &);

MultiDiEdge
apply_extended_parallel_reduction(MultiDiGraph &,
std::unordered_set<MultiDiEdge> const &);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,25 @@ std::unordered_multiset<Node> get_nodes(SeriesSplit const &);
std::unordered_multiset<Node> get_nodes(ParallelSplit const &);
std::unordered_multiset<Node> get_nodes(Node const &);

bool is_empty(Node const &node);
bool is_empty(SeriesSplit const &serial);
bool is_empty(ParallelSplit const &parallel);
bool is_empty(SeriesParallelDecomposition const &sp);

bool has_no_duplicate_nodes(SeriesParallelDecomposition const &sp);

SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp,
Node const &node);

// duplicate nodes within `sp` are counted multiple times
size_t num_nodes(SeriesParallelDecomposition const &sp);

SeriesParallelDecomposition serial_composition(
std::vector<SeriesParallelDecomposition> const &sp_compositions);
SeriesParallelDecomposition parallel_composition(
std::unordered_multiset<SeriesParallelDecomposition> const
&sp_compositions);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
#include "utils/containers/get_only.h"
#include "utils/containers/map_values.h"
#include "utils/containers/transform.h"
#include "utils/containers/unordered_multiset_of.h"
#include "utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.h"
#include "utils/graph/digraph/algorithms/transitive_reduction.h"
#include "utils/graph/instances/adjacency_multidigraph.h"
#include "utils/graph/multidigraph/algorithms/get_edges.h"
#include "utils/graph/multidigraph/multidiedge.dtg.h"
#include "utils/graph/node/algorithms.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h"
#include "utils/graph/series_parallel/parallel_reduction.h"
#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h"
#include "utils/graph/series_parallel/series_parallel_decomposition.h"
#include "utils/graph/series_parallel/series_reduction.h"

Expand All @@ -26,55 +30,52 @@ std::optional<SeriesParallelDecomposition>
if (!maybe_line_graph.has_value()) {
return std::nullopt;
}

maybe_line_graph.value();
});

MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of<AdjacencyMultiDiGraph>(
inverse_line_graph_result.graph);
std::unordered_map<MultiDiEdge, BinarySPDecompositionTree>
std::unordered_map<MultiDiEdge, SeriesParallelDecomposition>
ttsp_edge_to_sp_tree = map_values(
inverse_line_graph_result.inverse_edge_to_line_node_bidict
.as_unordered_map(),
[](Node const &n) { return BinarySPDecompositionTree{n}; });
[](Node const &n) { return SeriesParallelDecomposition{n}; });

while (true) {
assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size());
std::optional<ParallelReduction> maybe_parallel_reduction =
find_parallel_reduction(ttsp);
if (maybe_parallel_reduction.has_value()) {
ParallelReduction parallel_reduction = maybe_parallel_reduction.value();
auto [e1, e2] = parallel_reduction.edges.ordered();
MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction);
BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{
BinaryParallelSplit{
ttsp_edge_to_sp_tree.at(e1),
ttsp_edge_to_sp_tree.at(e2),
},
};
ttsp_edge_to_sp_tree.erase(e1);
ttsp_edge_to_sp_tree.erase(e2);
ttsp_edge_to_sp_tree.insert({merged, new_tree});

continue;
}

std::optional<SeriesReduction> maybe_series_reduction =
find_series_reduction(ttsp);
if (maybe_series_reduction.has_value()) {
SeriesReduction series_reduction = maybe_series_reduction.value();
MultiDiEdge e1 = series_reduction.first;
MultiDiEdge e2 = series_reduction.second;
MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction);
BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{
BinarySeriesSplit{
ttsp_edge_to_sp_tree.at(e1),
ttsp_edge_to_sp_tree.at(e2),
},
};

SeriesParallelDecomposition new_tree = serial_composition({
ttsp_edge_to_sp_tree.at(e1),
ttsp_edge_to_sp_tree.at(e2),
});

ttsp_edge_to_sp_tree.erase(e1);
ttsp_edge_to_sp_tree.erase(e2);
ttsp_edge_to_sp_tree.insert({merged, new_tree});

continue;
}
std::unordered_map<DirectedEdge, std::unordered_set<MultiDiEdge>>
parallel_reductions = find_all_extended_parallel_reductions(ttsp);
if (!parallel_reductions.empty()) {
for (auto const &[_, parallel_reduction] : parallel_reductions) {
MultiDiEdge merged =
apply_extended_parallel_reduction(ttsp, parallel_reduction);

SeriesParallelDecomposition new_tree = parallel_composition(transform(
unordered_multiset_of(parallel_reduction),
[&](MultiDiEdge const &e) { return ttsp_edge_to_sp_tree.at(e); }));
for (MultiDiEdge const &e : parallel_reduction) {
ttsp_edge_to_sp_tree.erase(e);
}
ttsp_edge_to_sp_tree.insert({merged, new_tree});
}
continue;
}

Expand All @@ -87,7 +88,7 @@ std::optional<SeriesParallelDecomposition>

MultiDiEdge e = get_only(get_edges(ttsp));
if (ttsp.get_multidiedge_src(e) != ttsp.get_multidiedge_dst(e)) {
return nary_sp_tree_from_binary(ttsp_edge_to_sp_tree.at(e));
return ttsp_edge_to_sp_tree.at(e);
}
}
}
Expand Down
63 changes: 44 additions & 19 deletions lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
#include "utils/graph/series_parallel/parallel_reduction.h"
#include "utils/graph/multidigraph/algorithms/get_edge_counts.h"
#include "utils/containers/get_one_of.h"
#include "utils/containers/group_by.h"
#include "utils/containers/unordered_set_of.h"
#include "utils/containers/values.h"
#include "utils/graph/digraph/directed_edge.dtg.h"
#include "utils/graph/multidigraph/algorithms/get_directed_edge.h"
#include "utils/graph/multidigraph/algorithms/get_edges.h"
#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h"
#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h"
#include "utils/graph/multidigraph/multidiedge.dtg.h"
#include "utils/graph/multidigraph/multidigraph.h"
#include "utils/graph/node/algorithms.h"
#include "utils/hash/unordered_set.h"
#include <unordered_map>
#include <unordered_set>

namespace FlexFlow {

Expand All @@ -15,31 +23,48 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &e1,
std::optional<ParallelReduction>
find_parallel_reduction(MultiDiGraphView const &g) {

for (auto const &[directed_edge, count] : get_edge_counts(g)) {

if (count <= 1) {
continue;
}

std::unordered_set<MultiDiEdge> const &outgoing_edges =
get_outgoing_edges(g, directed_edge.src);
for (MultiDiEdge const &e1 : outgoing_edges) {
for (MultiDiEdge const &e2 : outgoing_edges) {
if (e1 != e2 &&
g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) {
return make_parallel_reduction(e1, e2);
}
}
std::unordered_map<DirectedEdge, MultiDiEdge> seen;
for (MultiDiEdge const &edge : get_edges(g)) {
DirectedEdge diedge = get_directed_edge(g, edge);
if (seen.find(diedge) != seen.end()) {
return make_parallel_reduction(seen.at(diedge), edge);
}
seen.emplace(diedge, edge);
}

return std::nullopt;
}

std::unordered_map<DirectedEdge, std::unordered_set<MultiDiEdge>>
find_all_extended_parallel_reductions(MultiDiGraphView const &g) {
std::unordered_map<DirectedEdge, std::unordered_set<MultiDiEdge>>
parallel_groups = group_by(get_edges(g), [&](MultiDiEdge const &edge) {
return get_directed_edge(g, edge);
});

return filter(
parallel_groups,
[](std::pair<DirectedEdge, std::unordered_set<MultiDiEdge>> const
&group) { return group.second.size() > 1; });
}

MultiDiEdge apply_parallel_reduction(MultiDiGraph &g,
ParallelReduction const &r) {
g.remove_edge(r.edges.max());
return r.edges.min();
}

MultiDiEdge apply_extended_parallel_reduction(
MultiDiGraph &g, std::unordered_set<MultiDiEdge> const &parallel_edges) {

MultiDiEdge keep_edge = get_one_of(parallel_edges);

for (MultiDiEdge const &parallel_edge : parallel_edges) {
if (parallel_edge != keep_edge) {
g.remove_edge(parallel_edge);
}
}

return keep_edge;
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
#include "utils/graph/series_parallel/series_parallel_decomposition.h"
#include "utils/containers/all_of.h"
#include "utils/containers/extend.h"
#include "utils/containers/multiset_union.h"
#include "utils/containers/set_union.h"
#include "utils/containers/sum.h"
#include "utils/containers/transform.h"
#include "utils/containers/unordered_multiset_of.h"
#include "utils/containers/values.h"
#include "utils/containers/vector_of.h"
#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h"
#include "utils/hash/unordered_set.h"
#include "utils/variant.h"
#include <unordered_set>

namespace FlexFlow {

Expand Down Expand Up @@ -74,4 +79,60 @@ std::unordered_multiset<Node> get_nodes(Node const &node) {
return {node};
}

bool is_empty(Node const &node) {
return false;

Check warning on line 83 in lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc

View check run for this annotation

Codecov / codecov/patch

lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc#L82-L83

Added lines #L82 - L83 were not covered by tests
}

bool is_empty(SeriesSplit const &serial) {
return all_of(serial.children, [](auto const &child) {
return is_empty(widen<SeriesParallelDecomposition>(child));
});

Check warning on line 89 in lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc

View check run for this annotation

Codecov / codecov/patch

lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc#L86-L89

Added lines #L86 - L89 were not covered by tests
}

bool is_empty(ParallelSplit const &parallel) {
return all_of(parallel.get_children(), [](auto const &child) {
return is_empty(widen<SeriesParallelDecomposition>(child));
});

Check warning on line 95 in lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc

View check run for this annotation

Codecov / codecov/patch

lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc#L92-L95

Added lines #L92 - L95 were not covered by tests
}

bool is_empty(SeriesParallelDecomposition const &sp) {
return sp.visit<bool>([](auto const &t) { return is_empty(t); });

Check warning on line 99 in lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc

View check run for this annotation

Codecov / codecov/patch

lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc#L98-L99

Added lines #L98 - L99 were not covered by tests
}

SeriesParallelDecomposition serial_composition(
std::vector<SeriesParallelDecomposition> const &sp_compositions) {
std::vector<std::variant<ParallelSplit, Node>> composition{};
for (SeriesParallelDecomposition const &sp_comp : sp_compositions) {
if (sp_comp.has<SeriesSplit>()) {
extend(composition, sp_comp.get<SeriesSplit>().children);
} else if (sp_comp.has<ParallelSplit>()) {
composition.push_back(sp_comp.get<ParallelSplit>());
} else {
assert(sp_comp.has<Node>());
composition.push_back(sp_comp.get<Node>());
}
}
return SeriesParallelDecomposition{SeriesSplit{composition}};
}

SeriesParallelDecomposition parallel_composition(
std::unordered_multiset<SeriesParallelDecomposition> const
&sp_compositions) {
std::unordered_multiset<
std::variant<::FlexFlow::SeriesSplit, ::FlexFlow::Node>>
composition{};
for (SeriesParallelDecomposition const &sp_comp : sp_compositions) {
if (sp_comp.has<ParallelSplit>()) {
composition = multiset_union(composition,
sp_comp.get<ParallelSplit>().get_children());
} else if (sp_comp.has<SeriesSplit>()) {
composition.insert(sp_comp.get<SeriesSplit>());
} else {
assert(sp_comp.has<Node>());
composition.insert(sp_comp.get<Node>());
}
}
return SeriesParallelDecomposition(ParallelSplit{composition});
}

} // namespace FlexFlow
33 changes: 11 additions & 22 deletions lib/utils/src/utils/graph/series_parallel/series_reduction.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
#include "utils/graph/series_parallel/series_reduction.h"
#include "utils/containers/contains.h"
#include "utils/containers/get_only.h"
#include "utils/containers/require_same.h"
#include "utils/graph/multidigraph/algorithms/get_edges.h"
#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h"
#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h"
#include "utils/graph/multidigraph/multidigraph.h"
#include "utils/graph/multidigraph/multidigraph_view.h"
#include "utils/graph/node/algorithms.h"
#include <unordered_set>

namespace FlexFlow {

Expand All @@ -26,30 +32,13 @@ SeriesReduction make_series_reduction(MultiDiEdge const &e1,

std::optional<SeriesReduction>
find_series_reduction(MultiDiGraphView const &g) {
std::unordered_set<MultiDiEdge> edges = get_edges(g);

for (MultiDiEdge const &e1 : edges) {
for (MultiDiEdge const &e2 : edges) {
if (e1 == e2) {
continue;
}
Node e1_dst = g.get_multidiedge_dst(e1);
Node e2_src = g.get_multidiedge_src(e2);
if (e1_dst != e2_src) {
continue;
}

std::unordered_set<MultiDiEdge> outgoing = get_outgoing_edges(g, e1_dst);
std::unordered_set<MultiDiEdge> incoming = get_incoming_edges(g, e1_dst);

if (outgoing.size() > 1 || incoming.size() > 1) {
continue;
}

return SeriesReduction{e1, e2};
for (Node const &node : get_nodes(g)) {
if (get_incoming_edges(g, node).size() == 1 &&
get_outgoing_edges(g, node).size() == 1) {
return make_series_reduction(get_only(get_incoming_edges(g, node)),
get_only(get_outgoing_edges(g, node)));
}
}

return std::nullopt;
}

Expand Down
Loading

0 comments on commit 8564b8b

Please sign in to comment.