From 3cc3407dc6d05a84ef845888cffe12adba60cb30 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sat, 11 Jan 2025 11:33:52 -0800 Subject: [PATCH] PR fixes --- lib/utils/include/utils/graph/algorithms.h | 4 +- .../include/utils/graph/digraph/algorithms.h | 4 +- .../algorithms/get_incoming_edges.h | 3 +- .../algorithms/get_outgoing_edges.h | 4 +- .../include/utils/graph/node/node.struct.toml | 1 - .../extended_parallel_reduction.struct.toml | 21 ++ .../extended_series_reduction.struct.toml | 21 ++ .../get_series_parallel_decomposition.h | 2 - .../series_parallel/parallel_reduction.h | 19 +- .../series_parallel_decomposition.h | 2 +- .../graph/series_parallel/series_reduction.h | 44 ++++- .../undirected/undirected_edge.struct.toml | 1 - .../src/utils/graph/digraph/algorithms.cc | 6 +- .../get_cbc_decomposition.cc | 4 +- .../is_complete_bipartite_digraph.cc | 2 +- .../digraph/algorithms/get_dominators_map.cc | 4 +- .../algorithms/get_topological_ordering.cc | 2 +- .../get_inverse_line_graph.cc | 8 +- .../graph/digraph/algorithms/is_acyclic.cc | 6 +- .../algorithms/get_incoming_edges.cc | 10 +- .../algorithms/get_outgoing_edges.cc | 9 +- .../algorithms/find_isomorphisms.cc | 8 +- .../get_series_parallel_decomposition.cc | 21 +- .../series_parallel/parallel_reduction.cc | 29 +-- .../series_parallel_decomposition.cc | 2 +- .../graph/series_parallel/series_reduction.cc | 40 ++-- lib/utils/src/utils/graph/views/views.cc | 7 +- .../test/src/utils/containers/contains.cc | 4 +- .../graph/digraph/algorithms/algorithms.cc | 184 ++++++++++-------- .../digraph/algorithms/directed_edge_query.cc | 98 +++++----- .../digraph/algorithms/get_dominators.cc | 41 ++-- .../graph/digraph/algorithms/traversal.cc | 111 ++++++----- .../algorithms/get_incoming_edges.cc | 34 +++- .../algorithms/get_outgoing_edges.cc | 48 +++-- .../graph/series_parallel/series_reduction.cc | 32 +-- .../algorithms/get_connected_components.cc | 117 +++++++---- .../src/utils/graph/undirected/undirected.cc | 32 --- 37 files changed, 587 insertions(+), 398 deletions(-) create mode 100644 lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 3f170b5652..ff7a7dcad2 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -139,10 +139,10 @@ std::unordered_set get_neighbors(DiGraphView const &, Node const &); // &); // return the set of nodes without incoming edges -std::unordered_set get_sources(DiGraphView const &); +std::unordered_set get_initial_nodes(DiGraphView const &); // return the set of nodes without outgoing edges -std::unordered_set get_sinks(DiGraphView const &); +std::unordered_set get_terminal_nodes(DiGraphView const &); // std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g); // std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g); diff --git a/lib/utils/include/utils/graph/digraph/algorithms.h b/lib/utils/include/utils/graph/digraph/algorithms.h index 370f181c37..fdced8a05c 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms.h +++ b/lib/utils/include/utils/graph/digraph/algorithms.h @@ -6,8 +6,8 @@ namespace FlexFlow { std::unordered_set get_edges(DiGraphView const &); -std::unordered_set get_sources(DiGraphView const &); -std::unordered_set get_sinks(DiGraphView const &); +std::unordered_set get_initial_nodes(DiGraphView const &); +std::unordered_set get_terminal_nodes(DiGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h b/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h index 76be999b54..471a12a44b 100644 --- a/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h +++ b/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h @@ -9,7 +9,8 @@ std::unordered_set get_incoming_edges(MultiDiGraphView const &, Node const &); std::unordered_map> - get_incoming_edges(MultiDiGraphView const &g); + get_incoming_edges(MultiDiGraphView const &g, + std::unordered_set const &nodes); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h b/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h index 6a8474673e..bd8c364f7e 100644 --- a/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h +++ b/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_ALGORITHMS_GET_OUTGOING_EDGES_H #include "utils/graph/multidigraph/multidigraph_view.h" +#include namespace FlexFlow { @@ -9,7 +10,8 @@ std::unordered_set get_outgoing_edges(MultiDiGraphView const &, Node const &); std::unordered_map> - get_outgoing_edges(MultiDiGraphView const &g); + get_outgoing_edges(MultiDiGraphView const &g, + std::unordered_set const &ns); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml index 46e0255de3..d5c22e5d3d 100644 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -6,7 +6,6 @@ features = [ "hash", "fmt", "json", - "rapidcheck", ] includes = [ diff --git a/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml new file mode 100644 index 0000000000..9c1ed68730 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "ExtendedParallelReduction" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/multidigraph/multidiedge.dtg.h", + "" +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml new file mode 100644 index 0000000000..f1cf0ccde3 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "ExtendedSeriesReduction" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/multidigraph/multidiedge.dtg.h", + "" +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "edges" +type = "std::vector<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h index f2a006d899..5f492c1aeb 100644 --- a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h @@ -4,8 +4,6 @@ #include "utils/graph/digraph/digraph.h" #include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include "utils/optional.h" -#include -#include namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h index 0b3c7f3619..7a3a7a021c 100644 --- a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h @@ -2,24 +2,39 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H #include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/series_parallel/extended_parallel_reduction.dtg.h" #include "utils/graph/series_parallel/parallel_reduction.dtg.h" #include +#include namespace FlexFlow { ParallelReduction make_parallel_reduction(MultiDiEdge const &, MultiDiEdge const &); + std::optional find_parallel_reduction(MultiDiGraphView const &); -std::unordered_map> +/** + * @brief Finds all ExtendedParallelReduction for a given MultiDiGraph + * @details An ExtendedParallelReduction is a unordered collection of + * `MultiDiEdge`s such that they share a common source and destination node. + */ +std::unordered_set find_all_extended_parallel_reductions(MultiDiGraphView const &); MultiDiEdge apply_parallel_reduction(MultiDiGraph &, ParallelReduction const &); +/** + * @brief Applies a given ExtendedParallelReduction in place to a given + * MultiDiGraph + * @details The reduction removes all but one `MultiDiEdge`, so that the source, + * destination nodes associated with the reduction become connected by a single + * edge. + */ MultiDiEdge apply_extended_parallel_reduction(MultiDiGraph &, - std::unordered_set const &); + ExtendedParallelReduction const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h index d56d4a55f7..b3fc201ca5 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -30,7 +30,7 @@ SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp, // duplicate nodes within `sp` are counted multiple times size_t num_nodes(SeriesParallelDecomposition const &sp); -SeriesParallelDecomposition serial_composition( +SeriesParallelDecomposition series_composition( std::vector const &sp_compositions); SeriesParallelDecomposition parallel_composition( std::unordered_multiset const diff --git a/lib/utils/include/utils/graph/series_parallel/series_reduction.h b/lib/utils/include/utils/graph/series_parallel/series_reduction.h index 0de8aecc19..3e281066d4 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/series_reduction.h @@ -3,6 +3,7 @@ #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/series_parallel/extended_series_reduction.dtg.h" #include "utils/graph/series_parallel/series_reduction.dtg.h" #include "utils/hash/vector.h" @@ -15,13 +16,50 @@ Node get_center_node(MultiDiGraphView const &, SeriesReduction const &); SeriesReduction make_series_reduction(MultiDiEdge const &, MultiDiEdge const &); std::optional find_series_reduction(MultiDiGraphView const &); -std::unordered_set> +/** + * @brief Finds all the ExtendedSeriesReduction structures in a given graph. + * + * @details An `ExtendedSeriesReduction` is an ordered collection of + * `MultiDiEdges` such that: + * - The destination node of the nth edge is the same as the source node of the + * (n+1)th edge. + * - Such a node (intermediate node) has exactly two edges: one incoming (nth + * edge) and one outgoing ((n+1)th edge). + * + * For example, in the following graph: + * + * A -> B -> D -> E + * \ / + * -> C -> + * + * We have that [(A,B), (B,D), (D,E)] and [(A,C), (C,E)] both constitute + * `ExtendedSeriesReduction`. + */ +std::unordered_set find_all_extended_series_reductions(MultiDiGraphView const &g); MultiDiEdge apply_series_reduction(MultiDiGraph &, SeriesReduction const &); -MultiDiEdge apply_extended_series_reduction( - MultiDiGraph &g, std::vector const &series_edges); +/** + * @brief Applies a given ExtendedSeriesReduction in-place to a given graph. + * + * For example, in the following graph: + * + * A -> B -> D -> E + * \ / + * -> C -> + * + * Given the ExtendedSeriesReduction [(A,B), (B,D), (D,E)], the intermediate + *nodes B, D, will be deleted, and the resulting graph will be: + * + * A ----> E + * \ / + * -> C -> + * + **/ +MultiDiEdge + apply_extended_series_reduction(MultiDiGraph &g, + ExtendedSeriesReduction const &reduction); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml index f5258b0bfd..0ad8232339 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml @@ -5,7 +5,6 @@ features = [ "ord", "hash", "fmt", - "rapidcheck" ] includes = [ diff --git a/lib/utils/src/utils/graph/digraph/algorithms.cc b/lib/utils/src/utils/graph/digraph/algorithms.cc index 8cd685e5c6..84798b2f62 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms.cc @@ -15,11 +15,11 @@ std::unordered_set get_edges(DiGraphView const &g) { return g.query_edges(directed_edge_query_all()); } -std::unordered_set get_sinks(DiGraphView const &g) { - return get_sources(flipped(g)); +std::unordered_set get_terminal_nodes(DiGraphView const &g) { + return get_initial_nodes(flipped(g)); } -std::unordered_set get_sources(DiGraphView const &g) { +std::unordered_set get_initial_nodes(DiGraphView const &g) { std::unordered_set all_nodes = get_nodes(g); std::unordered_set with_incoming_edge = transform(get_edges(g), [](DirectedEdge const &e) { return e.dst; }); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 92bd1e32ca..9a2f9cb019 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -71,8 +71,8 @@ std::optional extend(already_in_a_tail, tail); } - assert(already_in_a_head == set_minus(get_nodes(g), get_sinks(g))); - assert(already_in_a_tail == set_minus(get_nodes(g), get_sources(g))); + assert(already_in_a_head == set_minus(get_nodes(g), get_terminal_nodes(g))); + assert(already_in_a_tail == set_minus(get_nodes(g), get_initial_nodes(g))); return result; } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc index ccd2808603..bf428ed26b 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc @@ -6,7 +6,7 @@ namespace FlexFlow { bool is_complete_bipartite_digraph(DiGraphView const &g) { - return is_complete_bipartite_digraph(g, get_sources(g)); + return is_complete_bipartite_digraph(g, get_initial_nodes(g)); } bool is_complete_bipartite_digraph(DiGraphView const &g, diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc index 3dd9de73f0..1d909150cc 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc @@ -15,11 +15,11 @@ namespace FlexFlow { std::unordered_map> get_dominators_map(DiGraphView const &g) { - std::unordered_set sources = get_sources(g); + std::unordered_set initial_nodes = get_initial_nodes(g); std::queue queue; - for (Node src : get_sources(g)) { + for (Node src : get_initial_nodes(g)) { queue.push(src); } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc index 41fe3b67d5..fea799b3e9 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc @@ -9,7 +9,7 @@ namespace FlexFlow { static std::vector get_unchecked_topological_ordering(DiGraphView const &g) { - auto dfs_view = unchecked_dfs(g, get_sources(g)); + auto dfs_view = unchecked_dfs(g, get_initial_nodes(g)); std::vector order; std::unordered_set seen; std::unordered_map> predecessors = diff --git a/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc b/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc index 77f04f2efd..ccf943c4d3 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc @@ -42,11 +42,11 @@ std::optional return get_component_containing_node_in_tail(cbc_decomposition, n).value(); }; - std::unordered_set sources = get_sources(view); - std::unordered_set sinks = get_sinks(view); + std::unordered_set initial_nodes = get_initial_nodes(view); + std::unordered_set terminal_nodes = get_terminal_nodes(view); auto src_for_node = [&](Node const &v) -> Node { - if (contains(sources, v)) { + if (contains(initial_nodes, v)) { return alpha; } else { return component_nodes.at_l(t(v)); @@ -54,7 +54,7 @@ std::optional }; auto dst_for_node = [&](Node const &v) -> Node { - if (contains(sinks, v)) { + if (contains(terminal_nodes, v)) { return omega; } else { return component_nodes.at_l(h(v)); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc index dd660f193d..018d07163d 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -9,11 +9,11 @@ std::optional is_acyclic(DiGraphView const &g) { if (num_nodes(g) == 0) { return std::nullopt; } - std::unordered_set sources = get_sources(g); - if (sources.size() == 0) { + std::unordered_set initial_nodes = get_initial_nodes(g); + if (initial_nodes.size() == 0) { return false; } - auto dfs_view = unchecked_dfs(g, sources); + auto dfs_view = unchecked_dfs(g, initial_nodes); std::unordered_set seen; for (unchecked_dfs_iterator it = dfs_view.begin(); it != dfs_view.end(); it++) { diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index 50818dea2f..7a5ba695f9 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -2,7 +2,9 @@ #include "utils/containers/group_by.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/multidiedge.dtg.h" +#include "utils/graph/multidigraph/multidiedge_query.dtg.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/query_set.h" namespace FlexFlow { @@ -12,12 +14,14 @@ std::unordered_set get_incoming_edges(MultiDiGraphView const &g, } std::unordered_map> - get_incoming_edges(MultiDiGraphView const &g) { + get_incoming_edges(MultiDiGraphView const &g, + std::unordered_set const &ns) { std::unordered_map> result = - group_by(get_edges(g), + group_by(g.query_edges(MultiDiEdgeQuery{query_set::matchall(), + query_set{ns}}), [&](MultiDiEdge const &e) { return g.get_multidiedge_dst(e); }); - for (Node const &n : get_nodes(g)) { + for (Node const &n : ns) { result[n]; } diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index 55847cf2af..d183b44137 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -2,6 +2,7 @@ #include "utils/containers/group_by.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" +#include namespace FlexFlow { std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, @@ -10,12 +11,14 @@ std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, } std::unordered_map> - get_outgoing_edges(MultiDiGraphView const &g) { + get_outgoing_edges(MultiDiGraphView const &g, + std::unordered_set const &ns) { std::unordered_map> result = - group_by(get_edges(g), + group_by(g.query_edges(MultiDiEdgeQuery{query_set{ns}, + query_set::matchall()}), [&](MultiDiEdge const &e) { return g.get_multidiedge_src(e); }); - for (Node const &n : get_nodes(g)) { + for (Node const &n : ns) { result[n]; } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc index fa17678943..375aaa3762 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -34,14 +34,14 @@ static std::optional { std::unordered_set already_mapped_src_nodes = left_entries(sink_node_mapping); - std::unordered_set src_g_sink_nodes = get_sinks(src_g); + std::unordered_set src_g_sink_nodes = get_terminal_nodes(src_g); assert(already_mapped_src_nodes == src_g_sink_nodes); } { std::unordered_set already_mapped_dst_nodes = right_entries(sink_node_mapping); - std::unordered_set dst_g_sink_nodes = get_sinks(dst_g); + std::unordered_set dst_g_sink_nodes = get_terminal_nodes(dst_g); assert(already_mapped_dst_nodes == dst_g_sink_nodes); } @@ -201,8 +201,8 @@ std::unordered_set OpenDataflowGraphView const &dst) { std::unordered_set result; - std::vector src_sink_nodes = vector_of(get_sinks(src)); - std::unordered_set dst_sink_nodes = get_sinks(dst); + std::vector src_sink_nodes = vector_of(get_terminal_nodes(src)); + std::unordered_set dst_sink_nodes = get_terminal_nodes(dst); if (src_sink_nodes.size() != dst_sink_nodes.size()) { return {}; diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 908743fae1..b45e62eae7 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -12,6 +12,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/extended_series_reduction.dtg.h" #include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include "utils/graph/series_parallel/series_parallel_decomposition.h" @@ -45,19 +46,19 @@ std::optional while (true) { int reductions = 0; - std::unordered_map> - parallel_reductions = find_all_extended_parallel_reductions(ttsp); + std::unordered_set parallel_reductions = + find_all_extended_parallel_reductions(ttsp); if (!parallel_reductions.empty()) { - for (auto const &[_, parallel_reduction] : parallel_reductions) { + for (ExtendedParallelReduction parallel_reduction : parallel_reductions) { MultiDiEdge merged = apply_extended_parallel_reduction(ttsp, parallel_reduction); SeriesParallelDecomposition new_tree = parallel_composition(transform( - unordered_multiset_of(parallel_reduction), + unordered_multiset_of(parallel_reduction.edges), [&](MultiDiEdge const &e) { return ttsp_edge_to_sp_tree.at(e); })); - for (MultiDiEdge const &e : parallel_reduction) { + for (MultiDiEdge const &e : parallel_reduction.edges) { ttsp_edge_to_sp_tree.erase(e); } ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -65,19 +66,19 @@ std::optional reductions++; } - std::unordered_set> series_reductions = + std::unordered_set series_reductions = find_all_extended_series_reductions(ttsp); if (!series_reductions.empty()) { - for (std::vector series_reduction : series_reductions) { + for (ExtendedSeriesReduction series_reduction : series_reductions) { MultiDiEdge merged = apply_extended_series_reduction(ttsp, series_reduction); - SeriesParallelDecomposition new_tree = serial_composition( - transform(series_reduction, [&](MultiDiEdge const &e) { + SeriesParallelDecomposition new_tree = series_composition( + transform(series_reduction.edges, [&](MultiDiEdge const &e) { return ttsp_edge_to_sp_tree.at(e); })); - for (MultiDiEdge const &e : series_reduction) { + for (MultiDiEdge const &e : series_reduction.edges) { ttsp_edge_to_sp_tree.erase(e); } ttsp_edge_to_sp_tree.insert({merged, new_tree}); diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index c7eb866b62..3aa677a2f7 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,6 +1,7 @@ #include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/containers/get_one_of.h" #include "utils/containers/group_by.h" +#include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" #include "utils/graph/digraph/directed_edge.dtg.h" @@ -9,6 +10,7 @@ #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/extended_parallel_reduction.dtg.h" #include "utils/hash/unordered_set.h" #include #include @@ -34,17 +36,22 @@ std::optional return std::nullopt; } -std::unordered_map> +std::unordered_set find_all_extended_parallel_reductions(MultiDiGraphView const &g) { std::unordered_map> - parallel_groups = group_by(get_edges(g), [&](MultiDiEdge const &edge) { - return get_directed_edge(g, edge); - }); + reduction_groups; + for (MultiDiEdge const &edge : get_edges(g)) { + reduction_groups[get_directed_edge(g, edge)].insert(edge); + } + + std::unordered_set> reductions = filter( + unordered_set_of(values(reduction_groups)), + [](std::unordered_set const &s) { return s.size() > 1; }); - return filter( - parallel_groups, - [](std::pair> const - &group) { return group.second.size() > 1; }); + return transform(reductions, + [&](std::unordered_set const &edges) { + return ExtendedParallelReduction{edges}; + }); } MultiDiEdge apply_parallel_reduction(MultiDiGraph &g, @@ -54,11 +61,11 @@ MultiDiEdge apply_parallel_reduction(MultiDiGraph &g, } MultiDiEdge apply_extended_parallel_reduction( - MultiDiGraph &g, std::unordered_set const ¶llel_edges) { + MultiDiGraph &g, ExtendedParallelReduction const &reduction) { - MultiDiEdge keep_edge = get_one_of(parallel_edges); + MultiDiEdge keep_edge = get_one_of(reduction.edges); - for (MultiDiEdge const ¶llel_edge : parallel_edges) { + for (MultiDiEdge const ¶llel_edge : reduction.edges) { if (parallel_edge != keep_edge) { g.remove_edge(parallel_edge); } diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index dc99ef6c5a..937fc1254e 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -99,7 +99,7 @@ bool is_empty(SeriesParallelDecomposition const &sp) { return sp.visit([](auto const &t) { return is_empty(t); }); } -SeriesParallelDecomposition serial_composition( +SeriesParallelDecomposition series_composition( std::vector const &sp_compositions) { std::vector> composition{}; for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { diff --git a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc index 26fabe593c..5b9b592444 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -15,6 +15,7 @@ #include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/multidigraph/multidigraph_view.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/extended_series_reduction.dtg.h" #include "utils/hash/unordered_set.h" #include @@ -50,24 +51,28 @@ std::optional return std::nullopt; } -std::unordered_set> +std::unordered_set find_all_extended_series_reductions(MultiDiGraphView const &g) { - std::unordered_map> incoming_edges = - get_incoming_edges(g); - std::unordered_map> outgoing_edges = - get_outgoing_edges(g); + + auto incoming_edges_map = get_incoming_edges(g, get_nodes(g)); + auto outgoing_edges_map = get_outgoing_edges(g, get_nodes(g)); + std::unordered_map> strands; std::unordered_map node_to_head_of_strand; + for (Node const &n : get_topological_ordering(g)) { - if ((incoming_edges.at(n).size() == 1) && - (outgoing_edges.at(n).size() == 1)) { - MultiDiEdge incoming = get_only(incoming_edges.at(n)); - MultiDiEdge outgoing = get_only(outgoing_edges.at(n)); + if ((incoming_edges_map.at(n).size() == 1) && + (outgoing_edges_map.at(n).size() == 1)) { + + MultiDiEdge incoming = get_only(incoming_edges_map.at(n)); + MultiDiEdge outgoing = get_only(outgoing_edges_map.at(n)); Node pre = g.get_multidiedge_src(incoming); + if (contains_key(node_to_head_of_strand, pre)) { Node head = node_to_head_of_strand.at(pre); node_to_head_of_strand.emplace(n, head); strands.at(head).push_back(outgoing); + } else { node_to_head_of_strand.emplace(n, n); strands[n].push_back(incoming); @@ -75,7 +80,10 @@ std::unordered_set> } } } - return unordered_set_of(values(strands)); + + return transform(unordered_set_of(values(strands)), [&](auto const &edges) { + return ExtendedSeriesReduction{edges}; + }); } MultiDiEdge apply_series_reduction(MultiDiGraph &g, SeriesReduction const &r) { @@ -87,16 +95,18 @@ MultiDiEdge apply_series_reduction(MultiDiGraph &g, SeriesReduction const &r) { return g.add_edge(pre_node, post_node); } -MultiDiEdge apply_extended_series_reduction( - MultiDiGraph &g, std::vector const &series_edges) { +MultiDiEdge + apply_extended_series_reduction(MultiDiGraph &g, + ExtendedSeriesReduction const &reduction) { - Node first = g.get_multidiedge_src(series_edges.at(0)); - Node last = g.get_multidiedge_dst(series_edges.at(series_edges.size() - 1)); + Node first = g.get_multidiedge_src(reduction.edges.at(0)); + Node last = g.get_multidiedge_dst(reduction.edges.back()); std::vector internal_nodes; - for (MultiDiEdge const &e : subvec(series_edges, std::nullopt, -1)) { + for (MultiDiEdge const &e : subvec(reduction.edges, std::nullopt, -1)) { internal_nodes.push_back(g.get_multidiedge_dst(e)); } + for (Node const &n : internal_nodes) { g.remove_node(n); } diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index 7bb039d314..e8f0a443c4 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -1,12 +1,8 @@ #include "utils/graph/views/views.h" -#include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/flatmap.h" #include "utils/containers/transform.h" -#include "utils/disjoint_set.h" -#include "utils/exception.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/directed_edge_query.h" -#include "utils/graph/node/algorithms.h" #include "utils/graph/node/node_query.h" #include "utils/graph/query_set.h" #include "utils/graph/undirected/undirected_edge_query.h" @@ -118,8 +114,7 @@ ViewUndirectedGraphAsDiGraph *ViewUndirectedGraphAsDiGraph::clone() const { std::unordered_set ViewUndirectedGraphAsDiGraph::query_edges( DirectedEdgeQuery const &q) const { std::unordered_set undirected_edges = - set_union(g.query_edges(UndirectedEdgeQuery{q.srcs}), - g.query_edges(UndirectedEdgeQuery{q.dsts})); + g.query_edges(UndirectedEdgeQuery{query_union(q.srcs, q.dsts)}); std::unordered_set directed_edges = flatmap(undirected_edges, [](UndirectedEdge const &e) { return to_directed_edges(e); }); diff --git a/lib/utils/test/src/utils/containers/contains.cc b/lib/utils/test/src/utils/containers/contains.cc index fc42d25eea..9d686ab814 100644 --- a/lib/utils/test/src/utils/containers/contains.cc +++ b/lib/utils/test/src/utils/containers/contains.cc @@ -10,13 +10,13 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("std::vector") { std::vector v = {1, 2, 3, 4, 5}; CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); + CHECK_FALSE(contains(v, 6)); } SUBCASE("std::unordered_set") { std::unordered_set s = {1, 2, 3, 4, 5}; CHECK(contains(s, 3)); - CHECK(!contains(s, 6)); + CHECK_FALSE(contains(s, 6)); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc index 0817c69e06..fd39449c2c 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc @@ -8,99 +8,119 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("DiGraph - algorithms.cc") { + TEST_CASE("get_edges(DiGraph)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); std::vector e = { - DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[0], n[3]}, - DirectedEdge{n[1], n[2]}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, }; add_edges(g, e); - SUBCASE("get_edges") { - SUBCASE("Base") { - std::unordered_set correct = unordered_set_of(e); - std::unordered_set result = get_edges(g); - CHECK(result == correct); - } - - SUBCASE("Adding an edge") { - g.add_edge(DirectedEdge{n[3], n[1]}); - std::unordered_set correct = { - DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[0], n[3]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[3], n[1]}, - }; - std::unordered_set result = get_edges(g); - CHECK(result == correct); - } - - SUBCASE("Removing an edge") { - g.remove_edge(DirectedEdge{n[0], n[3]}); - std::unordered_set correct = { - DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[2]}, - }; - std::unordered_set result = get_edges(g); - CHECK(result == correct); - } + SUBCASE("Base") { + std::unordered_set correct = unordered_set_of(e); + std::unordered_set result = get_edges(g); + CHECK(result == correct); } - SUBCASE("get_sinks") { - SUBCASE("Base") { - std::unordered_set correct = {n[2], n[3]}; - std::unordered_set result = get_sinks(g); - CHECK(result == correct); - } - - SUBCASE("Adding an edge to remove a sink") { - g.add_edge(DirectedEdge{n[3], n[2]}); - std::unordered_set correct = {n[2]}; - std::unordered_set result = get_sinks(g); - CHECK(result == correct); - } - - SUBCASE("Creating a cycle") { - g.add_edge(DirectedEdge{n[2], n[0]}); - std::unordered_set result = get_sinks(g); - std::unordered_set correct = {n[3]}; - CHECK(result == correct); - } + SUBCASE("Adding an edge") { + g.add_edge(DirectedEdge{n.at(3), n.at(1)}); + std::unordered_set correct = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(3), n.at(1)}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); } - SUBCASE("get_sources") { - SUBCASE("Base") { - std::unordered_set correct = {n[0]}; - std::unordered_set result = get_sources(g); - CHECK(result == correct); - } - - SUBCASE("Adding an edge to remove a source") { - g.add_edge(DirectedEdge{n[2], n[0]}); - std::unordered_set correct = {}; - std::unordered_set result = get_sources(g); - CHECK(result == correct); - } - - SUBCASE("Removing an edge to create a new source") { - g.remove_edge(DirectedEdge{n[0], n[1]}); - std::unordered_set correct = {n[0], n[1]}; - std::unordered_set result = get_sources(g); - CHECK(result == correct); - } - - SUBCASE("Creating a cycle") { - g.add_edge(DirectedEdge{n[2], n[0]}); - std::unordered_set result = get_sources(g); - std::unordered_set correct = {}; - CHECK(result.empty()); - } + SUBCASE("Removing an edge") { + g.remove_edge(DirectedEdge{n.at(0), n.at(3)}); + std::unordered_set correct = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + } + + TEST_CASE("get_terminal_nodes(DiGraph)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); + + SUBCASE("Base") { + std::unordered_set correct = {n.at(2), n.at(3)}; + std::unordered_set result = get_terminal_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a terminal node") { + g.add_edge(DirectedEdge{n.at(3), n.at(2)}); + std::unordered_set correct = {n.at(2)}; + std::unordered_set result = get_terminal_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n.at(2), n.at(0)}); + std::unordered_set result = get_terminal_nodes(g); + std::unordered_set correct = {n.at(3)}; + CHECK(result == correct); + } + } + + TEST_CASE("get_initial_nodes(DiGraph)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); + + SUBCASE("Base") { + std::unordered_set correct = {n.at(0)}; + std::unordered_set result = get_initial_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a source") { + g.add_edge(DirectedEdge{n.at(2), n.at(0)}); + std::unordered_set correct = {}; + std::unordered_set result = get_initial_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge to create a new source") { + g.remove_edge(DirectedEdge{n.at(0), n.at(1)}); + std::unordered_set correct = {n.at(0), n.at(1)}; + std::unordered_set result = get_initial_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n.at(2), n.at(0)}); + std::unordered_set result = get_initial_nodes(g); + std::unordered_set correct = {}; + CHECK(result.empty()); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc index 1dde5c8f69..ee7ead009e 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc @@ -1,70 +1,68 @@ #include "utils/graph/digraph/directed_edge_query.h" #include "utils/graph/algorithms.h" -#include "utils/graph/digraph/algorithms/get_successors.h" -#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/digraph/directed_edge_query.dtg.h" #include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("directed_edge_query") { - DiGraph g = DiGraph::create(); + TEST_CASE("directed_edge_query_all") { + Node n1{0}, n2{1}, n3{2}; + DirectedEdge e1 = DirectedEdge{n1, n2}; + DirectedEdge e2 = DirectedEdge{n2, n3}; - std::vector n = add_nodes(g, 5); + DirectedEdgeQuery result = directed_edge_query_all(); - add_edges(g, - {DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(2), n.at(4)}, - DirectedEdge{n.at(1), n.at(3)}}); + CHECK(matches_edge(result, e1)); + CHECK(matches_edge(result, e2)); + } - SUBCASE("directed_edge_query_all") { + TEST_CASE("matches_edge") { + Node n1{0}, n2{1}, n3{2}; + DirectedEdge e1 = DirectedEdge{n1, n2}; + DirectedEdge e2 = DirectedEdge{n2, n3}; - DirectedEdgeQuery result = directed_edge_query_all(); + DirectedEdgeQuery query = DirectedEdgeQuery{query_set{n1}, query_set{n2}}; - CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(1)})); - CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(2)})); - CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(2)})); - CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(4)})); - CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(3)})); - } + CHECK(matches_edge(query, e1)); + CHECK_FALSE(matches_edge(query, e2)); + + DirectedEdge flipped_edge = DirectedEdge{n2, n1}; + CHECK_FALSE(matches_edge(query, flipped_edge)); + } - SUBCASE("matches_edge") { - DirectedEdgeQuery q = - DirectedEdgeQuery{query_set{n.at(0)}, query_set{n.at(1)}}; + TEST_CASE("query_intersection") { + Node n1{0}, n2{1}, n3{2}, n4{3}; + DirectedEdge e1 = DirectedEdge{n1, n2}; + DirectedEdge e2 = DirectedEdge{n2, n3}; + DirectedEdge e3 = DirectedEdge{n3, n4}; - CHECK(matches_edge(q, DirectedEdge{n.at(0), n.at(1)})); - CHECK_FALSE(matches_edge(q, DirectedEdge{n.at(1), n.at(2)})); + SUBCASE("standard intersection") { + DirectedEdgeQuery q1 = + DirectedEdgeQuery{query_set{n1, n2}, query_set{n2, n3}}; + DirectedEdgeQuery q2 = + DirectedEdgeQuery{query_set{n2, n3}, query_set{n3, n4}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery expected = + DirectedEdgeQuery{query_set{n2}, query_set{n3}}; + + CHECK(result == expected); } - SUBCASE("query_intersection") { - SUBCASE("standard intersection") { - DirectedEdgeQuery q1 = DirectedEdgeQuery{ - query_set{n.at(0), n.at(1)}, query_set{n.at(1), n.at(2), n.at(4)}}; - DirectedEdgeQuery q2 = DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, - query_set{n.at(2), n.at(3)}}; - - DirectedEdgeQuery result = query_intersection(q1, q2); - DirectedEdgeQuery correct = DirectedEdgeQuery{ - query_set{n.at(1)}, - query_set{n.at(2)}, - }; - - CHECK(result == correct); - } - SUBCASE("intersection with std::nullopt") { - DirectedEdgeQuery q1 = - DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, matchall()}; - DirectedEdgeQuery q2 = - DirectedEdgeQuery{matchall(), query_set{n.at(3), n.at(4)}}; - - DirectedEdgeQuery result = query_intersection(q1, q2); - DirectedEdgeQuery correct = DirectedEdgeQuery{ - query_set{n.at(1), n.at(2)}, query_set{n.at(3), n.at(4)}}; - CHECK(result == correct); - } + SUBCASE("intersection with matchall") { + DirectedEdgeQuery q1 = + DirectedEdgeQuery{query_set{n1, n2}, matchall()}; + DirectedEdgeQuery q2 = + DirectedEdgeQuery{matchall(), query_set{n3, n4}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery expected = + DirectedEdgeQuery{query_set{n1, n2}, query_set{n3, n4}}; + + CHECK(result == expected); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc index e9151b53e5..17bea2210f 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc @@ -9,28 +9,29 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_dominators") { DiGraph g = DiGraph::create(); + SUBCASE("acyclic graph") { + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); - std::vector n = add_nodes(g, 4); - std::vector e = { - DirectedEdge{n.at(0), n.at(3)}, - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(2)}, - }; - add_edges(g, e); - - SUBCASE("single node") { - Node node = n.at(2); - std::unordered_set correct = {n.at(0), n.at(2)}; - std::unordered_set result = get_dominators(g, node); - CHECK(correct == result); - } + SUBCASE("get_dominators(DiGraph, Node)") { + Node node = n.at(2); + std::unordered_set correct = {n.at(0), n.at(2)}; + std::unordered_set result = get_dominators(g, node); + CHECK(correct == result); + } - SUBCASE("multiple nodes") { - std::unordered_set nodes = {n.at(1), n.at(3)}; - std::unordered_set result = get_dominators(g, nodes); - std::unordered_set correct = {n.at(0)}; - CHECK(correct == result); + SUBCASE("get_dominators(DiGraph, std::unordered_set)") { + std::unordered_set nodes = {n.at(1), n.at(3)}; + std::unordered_set result = get_dominators(g, nodes); + std::unordered_set correct = {n.at(0)}; + CHECK(correct == result); + } } SUBCASE("graph with cycles") { diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc index 0d8e7ca53a..f778cfbd22 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc @@ -1,4 +1,5 @@ #include "utils/graph/traversal.h" +#include "utils/containers/contains.h" #include "utils/fmt/vector.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/digraph.h" @@ -12,40 +13,55 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_unchecked_dfs_ordering") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); - add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[3]}}); - SUBCASE("simple path") { - std::vector correct = {n[0], n[1], n[2], n[3]}; - std::vector result = get_unchecked_dfs_ordering(g, {n[0]}); + SUBCASE("linear path") { + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}}); + + std::vector correct = {n.at(0), n.at(1), n.at(2), n.at(3)}; + std::vector result = get_unchecked_dfs_ordering(g, {n.at(0)}); CHECK(correct == result); } + + SUBCASE("diamond path") { + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}}); + + std::unordered_set> corrects = { + {n.at(0), n.at(1), n.at(3), n.at(2), n.at(3)}, + {n.at(0), n.at(2), n.at(3), n.at(1), n.at(3)}}; + std::vector result = get_unchecked_dfs_ordering(g, {n.at(0)}); + CHECK(contains(corrects, result)); + } } TEST_CASE("get_bfs_ordering") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 6); add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[3]}, - DirectedEdge{n[2], n[3]}, - DirectedEdge{n[3], n[4]}, - DirectedEdge{n[4], n[5]}}); + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}}); SUBCASE("branching path") { std::unordered_set> corrects = { - {n[0], n[1], n[2], n[3], n[4], n[5]}, - {n[0], n[2], n[1], n[3], n[4], n[5]}}; - std::vector result = get_bfs_ordering(g, {n[0]}); + {n.at(0), n.at(1), n.at(2), n.at(3), n.at(4), n.at(5)}, + {n.at(0), n.at(2), n.at(1), n.at(3), n.at(4), n.at(5)}}; + std::vector result = get_bfs_ordering(g, {n.at(0)}); CHECK(contains(corrects, result)); } SUBCASE("isolated node") { - std::vector correct = {n[5]}; - std::vector result = get_bfs_ordering(g, {n[5]}); + std::vector correct = {n.at(5)}; + std::vector result = get_bfs_ordering(g, {n.at(5)}); CHECK(correct == result); } @@ -53,15 +69,15 @@ TEST_SUITE(FF_TEST_SUITE) { g = DiGraph::create(); n = add_nodes(g, 3); add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[0]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[0]}, - DirectedEdge{n[2], n[1]}}); - std::unordered_set> corrects = {{n[0], n[1], n[2]}, - {n[0], n[2], n[1]}}; - std::vector result = get_bfs_ordering(g, {n[0]}); + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(0)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(2), n.at(1)}}); + std::unordered_set> corrects = { + {n.at(0), n.at(1), n.at(2)}, {n.at(0), n.at(2), n.at(1)}}; + std::vector result = get_bfs_ordering(g, {n.at(0)}); CHECK(contains(corrects, result)); } } @@ -70,42 +86,49 @@ TEST_SUITE(FF_TEST_SUITE) { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[3]}}); + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}}); SUBCASE("simple path") { - std::vector correct = {n[0], n[1], n[2], n[3]}; - std::vector result = get_dfs_ordering(g, {n[0]}); + std::vector correct = {n.at(0), n.at(1), n.at(2), n.at(3)}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); + CHECK(correct == result); + } + + SUBCASE("start from non-initial node") { + std::vector correct = {n.at(1), n.at(2), n.at(3)}; + std::vector result = get_unchecked_dfs_ordering(g, {n.at(1)}); CHECK(correct == result); } SUBCASE("with cycle") { - g.add_edge(DirectedEdge{n[3], n[1]}); - std::vector correct = {n[0], n[1], n[2], n[3]}; - std::vector result = get_dfs_ordering(g, {n[0]}); + g.add_edge(DirectedEdge{n.at(3), n.at(1)}); + std::vector correct = {n.at(0), n.at(1), n.at(2), n.at(3)}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); CHECK(correct == result); } SUBCASE("branching") { - g.add_edge(DirectedEdge{n[1], n[3]}); + g.add_edge(DirectedEdge{n.at(1), n.at(3)}); std::unordered_set> corrects = { - {n[0], n[1], n[2], n[3]}, {n[0], n[1], n[3], n[2]}}; - std::vector result = get_dfs_ordering(g, {n[0]}); + {n.at(0), n.at(1), n.at(2), n.at(3)}, + {n.at(0), n.at(1), n.at(3), n.at(2)}}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); CHECK(contains(corrects, result)); } SUBCASE("disconnected") { - g.remove_edge(DirectedEdge{n[2], n[3]}); - std::vector correct = {n[0], n[1], n[2]}; - std::vector result = get_dfs_ordering(g, {n[0]}); + g.remove_edge(DirectedEdge{n.at(2), n.at(3)}); + std::vector correct = {n.at(0), n.at(1), n.at(2)}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); CHECK(correct == result); } SUBCASE("isolated node") { - g.remove_edge(DirectedEdge{n[2], n[3]}); - std::vector correct = {n[3]}; - std::vector result = get_dfs_ordering(g, {n[3]}); + g.remove_edge(DirectedEdge{n.at(2), n.at(3)}); + std::vector correct = {n.at(3)}; + std::vector result = get_dfs_ordering(g, {n.at(3)}); CHECK(correct == result); } } diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index b5943cd99f..b15b8a9d7d 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -11,7 +11,7 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_incoming_edges(MultiDiGraphView, Node)") { + TEST_CASE("get_incoming_edges") { MultiDiGraph g = MultiDiGraph::create(); std::vector n = add_nodes(g, 3); @@ -19,17 +19,33 @@ TEST_SUITE(FF_TEST_SUITE) { {{n.at(0), n.at(0)}, {n.at(0), n.at(1)}, {n.at(0), n.at(1)}, - {n.at(1), n.at(0)}}); + {n.at(1), n.at(0)}, + {n.at(2), n.at(0)}}); - SUBCASE("node has incoming edges") { - std::unordered_set result = get_incoming_edges(g, n.at(1)); - std::unordered_set correct = {edges.at(1), edges.at(2)}; - CHECK(result == correct); + SUBCASE("get_incoming_edges(MultiDiGraphView, Node)") { + + SUBCASE("node has incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(1)); + std::unordered_set correct = {edges.at(1), edges.at(2)}; + CHECK(result == correct); + } + + SUBCASE("node has no incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } } - SUBCASE("node has no incoming edges") { - std::unordered_set result = get_incoming_edges(g, n.at(2)); - std::unordered_set correct = {}; + SUBCASE("get_incoming_edges(MultiDiGraphView, std::unordered_set)") { + + std::unordered_set ns = {n.at(0), n.at(2)}; + std::unordered_map> result = + get_incoming_edges(g, ns); + + std::unordered_map> correct = { + {n.at(0), {edges.at(0), edges.at(3), edges.at(4)}}, {n.at(2), {}}}; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index d4748e8422..69b38090d3 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -11,29 +11,45 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_outgoing_edges(MultiDiGraph, Node)") { + TEST_CASE("get_outgoing_edges") { MultiDiGraph g = MultiDiGraph::create(); std::vector n = add_nodes(g, 3); - std::vector> input = { - {n.at(0), n.at(0)}, - {n.at(0), n.at(1)}, - {n.at(0), n.at(1)}, - {n.at(1), n.at(0)}, - }; + std::vector edges = add_edges(g, + { + {n.at(0), n.at(0)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(2)}, + {n.at(1), n.at(0)}, + }); - std::vector edges = add_edges(g, input); + SUBCASE("get_outgoing_edges(MultiDiGraphView, Node)") { - SUBCASE("node has outgoing edges") { - std::unordered_set result = get_outgoing_edges(g, n.at(0)); - std::unordered_set correct = { - edges.at(0), edges.at(1), edges.at(2)}; - CHECK(result == correct); + SUBCASE("node has outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(0)); + std::unordered_set correct = { + edges.at(0), edges.at(1), edges.at(2), edges.at(3)}; + CHECK(result == correct); + } + + SUBCASE("node has no outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } } - SUBCASE("node has no outgoing edges") { - std::unordered_set result = get_outgoing_edges(g, n.at(2)); - std::unordered_set correct = {}; + SUBCASE("get_outgoing_edges(MultiDiGraphView, std::unordered_set)") { + + std::unordered_set ns = {n.at(0), n.at(1)}; + std::unordered_map> result = + get_outgoing_edges(g, ns); + + std::unordered_map> correct = { + {n.at(0), {edges.at(0), edges.at(1), edges.at(2), edges.at(3)}}, + {n.at(1), {edges.at(4)}}}; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc index 3a8a5e9a60..51606bc9d6 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc @@ -245,6 +245,7 @@ TEST_SUITE(FF_TEST_SUITE) { } } } + TEST_CASE("find_all_extended_series_reductions") { MultiDiGraph g = MultiDiGraph::create(); @@ -257,14 +258,14 @@ TEST_SUITE(FF_TEST_SUITE) { {n.at(2), n.at(3)}, }); - std::unordered_set> result = + std::unordered_set result = find_all_extended_series_reductions(g); - std::unordered_set> correct = { - {e[0], e[1], e[2]}}; + std::unordered_set correct = { + ExtendedSeriesReduction({e.at(0), e.at(1), e.at(2)})}; CHECK(result == correct); } - SUBCASE("2 linear strands") { + SUBCASE("2 linear strands with a common terminal node") { std::vector n = add_nodes(g, 4); std::vector e = add_edges(g, {{n.at(0), n.at(1)}, @@ -272,10 +273,11 @@ TEST_SUITE(FF_TEST_SUITE) { {n.at(1), n.at(3)}, {n.at(2), n.at(3)}}); - std::unordered_set> result = + std::unordered_set result = find_all_extended_series_reductions(g); - std::unordered_set> correct = {{e[0], e[2]}, - {e[1], e[3]}}; + std::unordered_set correct = { + ExtendedSeriesReduction({e.at(0), e.at(2)}), + ExtendedSeriesReduction({e.at(1), e.at(3)})}; CHECK(result == correct); } @@ -294,10 +296,12 @@ TEST_SUITE(FF_TEST_SUITE) { {n.at(6), n.at(8)}, {n.at(7), n.at(8)}}); - std::unordered_set> result = + std::unordered_set result = find_all_extended_series_reductions(g); - std::unordered_set> correct = { - {e[0], e[2], e[7]}, {e[3], e[6]}, {e[5], e[9]}}; + std::unordered_set correct = { + ExtendedSeriesReduction({e.at(0), e.at(2), e.at(7)}), + ExtendedSeriesReduction({e.at(3), e.at(6)}), + ExtendedSeriesReduction({e.at(5), e.at(9)})}; CHECK(result == correct); } } @@ -310,7 +314,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector e = add_edges( g, {{n.at(0), n.at(1)}, {n.at(1), n.at(2)}, {n.at(2), n.at(3)}}); - std::vector reduction = {e.at(0), e.at(1), e.at(2)}; + ExtendedSeriesReduction reduction({e.at(0), e.at(1), e.at(2)}); MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); @@ -355,7 +359,7 @@ TEST_SUITE(FF_TEST_SUITE) { {n.at(5), n.at(7)}, }); - std::vector reduction = {e.at(3), e.at(4), e.at(5)}; + ExtendedSeriesReduction reduction({e.at(3), e.at(4), e.at(5)}); MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); @@ -370,9 +374,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set result_edges = get_edges(g); std::unordered_set correct_edges = [&] { std::unordered_set new_edges = unordered_set_of(e); - new_edges.erase(e.at(3)); - new_edges.erase(e.at(4)); - new_edges.erase(e.at(5)); + new_edges = set_minus(new_edges, {e.at(3), e.at(4), e.at(5)}); new_edges.insert(returned_edge); return new_edges; }(); diff --git a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc index 179cce7db7..e6b0575ff5 100644 --- a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc +++ b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc @@ -7,55 +7,86 @@ using namespace FlexFlow; -TEST_CASE("get_connected_components") { - UndirectedGraph g = UndirectedGraph::create(); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); - SUBCASE("disjoint nodes") { - std::vector n = add_nodes(g, 3); + SUBCASE("disjoint nodes") { + std::vector n = add_nodes(g, 3); - std::unordered_set> correct = { - {n[0]}, - {n[1]}, - {n[2]}, - }; - std::unordered_set> result = - get_connected_components(g); + std::unordered_set> correct = { + {n.at(0)}, + {n.at(1)}, + {n.at(2)}, + }; + std::unordered_set> result = + get_connected_components(g); - CHECK(correct == result); - } + CHECK(correct == result); + } - SUBCASE("2 components") { - std::vector n = add_nodes(g, 4); - add_edges(g, {UndirectedEdge{{n[0], n[1]}}, UndirectedEdge{{n[2], n[1]}}}); + SUBCASE("1 component") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(2), n.at(3)}}, + UndirectedEdge{{n.at(3), n.at(0)}}, + }); - std::unordered_set> correct = { - {n[0], n[1], n[2]}, - {n[3]}, - }; - std::unordered_set> result = - get_connected_components(g); + std::unordered_set> correct = { + {n.at(0), n.at(1), n.at(2), n.at(3)}, + }; + std::unordered_set> result = + get_connected_components(g); - CHECK(correct == result); - } + CHECK(correct == result); + } + + SUBCASE("2 components") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(2), n.at(1)}}}); + + std::unordered_set> correct = { + {n.at(0), n.at(1), n.at(2)}, + {n.at(3)}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("3 components") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(0), n.at(2)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(3), n.at(4)}}, + }); + + std::unordered_set> correct = { + {n.at(0), n.at(1), n.at(2)}, + {n.at(3), n.at(4)}, + {n.at(5)}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("empty graph") { + std::unordered_set> correct = {}; + std::unordered_set> result = + get_connected_components(g); - SUBCASE("3 components") { - std::vector n = add_nodes(g, 6); - add_edges(g, - { - UndirectedEdge{{n[0], n[1]}}, - UndirectedEdge{{n[0], n[2]}}, - UndirectedEdge{{n[1], n[2]}}, - UndirectedEdge{{n[3], n[4]}}, - }); - - std::unordered_set> correct = { - {n[0], n[1], n[2]}, - {n[3], n[4]}, - {n[5]}, - }; - std::unordered_set> result = - get_connected_components(g); - - CHECK(correct == result); + CHECK(correct == result); + } } } diff --git a/lib/utils/test/src/utils/graph/undirected/undirected.cc b/lib/utils/test/src/utils/graph/undirected/undirected.cc index 7973cf8af5..6454379118 100644 --- a/lib/utils/test/src/utils/graph/undirected/undirected.cc +++ b/lib/utils/test/src/utils/graph/undirected/undirected.cc @@ -41,35 +41,3 @@ TEST_SUITE(FF_TEST_SUITE) { }); } } -/* static_assert(is_fmtable::value, ""); */ - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE( - "UndirectedGraph implementations", T, HashmapUndirectedGraph) { - - RC_SUBCASE("Full", [&]() { - UndirectedGraph g = UndirectedGraph::create(); - int num_nodes = *gen::inRange(1, 10); - std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); - int num_edges = *gen::inRange(0, num_nodes); - std::vector e; - if (num_nodes > 0) { - e = *gen::unique>( - num_edges, - gen::construct( - gen::construct>(gen::elementOf(n), - gen::elementOf(n)))); - } - for (UndirectedEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(node_query_all()) == unordered_set_of(n)); - - auto subset = *rc::subset_of(n); - CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); - - CHECK(g.query_edges(undirected_edge_query_all()) == unordered_set_of(e)); - }); - } -}