From 8564b8b979419ced5a67774c89777d380061e300 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Wed, 27 Nov 2024 11:38:54 -0800 Subject: [PATCH] get_series_parallel_decomposition fix --- lib/utils/include/utils/containers/find.h | 7 +++ .../series_parallel/parallel_reduction.h | 7 +++ .../series_parallel_decomposition.h | 19 ++++++ .../get_series_parallel_decomposition.cc | 61 +++++++++--------- .../series_parallel/parallel_reduction.cc | 63 +++++++++++++------ .../series_parallel_decomposition.cc | 61 ++++++++++++++++++ .../graph/series_parallel/series_reduction.cc | 33 ++++------ .../test/src/utils/containers/contains.cc | 15 ++++- 8 files changed, 192 insertions(+), 74 deletions(-) diff --git a/lib/utils/include/utils/containers/find.h b/lib/utils/include/utils/containers/find.h index eed5f8453c..7b103fed16 100644 --- a/lib/utils/include/utils/containers/find.h +++ b/lib/utils/include/utils/containers/find.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FIND_H #include +#include namespace FlexFlow { @@ -11,6 +12,12 @@ typename Container::const_iterator return std::find(c.cbegin(), c.cend(), e); } +template +typename std::unordered_set::const_iterator + find(std::unordered_set const &c, V const &e) { + return c.find(e); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h index 3fc1347ee5..0b3c7f3619 100644 --- a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h @@ -12,8 +12,15 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &, std::optional find_parallel_reduction(MultiDiGraphView const &); +std::unordered_map> + find_all_extended_parallel_reductions(MultiDiGraphView const &); + MultiDiEdge apply_parallel_reduction(MultiDiGraph &, ParallelReduction const &); +MultiDiEdge + apply_extended_parallel_reduction(MultiDiGraph &, + std::unordered_set const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h index 52d2cb7236..d56d4a55f7 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -17,6 +17,25 @@ std::unordered_multiset get_nodes(SeriesSplit const &); std::unordered_multiset get_nodes(ParallelSplit const &); std::unordered_multiset get_nodes(Node const &); +bool is_empty(Node const &node); +bool is_empty(SeriesSplit const &serial); +bool is_empty(ParallelSplit const ¶llel); +bool is_empty(SeriesParallelDecomposition const &sp); + +bool has_no_duplicate_nodes(SeriesParallelDecomposition const &sp); + +SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp, + Node const &node); + +// duplicate nodes within `sp` are counted multiple times +size_t num_nodes(SeriesParallelDecomposition const &sp); + +SeriesParallelDecomposition serial_composition( + std::vector const &sp_compositions); +SeriesParallelDecomposition parallel_composition( + std::unordered_multiset const + &sp_compositions); + } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index cd29af59a0..7a5cb1ea82 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -2,14 +2,18 @@ #include "utils/containers/get_only.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" #include "utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.h" #include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" #include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/graph/series_parallel/series_reduction.h" @@ -26,39 +30,18 @@ std::optional if (!maybe_line_graph.has_value()) { return std::nullopt; } - maybe_line_graph.value(); }); MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); - std::unordered_map + std::unordered_map ttsp_edge_to_sp_tree = map_values( inverse_line_graph_result.inverse_edge_to_line_node_bidict .as_unordered_map(), - [](Node const &n) { return BinarySPDecompositionTree{n}; }); + [](Node const &n) { return SeriesParallelDecomposition{n}; }); while (true) { - assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); - std::optional maybe_parallel_reduction = - find_parallel_reduction(ttsp); - if (maybe_parallel_reduction.has_value()) { - ParallelReduction parallel_reduction = maybe_parallel_reduction.value(); - auto [e1, e2] = parallel_reduction.edges.ordered(); - MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); - BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ - BinaryParallelSplit{ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }, - }; - ttsp_edge_to_sp_tree.erase(e1); - ttsp_edge_to_sp_tree.erase(e2); - ttsp_edge_to_sp_tree.insert({merged, new_tree}); - - continue; - } - std::optional maybe_series_reduction = find_series_reduction(ttsp); if (maybe_series_reduction.has_value()) { @@ -66,15 +49,33 @@ std::optional MultiDiEdge e1 = series_reduction.first; MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ - BinarySeriesSplit{ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }, - }; + + SeriesParallelDecomposition new_tree = serial_composition({ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }); + ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); + + continue; + } + std::unordered_map> + parallel_reductions = find_all_extended_parallel_reductions(ttsp); + if (!parallel_reductions.empty()) { + for (auto const &[_, parallel_reduction] : parallel_reductions) { + MultiDiEdge merged = + apply_extended_parallel_reduction(ttsp, parallel_reduction); + + SeriesParallelDecomposition new_tree = parallel_composition(transform( + unordered_multiset_of(parallel_reduction), + [&](MultiDiEdge const &e) { return ttsp_edge_to_sp_tree.at(e); })); + for (MultiDiEdge const &e : parallel_reduction) { + ttsp_edge_to_sp_tree.erase(e); + } + ttsp_edge_to_sp_tree.insert({merged, new_tree}); + } continue; } @@ -87,7 +88,7 @@ std::optional MultiDiEdge e = get_only(get_edges(ttsp)); if (ttsp.get_multidiedge_src(e) != ttsp.get_multidiedge_dst(e)) { - return nary_sp_tree_from_binary(ttsp_edge_to_sp_tree.at(e)); + return ttsp_edge_to_sp_tree.at(e); } } } diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 78265f6856..c7eb866b62 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,9 +1,17 @@ #include "utils/graph/series_parallel/parallel_reduction.h" -#include "utils/graph/multidigraph/algorithms/get_edge_counts.h" +#include "utils/containers/get_one_of.h" +#include "utils/containers/group_by.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" -#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" -#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" +#include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/node/algorithms.h" +#include "utils/hash/unordered_set.h" +#include +#include namespace FlexFlow { @@ -15,31 +23,48 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &e1, std::optional find_parallel_reduction(MultiDiGraphView const &g) { - for (auto const &[directed_edge, count] : get_edge_counts(g)) { - - if (count <= 1) { - continue; - } - - std::unordered_set const &outgoing_edges = - get_outgoing_edges(g, directed_edge.src); - for (MultiDiEdge const &e1 : outgoing_edges) { - for (MultiDiEdge const &e2 : outgoing_edges) { - if (e1 != e2 && - g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { - return make_parallel_reduction(e1, e2); - } - } + std::unordered_map seen; + for (MultiDiEdge const &edge : get_edges(g)) { + DirectedEdge diedge = get_directed_edge(g, edge); + if (seen.find(diedge) != seen.end()) { + return make_parallel_reduction(seen.at(diedge), edge); } + seen.emplace(diedge, edge); } - return std::nullopt; } +std::unordered_map> + find_all_extended_parallel_reductions(MultiDiGraphView const &g) { + std::unordered_map> + parallel_groups = group_by(get_edges(g), [&](MultiDiEdge const &edge) { + return get_directed_edge(g, edge); + }); + + return filter( + parallel_groups, + [](std::pair> const + &group) { return group.second.size() > 1; }); +} + MultiDiEdge apply_parallel_reduction(MultiDiGraph &g, ParallelReduction const &r) { g.remove_edge(r.edges.max()); return r.edges.min(); } +MultiDiEdge apply_extended_parallel_reduction( + MultiDiGraph &g, std::unordered_set const ¶llel_edges) { + + MultiDiEdge keep_edge = get_one_of(parallel_edges); + + for (MultiDiEdge const ¶llel_edge : parallel_edges) { + if (parallel_edge != keep_edge) { + g.remove_edge(parallel_edge); + } + } + + return keep_edge; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index b7a84b871a..dc99ef6c5a 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,12 +1,17 @@ #include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/containers/all_of.h" +#include "utils/containers/extend.h" #include "utils/containers/multiset_union.h" #include "utils/containers/set_union.h" +#include "utils/containers/sum.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/values.h" #include "utils/containers/vector_of.h" #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/hash/unordered_set.h" #include "utils/variant.h" +#include namespace FlexFlow { @@ -74,4 +79,60 @@ std::unordered_multiset get_nodes(Node const &node) { return {node}; } +bool is_empty(Node const &node) { + return false; +} + +bool is_empty(SeriesSplit const &serial) { + return all_of(serial.children, [](auto const &child) { + return is_empty(widen(child)); + }); +} + +bool is_empty(ParallelSplit const ¶llel) { + return all_of(parallel.get_children(), [](auto const &child) { + return is_empty(widen(child)); + }); +} + +bool is_empty(SeriesParallelDecomposition const &sp) { + return sp.visit([](auto const &t) { return is_empty(t); }); +} + +SeriesParallelDecomposition serial_composition( + std::vector const &sp_compositions) { + std::vector> composition{}; + for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + extend(composition, sp_comp.get().children); + } else if (sp_comp.has()) { + composition.push_back(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.push_back(sp_comp.get()); + } + } + return SeriesParallelDecomposition{SeriesSplit{composition}}; +} + +SeriesParallelDecomposition parallel_composition( + std::unordered_multiset const + &sp_compositions) { + std::unordered_multiset< + std::variant<::FlexFlow::SeriesSplit, ::FlexFlow::Node>> + composition{}; + for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + composition = multiset_union(composition, + sp_comp.get().get_children()); + } else if (sp_comp.has()) { + composition.insert(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.insert(sp_comp.get()); + } + } + return SeriesParallelDecomposition(ParallelSplit{composition}); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc index 7300c93fb0..c312bb4a6b 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -1,8 +1,14 @@ #include "utils/graph/series_parallel/series_reduction.h" +#include "utils/containers/contains.h" +#include "utils/containers/get_only.h" #include "utils/containers/require_same.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/multidigraph/multidigraph_view.h" +#include "utils/graph/node/algorithms.h" +#include namespace FlexFlow { @@ -26,30 +32,13 @@ SeriesReduction make_series_reduction(MultiDiEdge const &e1, std::optional find_series_reduction(MultiDiGraphView const &g) { - std::unordered_set edges = get_edges(g); - - for (MultiDiEdge const &e1 : edges) { - for (MultiDiEdge const &e2 : edges) { - if (e1 == e2) { - continue; - } - Node e1_dst = g.get_multidiedge_dst(e1); - Node e2_src = g.get_multidiedge_src(e2); - if (e1_dst != e2_src) { - continue; - } - - std::unordered_set outgoing = get_outgoing_edges(g, e1_dst); - std::unordered_set incoming = get_incoming_edges(g, e1_dst); - - if (outgoing.size() > 1 || incoming.size() > 1) { - continue; - } - - return SeriesReduction{e1, e2}; + for (Node const &node : get_nodes(g)) { + if (get_incoming_edges(g, node).size() == 1 && + get_outgoing_edges(g, node).size() == 1) { + return make_series_reduction(get_only(get_incoming_edges(g, node)), + get_only(get_outgoing_edges(g, node))); } } - return std::nullopt; } diff --git a/lib/utils/test/src/utils/containers/contains.cc b/lib/utils/test/src/utils/containers/contains.cc index 6e0a84c7ab..fc42d25eea 100644 --- a/lib/utils/test/src/utils/containers/contains.cc +++ b/lib/utils/test/src/utils/containers/contains.cc @@ -1,13 +1,22 @@ #include "utils/containers/contains.h" #include +#include #include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("contains") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); + SUBCASE("std::vector") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(contains(v, 3)); + CHECK(!contains(v, 6)); + } + + SUBCASE("std::unordered_set") { + std::unordered_set s = {1, 2, 3, 4, 5}; + CHECK(contains(s, 3)); + CHECK(!contains(s, 6)); + } } }