From 7b04ccbbc47a18b838fa04ba0cb5edb22c5d1acb Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Fri, 27 Dec 2024 14:34:03 -0800 Subject: [PATCH] task_simulator_forward_pass --- .../compiler/cost_estimator/task_simulator.h | 17 ++ .../timed_dependency.struct.toml | 21 ++ .../cost_estimator/timed_layer.struct.toml | 21 ++ .../machine_mapping/machine_mapping.h | 9 + .../unmapped_op_cost_estimate_key.h | 5 + .../src/compiler/allowed_machine_views.cc | 4 + .../compiler/cost_estimator/cost_estimator.cc | 1 + .../compiler/cost_estimator/task_simulator.cc | 174 +++++++++++++ .../machine_mapping/machine_mapping.cc | 29 ++- .../unmapped_op_cost_estimate_key.cc | 8 + .../compiler/cost_estimator/task_simulator.cc | 234 ++++++++++++++++++ .../cost_estimator_for_test.cc | 7 + .../cost_estimator_for_test.h | 3 + .../get_optimal_machine_mapping.cc | 2 +- .../get_tensor_set_movement_across_split.cc | 2 +- .../machine_mapping/machine_mapping.cc | 1 - lib/pcg/include/pcg/machine_specification.h | 5 + lib/pcg/include/pcg/machine_view.h | 8 + lib/pcg/include/pcg/operator_task_space.h | 3 + .../parallel_computation_graph.h | 15 ++ lib/pcg/src/pcg/machine_specification.cc | 9 + lib/pcg/src/pcg/machine_view.cc | 33 ++- lib/pcg/src/pcg/operator_task_space.cc | 12 +- .../parallel_computation_graph.cc | 33 +++ .../parallel_computation_graph.cc | 83 ++++++- lib/runtime/src/parallel_compuation_graph.cc | 7 - .../algorithms/get_outgoing_edges.h | 16 ++ .../algorithms/get_outgoing_edges.cc | 28 +++ 28 files changed, 773 insertions(+), 17 deletions(-) create mode 100644 lib/compiler/include/compiler/cost_estimator/task_simulator.h create mode 100644 lib/compiler/include/compiler/cost_estimator/timed_dependency.struct.toml create mode 100644 lib/compiler/include/compiler/cost_estimator/timed_layer.struct.toml create mode 100644 lib/compiler/src/compiler/cost_estimator/task_simulator.cc create mode 100644 lib/compiler/test/src/compiler/cost_estimator/task_simulator.cc rename lib/compiler/test/src/compiler/{machine_mapping => }/cost_estimator_for_test.cc (83%) rename lib/compiler/test/src/compiler/{machine_mapping => }/cost_estimator_for_test.h (91%) delete mode 100644 lib/runtime/src/parallel_compuation_graph.cc create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.h create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.cc diff --git a/lib/compiler/include/compiler/cost_estimator/task_simulator.h b/lib/compiler/include/compiler/cost_estimator/task_simulator.h new file mode 100644 index 0000000000..51549ba3f3 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/task_simulator.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_TASK_SIMULATOR_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_TASK_SIMULATOR_H + +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" + +namespace FlexFlow { +float task_simulator_forward_pass(ParallelComputationGraph const &pcg, + CostEstimator const &estimator, + MachineMapping const &machine_mapping, + MachineSpecification const &machine_spec); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/cost_estimator/timed_dependency.struct.toml b/lib/compiler/include/compiler/cost_estimator/timed_dependency.struct.toml new file mode 100644 index 0000000000..3725a23bb3 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/timed_dependency.struct.toml @@ -0,0 +1,21 @@ + +namespace = "FlexFlow" +name = "TimedDependency" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::ParallelComputationGraphEdge" + +[[fields]] +name = "endtime" +type = "float" diff --git a/lib/compiler/include/compiler/cost_estimator/timed_layer.struct.toml b/lib/compiler/include/compiler/cost_estimator/timed_layer.struct.toml new file mode 100644 index 0000000000..d4564840c9 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/timed_layer.struct.toml @@ -0,0 +1,21 @@ + +namespace = "FlexFlow" +name = "TimedLayer" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", +] + +[[fields]] +name = "layer" +type = "::FlexFlow::parallel_layer_guid_t" + +[[fields]] +name = "endtime" +type = "float" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h index 06cbbf942d..c20e28bc9c 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h @@ -2,6 +2,10 @@ #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H #include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/operator_task_space.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" namespace FlexFlow { @@ -10,6 +14,11 @@ MachineMapping combine_disjoint_mappings(MachineMapping const &, bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); +std::unordered_map> + get_device_mapping(MachineMapping const &machine_mapping, + MachineSpecification const &machine_spec, + ParallelComputationGraph const &pcg); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h index 9fbad4a1d0..63b6ed97f4 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h @@ -15,6 +15,11 @@ OpCostEstimateKey map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, MachineView const &machine_view); +OpCostEstimateKey get_mapped_op_cost_estimate_key_for_layer( + ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &layer, + MachineView const &machine_view); + } // namespace FlexFlow #endif diff --git a/lib/compiler/src/compiler/allowed_machine_views.cc b/lib/compiler/src/compiler/allowed_machine_views.cc index 1c226f79b0..db7477b460 100644 --- a/lib/compiler/src/compiler/allowed_machine_views.cc +++ b/lib/compiler/src/compiler/allowed_machine_views.cc @@ -24,6 +24,10 @@ namespace FlexFlow { bool is_valid_machine_view(MachineView const &mv, OperatorTaskSpace const &task, MachineSpecification const &ms) { + if (num_dims(mv) != num_dims(task)) { + return false; + } + std::optional maximum_device_coord = get_machine_space_coordinate( task, mv, get_task_space_maximum_coordinate(task), ms); diff --git a/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc index 051ffcd190..2033bc1ca3 100644 --- a/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc +++ b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc @@ -1,3 +1,4 @@ + #include "compiler/cost_estimator/cost_estimator.h" namespace FlexFlow { diff --git a/lib/compiler/src/compiler/cost_estimator/task_simulator.cc b/lib/compiler/src/compiler/cost_estimator/task_simulator.cc new file mode 100644 index 0000000000..aa168df049 --- /dev/null +++ b/lib/compiler/src/compiler/cost_estimator/task_simulator.cc @@ -0,0 +1,174 @@ +#include "compiler/cost_estimator/task_simulator.h" +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/cost_estimator/single_tensor_movement.dtg.h" +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/cost_estimator/timed_dependency.dtg.h" +#include "compiler/cost_estimator/timed_layer.dtg.h" +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/device_id.h" +#include "pcg/device_id_t.dtg.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph_edge.dtg.h" +#include "utils/containers/all_of.h" +#include "utils/containers/contains.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_one_of.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/deduplicated_priority_queue.h" +#include "utils/graph/dataflow_graph/algorithms/get_outgoing_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h" +#include "utils/hash/unordered_set.h" +#include +#include + +namespace FlexFlow { + +static float + single_parallel_layer_cost_estimator(parallel_layer_guid_t const &layer, + ParallelComputationGraph const &pcg, + CostEstimator const &estimator, + MachineView const &mv) { + return estimator.estimate_cost( + get_mapped_op_cost_estimate_key_for_layer(pcg, layer, mv)); +} + +static float single_dependency_cost_estimator( + ParallelComputationGraphEdge const &dependency, + ParallelComputationGraph const &pcg, + MachineMapping const &machine_mapping, + CostEstimator const &estimator) { + parallel_layer_guid_t incoming = get_src_layer(dependency); + parallel_layer_guid_t outgoing = get_dst_layer(dependency); + MachineView src_mv = machine_mapping.machine_views.at(incoming); + MachineView dst_mv = machine_mapping.machine_views.at(outgoing); + ParallelTensorShape tensor_shape = get_parallel_tensor_shape( + pcg, parallel_tensor_guid_t{dependency.raw_edge.src}); + TensorSetMovement movement = TensorSetMovement{ + {SingleTensorMovement{tensor_shape, {src_mv}, {dst_mv}}}}; + return estimator.estimate_cost(movement); +} + +float task_simulator_forward_pass(ParallelComputationGraph const &pcg, + CostEstimator const &estimator, + MachineMapping const &machine_mapping, + MachineSpecification const &machine_spec) { + + float current_time = 0.0f; + + std::unordered_set layer_frontier; + DeduplicatedPriorityQueue> + layer_processing; + std::unordered_set processed_layers; + + DeduplicatedPriorityQueue> + dependency_processing; + std::unordered_set processed_dependencies; + + std::unordered_map> + device_mapping = get_device_mapping(machine_mapping, machine_spec, pcg); + + std::unordered_map devices = + generate_map(set_union(values(device_mapping)), + [](device_id_t const &d) { return false; }); + + auto start_layer_processing = [&](parallel_layer_guid_t const &layer) { + float cost = single_parallel_layer_cost_estimator( + layer, pcg, estimator, machine_mapping.machine_views.at(layer)); + layer_processing.push(TimedLayer{layer, current_time + cost}); + for (device_id_t d : device_mapping.at(layer)) { + devices[d] = true; + } + layer_frontier.erase(layer); + }; + + auto start_dependency_processing = [&](ParallelComputationGraphEdge const + &dependency, + float start_time) { + float cost = single_dependency_cost_estimator( + dependency, pcg, machine_mapping, estimator); + dependency_processing.push(TimedDependency{dependency, start_time + cost}); + }; + + auto finish_layer_processing = [&](TimedLayer const &timed_layer) { + for (device_id_t d : device_mapping.at(timed_layer.layer)) { + devices[d] = false; + } + processed_layers.insert(timed_layer); + current_time = timed_layer.endtime; + std::unordered_set outgoing_dependencies = + get_outgoing_edges(pcg, timed_layer.layer); + for (ParallelComputationGraphEdge const &dep : outgoing_dependencies) { + start_dependency_processing(dep, timed_layer.endtime); + } + }; + + auto finish_dependency_processing = + [&](TimedDependency const &timed_dependency) { + processed_dependencies.insert(timed_dependency); + parallel_layer_guid_t destination_layer = + get_dst_layer(timed_dependency.raw_edge); + std::unordered_set incoming_dependencies = + get_incoming_edges(pcg, destination_layer); + std::unordered_set + non_timed_processed_dependencies = transform( + processed_dependencies, + [](TimedDependency const &dep) { return dep.raw_edge; }); + // start processing a new node if all dependencies have been processed + // already + if (is_subseteq_of(incoming_dependencies, + non_timed_processed_dependencies)) { + layer_frontier.insert(destination_layer); + } + current_time = timed_dependency.endtime; + }; + + for (parallel_layer_guid_t const &layer : get_source_layers(pcg)) { + layer_frontier.insert(layer); + } + + while (!layer_frontier.empty() || !layer_processing.empty() || + !dependency_processing.empty()) { + + auto frontier_copy = layer_frontier; + for (parallel_layer_guid_t const &layer : frontier_copy) { + auto layer_devices = device_mapping.at(layer); + if (all_of(layer_devices, + [&](device_id_t d) { return devices.at(d) == false; })) { + start_layer_processing(layer); + } + } + + while (!dependency_processing.empty()) { + TimedDependency dep = dependency_processing.top(); + dependency_processing.pop(); + finish_dependency_processing(dep); + } + + if (!layer_processing.empty()) { + TimedLayer layer = layer_processing.top(); + layer_processing.pop(); + finish_layer_processing(layer); + } + } + + return current_time; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index 57e82684e9..47fe672a06 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -1,17 +1,40 @@ #include "compiler/machine_mapping/machine_mapping.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_view.h" +#include "pcg/operator_task_space.dtg.h" +#include "pcg/operator_task_space.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/are_disjoint.h" +#include "utils/containers/get_one_of.h" #include "utils/containers/keys.h" +#include "utils/containers/map_values.h" #include "utils/containers/merge_maps.h" namespace FlexFlow { -MachineMapping combine_disjoint_mappings(MachineMapping const &s1, - MachineMapping const &s2) { - return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; +MachineMapping combine_disjoint_mappings(MachineMapping const &m1, + MachineMapping const &m2) { + return MachineMapping{merge_maps(m1.machine_views, m2.machine_views)}; } bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); } +std::unordered_map> + get_device_mapping(MachineMapping const &machine_mapping, + MachineSpecification const &machine_spec, + ParallelComputationGraph const &pcg) { + std::unordered_map> + device_mapping; + for (auto const &[layer, machine_view] : machine_mapping.machine_views) { + parallel_tensor_guid_t out_tensor = get_layer_outputs(pcg, layer).at(0); + OperatorTaskSpace op = + get_operator_task_space(get_parallel_tensor_shape(pcg, out_tensor)); + device_mapping.insert( + {layer, get_device_ids(op, machine_view, machine_spec)}); + } + return device_mapping; +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc index 990b287f8b..e52a3e6456 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc @@ -33,4 +33,12 @@ OpCostEstimateKey }; } +OpCostEstimateKey get_mapped_op_cost_estimate_key_for_layer( + ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &layer, + MachineView const &machine_view) { + return map_unmapped_op_cost_estimate_key( + get_unmapped_op_cost_estimate_key_for_layer(pcg, layer), machine_view); +} + } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/cost_estimator/task_simulator.cc b/lib/compiler/test/src/compiler/cost_estimator/task_simulator.cc new file mode 100644 index 0000000000..64cec763ea --- /dev/null +++ b/lib/compiler/test/src/compiler/cost_estimator/task_simulator.cc @@ -0,0 +1,234 @@ +#include "compiler/cost_estimator/task_simulator.h" +#include "../cost_estimator_for_test.h" +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/cost_estimator/timed_layer.dtg.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "op-attrs/parallel_tensor_dims.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/device_id.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_specification_dimension.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/machine_view.h" +#include "pcg/machine_view_dimension.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" +#include "pcg/stride_t.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "utils/containers/get_only.h" +#include "utils/deduplicated_priority_queue.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h" +#include +#include +#include +#include + +namespace FlexFlow { + +TEST_SUITE(FF_TEST_SUITE) { + CostEstimator estimator = + make_fake_constant_cost_estimator(/*op_cost*/ 10.0f, /*comm_cost*/ 1.0f); + MachineSpecification machine_spec = MachineSpecification{3, 3, 3, 1, 1}; + + TEST_CASE("task_simulator: linear graph") { + ParallelComputationGraphBuilder b; + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ShardParallelDim{10, 1}}, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t tensor0 = b.create_input_tensor(input_shape); + parallel_tensor_guid_t tensor1 = b.relu(tensor0); + + parallel_layer_guid_t layer0 = get_source_layer(tensor0); + parallel_layer_guid_t layer1 = get_source_layer(tensor1); + + ParallelComputationGraph pcg = b.pcg; + + std::unordered_set layers = {layer0, layer1}; + CHECK(get_parallel_layers(pcg) == layers); + + MachineView mv1 = MachineView{ + MachineSpaceCoordinate{0, 0, DeviceType::GPU}, + { + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + }}; + MachineView mv2 = MachineView{ + MachineSpaceCoordinate{0, 1, DeviceType::GPU}, + { + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + }}; + + MachineMapping device_mapping = MachineMapping{{ + {layer0, mv1}, + {layer1, mv2}, + }}; + + float result = task_simulator_forward_pass( + pcg, estimator, device_mapping, machine_spec); + float correct = 10 + 1 + 10; + CHECK(result == correct); + } + + TEST_CASE("task_simulator: rhombus graph, all separate devices") { + ParallelComputationGraphBuilder b; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ShardParallelDim{10, 1}}, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t tensor0 = b.create_input_tensor(input_shape); + parallel_tensor_guid_t tensor1 = b.relu(tensor0); + parallel_tensor_guid_t tensor2 = b.relu(tensor0); + parallel_tensor_guid_t tensor3 = b.add(tensor1, tensor2); + + parallel_layer_guid_t layer0 = get_source_layer(tensor0); + parallel_layer_guid_t layer1 = get_source_layer(tensor1); + parallel_layer_guid_t layer2 = get_source_layer(tensor2); + parallel_layer_guid_t layer3 = get_source_layer(tensor3); + + ParallelComputationGraph pcg = b.pcg; + + std::unordered_set layers = { + layer0, layer1, layer2, layer3}; + CHECK(get_parallel_layers(pcg) == layers); + + MachineView mv0 = MachineView{ + MachineSpaceCoordinate{0, 0, DeviceType::GPU}, + { + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + }}; + MachineView mv1 = MachineView{ + MachineSpaceCoordinate{0, 1, DeviceType::GPU}, + { + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + }}; + MachineView mv2 = MachineView{ + MachineSpaceCoordinate{1, 0, DeviceType::GPU}, + { + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + }}; + MachineView mv3 = MachineView{ + MachineSpaceCoordinate{1, 1, DeviceType::GPU}, + { + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + }}; + + MachineMapping device_mapping = MachineMapping{{ + {layer0, mv0}, + {layer1, mv1}, + {layer2, mv2}, + {layer3, mv3}, + }}; + + float result = task_simulator_forward_pass( + pcg, estimator, device_mapping, machine_spec); + float correct = 10 + 1 + 10 + 1 + 10; + CHECK(result == correct); + } + + TEST_CASE("task_simulator: rhombus graph, all same device") { + ParallelComputationGraphBuilder b; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ShardParallelDim{10, 1}}, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t tensor0 = b.create_input_tensor(input_shape); + parallel_tensor_guid_t tensor1 = b.relu(tensor0); + parallel_tensor_guid_t tensor2 = b.relu(tensor0); + parallel_tensor_guid_t tensor3 = b.add(tensor1, tensor2); + + parallel_layer_guid_t layer0 = get_source_layer(tensor0); + parallel_layer_guid_t layer1 = get_source_layer(tensor1); + parallel_layer_guid_t layer2 = get_source_layer(tensor2); + parallel_layer_guid_t layer3 = get_source_layer(tensor3); + + ParallelComputationGraph pcg = b.pcg; + + std::unordered_set layers = { + layer0, layer1, layer2, layer3}; + CHECK(get_parallel_layers(pcg) == layers); + + MachineView mv = MachineView{ + MachineSpaceCoordinate{0, 0, DeviceType::GPU}, + { + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + }}; + MachineMapping device_mapping = MachineMapping{{ + {layer0, mv}, + {layer1, mv}, + {layer2, mv}, + {layer3, mv}, + }}; + + float result = task_simulator_forward_pass( + pcg, estimator, device_mapping, machine_spec); + float correct = 10 + 10 + 10 + 10 + 1 + 1; + CHECK(result == correct); + } +} +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc b/lib/compiler/test/src/compiler/cost_estimator_for_test.cc similarity index 83% rename from lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc rename to lib/compiler/test/src/compiler/cost_estimator_for_test.cc index 9ee596af3e..4804ac0c11 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc +++ b/lib/compiler/test/src/compiler/cost_estimator_for_test.cc @@ -38,4 +38,11 @@ CostEstimator make_fake_cost_estimator( }); } +CostEstimator make_fake_constant_cost_estimator(float const &op_cost, + float const &comm_cost) { + return make_fake_cost_estimator( + [=](OpCostEstimateKey const &op) { return op_cost; }, + [=](TensorSetMovement const &op) { return comm_cost; }); +} + } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/cost_estimator_for_test.h similarity index 91% rename from lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h rename to lib/compiler/test/src/compiler/cost_estimator_for_test.h index 7c1d06207a..b9e5fb809b 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h +++ b/lib/compiler/test/src/compiler/cost_estimator_for_test.h @@ -33,6 +33,9 @@ CostEstimator make_fake_cost_estimator( std::unordered_map const &op_cost_map, std::unordered_map const &comm_cost_map); +CostEstimator make_fake_constant_cost_estimator(float const &op_cost, + float const &comm_cost); + } // namespace FlexFlow #endif diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index a0d06fe930..353e0f3160 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,5 +1,5 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" -#include "./cost_estimator_for_test.h" +#include "../cost_estimator_for_test.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/machine_mapping_cache.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index e22f715d82..52ad82595d 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -1,5 +1,5 @@ #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" -#include "./cost_estimator_for_test.h" +#include "../cost_estimator_for_test.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc index 221cca3ae1..304034f9be 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc @@ -1,5 +1,4 @@ #include "compiler/machine_mapping/machine_mapping.h" -#include "cost_estimator_for_test.h" #include "doctest/doctest.h" #include "pcg/machine_view.h" diff --git a/lib/pcg/include/pcg/machine_specification.h b/lib/pcg/include/pcg/machine_specification.h index 6ffa9900c2..404cafd5d7 100644 --- a/lib/pcg/include/pcg/machine_specification.h +++ b/lib/pcg/include/pcg/machine_specification.h @@ -20,6 +20,11 @@ bool is_valid_machine_space_coordinate(MachineSpecification const &ms, device_id_t get_device_id(MachineSpecification const &ms, MachineSpaceCoordinate const &coord); + +std::unordered_set + get_device_ids(MachineSpecification const &ms, + std::unordered_set const &coords); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 293227b7a1..f72b2359dc 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -37,6 +37,14 @@ std::unordered_set MachineView const &mv, MachineSpecification const &ms); +std::unordered_set get_device_ids(OperatorTaskSpace const &task, + MachineView const &mv, + MachineSpecification const &ms); + +MachineView make_1d_machine_view(MachineSpaceCoordinate const &start, + MachineSpecificationDimension const &dim, + stride_t stride); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/operator_task_space.h b/lib/pcg/include/pcg/operator_task_space.h index 61cab4eff1..9916750d88 100644 --- a/lib/pcg/include/pcg/operator_task_space.h +++ b/lib/pcg/include/pcg/operator_task_space.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H #define _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "pcg/operator_task_space.dtg.h" #include "pcg/task_space_coordinate.dtg.h" #include @@ -17,6 +18,8 @@ TaskSpaceCoordinate size_t num_dims(OperatorTaskSpace const &task); size_t num_tasks(OperatorTaskSpace const &task); +OperatorTaskSpace get_operator_task_space(ParallelTensorShape const &shape); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index c740e1ffd2..7f1a8c9750 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -6,6 +6,7 @@ #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include namespace FlexFlow { @@ -31,6 +32,17 @@ std::unordered_set parallel_layer_guid_t const &, parallel_layer_guid_t const &); +std::unordered_set + get_outgoing_edges(ParallelComputationGraph const &, + parallel_layer_guid_t const &); + +std::unordered_set + get_incoming_edges(ParallelComputationGraph const &, + parallel_layer_guid_t const &); + +std::unordered_set + get_source_layers(ParallelComputationGraph const &); + std::vector get_incoming_tensors(ParallelComputationGraph const &, parallel_layer_guid_t const &); @@ -45,6 +57,9 @@ std::vector get_incoming_weights(ParallelComputationGraph const &, parallel_layer_guid_t const &); +parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &g, + parallel_tensor_guid_t const &t); + ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, parallel_layer_guid_t const &); PCGOperatorAttrs pcg_get_op_attrs(ParallelComputationGraph const &, diff --git a/lib/pcg/src/pcg/machine_specification.cc b/lib/pcg/src/pcg/machine_specification.cc index ca5b8ba047..9c1d1db09c 100644 --- a/lib/pcg/src/pcg/machine_specification.cc +++ b/lib/pcg/src/pcg/machine_specification.cc @@ -1,5 +1,6 @@ #include "pcg/machine_specification.h" #include "pcg/device_id.h" +#include "utils/containers/transform.h" #include "utils/exception.h" namespace FlexFlow { @@ -50,4 +51,12 @@ device_id_t get_device_id(MachineSpecification const &ms, return device_id_from_index(raw_idx, coord.device_type); } +std::unordered_set + get_device_ids(MachineSpecification const &ms, + std::unordered_set const &coords) { + return transform(coords, [&](MachineSpaceCoordinate const &coord) { + return get_device_id(ms, coord); + }); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index 18f6cacb7e..8438b58df1 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -1,14 +1,20 @@ #include "pcg/machine_view.h" +#include "pcg/machine_specification.dtg.h" #include "pcg/machine_specification.h" +#include "pcg/machine_specification_dimension.dtg.h" +#include "pcg/machine_view_dimension.dtg.h" +#include "pcg/operator_task_space.dtg.h" #include "pcg/operator_task_space.h" +#include "pcg/stride_t.dtg.h" #include "utils/containers/contains.h" #include "utils/containers/count.h" #include "utils/containers/filter.h" +#include "utils/containers/get_only.h" #include "utils/containers/scanl.h" #include "utils/containers/sum.h" #include "utils/containers/transform.h" #include "utils/containers/zip.h" - +#include "utils/exception.h" namespace FlexFlow { size_t num_dims(MachineView const &mv) { @@ -35,6 +41,13 @@ MachineView machine_view_from_strides_and_machine_spec_dimensions( MachineSpaceCoordinate const &start, std::vector const &strides, std::vector const &dims) { + if (strides.size() != dims.size()) { + throw mk_runtime_error( + fmt::format("Dimensions of {} and {} must match when calling " + "machine_view_from_strides_and_machine_spec_dimensions", + start, + strides)); + } std::vector dimensions = transform(zip(strides, dims), [&](auto const &p) { return MachineViewDimension{p.first, p.second}; @@ -48,6 +61,7 @@ std::optional get_machine_space_coordinate( TaskSpaceCoordinate const &coord, MachineSpecification const &machine_specification) { + assert(num_dims(machine_view) == task.degrees.size()); auto get_dimension_indices_for_dimension = [&](MachineSpecificationDimension dimension) { std::vector mv_dimensions = @@ -112,4 +126,21 @@ std::unordered_set get_machine_space_coordinates( }); } +std::unordered_set get_device_ids(OperatorTaskSpace const &task, + MachineView const &mv, + MachineSpecification const &ms) { + return transform(get_machine_space_coordinates(task, mv, ms), + [&](MachineSpaceCoordinate const &coord) { + return get_device_id(ms, coord); + }); +} + +MachineView make_1d_machine_view(MachineSpaceCoordinate const &start, + MachineSpecificationDimension const &dim, + stride_t stride) { + + return machine_view_from_strides_and_machine_spec_dimensions( + start, {stride}, {dim}); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_task_space.cc b/lib/pcg/src/pcg/operator_task_space.cc index 2538cb4ea0..165a4d7c24 100644 --- a/lib/pcg/src/pcg/operator_task_space.cc +++ b/lib/pcg/src/pcg/operator_task_space.cc @@ -1,12 +1,14 @@ #include "pcg/operator_task_space.h" +#include "op-attrs/parallel_tensor_shape.h" #include "utils/containers/cartesian_product.h" +#include "utils/containers/extend.h" #include "utils/containers/maximum.h" #include "utils/containers/product.h" #include "utils/containers/range.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" +#include "utils/containers/vector_of.h" #include "utils/fmt/unordered_set.h" - namespace FlexFlow { std::unordered_set @@ -36,4 +38,12 @@ size_t num_tasks(OperatorTaskSpace const &task) { return product(task.degrees); } +OperatorTaskSpace get_operator_task_space(ParallelTensorShape const &shape) { + std::vector degrees; + extend(degrees, vector_of(ff_ordered_shard_degrees(shape))); + degrees.push_back(get_sum_degree(shape)); + degrees.push_back(get_discard_copy_degree(shape)); + return OperatorTaskSpace{degrees}; +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 781c44640c..2d08044cf2 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,15 +1,21 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "op-attrs/get_incoming_tensor_roles.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" #include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" +#include "utils/graph/dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/graph/dataflow_graph/algorithms/get_outgoing_edges.h" +#include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" #include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node.dtg.h" namespace FlexFlow { @@ -78,6 +84,33 @@ std::unordered_set }); } +std::unordered_set + get_outgoing_edges(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + std::unordered_set raw_edges = + get_outgoing_edges(pcg.raw_graph, l.raw_graph_node); + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); +} + +std::unordered_set + get_incoming_edges(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + std::unordered_set raw_edges = + unordered_set_of(get_incoming_edges(pcg.raw_graph, l.raw_graph_node)); + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); +} + +std::unordered_set + get_source_layers(ParallelComputationGraph const &pcg) { + std::unordered_set raw_sources = get_sources(pcg.raw_graph); + return transform(raw_sources, + [](Node const &n) { return parallel_layer_guid_t{n}; }); +} + std::vector get_incoming_tensors(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index fc07edf5b3..a36d3bc42a 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -36,8 +36,8 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t tensor3 = get_only(layer3_added.outputs); std::vector result = topological_ordering(pcg); - // std::vector correct = {layer1, layer2, layer3}; - // CHECK(result == correct); + std::vector correct = {layer1, layer2, layer3}; + CHECK(result == correct); } TEST_CASE( @@ -105,6 +105,85 @@ TEST_SUITE(FF_TEST_SUITE) { } } + TEST_CASE("get_source_layer") { + ParallelTensorShape tensor_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAttrs layer_label = some(); + ParallelTensorAttrs tensor_label = some(); + + SUBCASE("single layer") { + ParallelLayerAddedResult layer1_added = + add_parallel_layer(pcg, layer_label, {}, {tensor_label}); + parallel_layer_guid_t layer1 = layer1_added.parallel_layer; + parallel_tensor_guid_t tensor1 = get_only(layer1_added.outputs); + + SUBCASE("get_source_layer") { + parallel_layer_guid_t result = get_source_layer(pcg, tensor1); + parallel_layer_guid_t correct = layer1; + CHECK(result == correct); + } + } + + SUBCASE("two connected layers") { + ParallelLayerAddedResult layer1_added = + add_parallel_layer(pcg, layer_label, {}, {tensor_label}); + parallel_layer_guid_t layer1 = layer1_added.parallel_layer; + parallel_tensor_guid_t tensor1 = get_only(layer1_added.outputs); + + ParallelLayerAddedResult layer2_added = + add_parallel_layer(pcg, layer_label, {tensor1}, {tensor_label}); + parallel_layer_guid_t layer2 = layer2_added.parallel_layer; + + SUBCASE("get_source_layer") { + parallel_layer_guid_t result = get_source_layer(pcg, tensor1); + parallel_layer_guid_t correct = layer1; + CHECK(result == correct); + } + } + + SUBCASE("three layers in serial") { + ParallelLayerAddedResult layer1_added = + add_parallel_layer(pcg, layer_label, {}, {tensor_label}); + parallel_layer_guid_t layer1 = layer1_added.parallel_layer; + parallel_tensor_guid_t tensor1 = get_only(layer1_added.outputs); + + ParallelLayerAddedResult layer2_added = + add_parallel_layer(pcg, layer_label, {tensor1}, {tensor_label}); + parallel_layer_guid_t layer2 = layer2_added.parallel_layer; + parallel_tensor_guid_t tensor2 = get_only(layer2_added.outputs); + + ParallelLayerAddedResult layer3_added = + add_parallel_layer(pcg, layer_label, {tensor2}, {tensor_label}); + parallel_layer_guid_t layer3 = layer3_added.parallel_layer; + + SUBCASE("get_source_layer - tensor 1") { + parallel_layer_guid_t result = get_source_layer(pcg, tensor1); + parallel_layer_guid_t correct = layer1; + CHECK(result == correct); + } + + SUBCASE("get_source_layer - tensor 2") { + parallel_layer_guid_t result = get_source_layer(pcg, tensor2); + parallel_layer_guid_t correct = layer2; + CHECK(result == correct); + } + } + } + TEST_CASE( "get_incoming_weights(ParallelComputationGraph, parallel_layer_guid_t)") { ParallelTensorShape input_shape = ParallelTensorShape{ diff --git a/lib/runtime/src/parallel_compuation_graph.cc b/lib/runtime/src/parallel_compuation_graph.cc deleted file mode 100644 index ebc5ac1e8e..0000000000 --- a/lib/runtime/src/parallel_compuation_graph.cc +++ /dev/null @@ -1,7 +0,0 @@ -#include "parallel_computation_graph.h" - -namespace FlexFlow { - -ParallelTensor ParallelComputationGraph::{} - -} // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.h new file mode 100644 index 0000000000..a8b5efe66e --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_EDGES_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_outgoing_edges(DataflowGraphView const &, + Node const &); +std::unordered_set + get_outgoing_edges(DataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.cc new file mode 100644 index 0000000000..2376e4897f --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.cc @@ -0,0 +1,28 @@ +#include "utils/graph/dataflow_graph/algorithms/get_outgoing_edges.h" +#include "utils/containers/sorted_by.h" + +namespace FlexFlow { + +std::unordered_set get_outgoing_edges(DataflowGraphView const &g, + Node const &n) { + return g.query_edges(DataflowEdgeQuery{ + {n}, + query_set::matchall(), + query_set::matchall(), + query_set::matchall(), + }); +} + +std::unordered_set + get_outgoing_edges(DataflowGraphView const &g, + std::unordered_set const &ns) { + DataflowEdgeQuery query = DataflowEdgeQuery{ + query_set{ns}, + query_set::matchall(), + query_set::matchall(), + query_set::matchall(), + }; + return g.query_edges(query); +} + +} // namespace FlexFlow