From 40019265b96bfa3f1bbc3cffded483f58d0abd73 Mon Sep 17 00:00:00 2001 From: easyeasydev Date: Wed, 22 Jan 2025 21:07:53 -0500 Subject: [PATCH] Address review comments --- lib/models/include/models/dlrm/dlrm.h | 1 + .../dlrm/dlrm_arch_interaction_op.enum.toml | 10 +++++ .../models/dlrm/dlrm_config.struct.toml | 3 +- lib/models/src/models/dlrm/dlrm.cc | 44 ++++++++++--------- lib/models/test/src/models/dlrm/dlrm.cc | 2 +- 5 files changed, 38 insertions(+), 22 deletions(-) create mode 100644 lib/models/include/models/dlrm/dlrm_arch_interaction_op.enum.toml diff --git a/lib/models/include/models/dlrm/dlrm.h b/lib/models/include/models/dlrm/dlrm.h index f89603ec47..c3443f3b9b 100644 --- a/lib/models/include/models/dlrm/dlrm.h +++ b/lib/models/include/models/dlrm/dlrm.h @@ -12,6 +12,7 @@ #ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_DLRM_H #define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_DLRM_H +#include "models/dlrm/dlrm_arch_interaction_op.dtg.h" #include "models/dlrm/dlrm_config.dtg.h" #include "pcg/computation_graph_builder.h" diff --git a/lib/models/include/models/dlrm/dlrm_arch_interaction_op.enum.toml b/lib/models/include/models/dlrm/dlrm_arch_interaction_op.enum.toml new file mode 100644 index 0000000000..e847537a82 --- /dev/null +++ b/lib/models/include/models/dlrm/dlrm_arch_interaction_op.enum.toml @@ -0,0 +1,10 @@ +namespace = "FlexFlow" +name = "DLRMArchInteractionOp" + +features = ["hash", "json", "rapidcheck", "fmt"] + +[[values]] +name = "DOT" + +[[values]] +name = "CAT" diff --git a/lib/models/include/models/dlrm/dlrm_config.struct.toml b/lib/models/include/models/dlrm/dlrm_config.struct.toml index 52488b3229..cecdd8f8a0 100644 --- a/lib/models/include/models/dlrm/dlrm_config.struct.toml +++ b/lib/models/include/models/dlrm/dlrm_config.struct.toml @@ -13,6 +13,7 @@ features = [ includes = [ "", "", + "models/dlrm/dlrm_arch_interaction_op.dtg.h", ] src_includes = [ @@ -42,7 +43,7 @@ type = "std::vector" [[fields]] name = "arch_interaction_op" -type = "std::string" +type = "FlexFlow::DLRMArchInteractionOp" [[fields]] name = "batch_size" diff --git a/lib/models/src/models/dlrm/dlrm.cc b/lib/models/src/models/dlrm/dlrm.cc index 74cf486e2a..5454e517e9 100644 --- a/lib/models/src/models/dlrm/dlrm.cc +++ b/lib/models/src/models/dlrm/dlrm.cc @@ -1,6 +1,7 @@ #include "models/dlrm/dlrm.h" #include "pcg/computation_graph.h" #include "utils/containers/concat_vectors.h" +#include "utils/containers/repeat.h" #include "utils/containers/transform.h" #include "utils/containers/zip.h" @@ -29,7 +30,7 @@ DLRMConfig get_default_dlrm_config() { 64, 2, }, - /*arch_interaction_op=*/"cat", + /*arch_interaction_op=*/DLRMArchInteractionOp::CAT, /*batch_size=*/64, /*seed=*/std::rand(), }; @@ -97,12 +98,13 @@ tensor_guid_t create_dlrm_interact_features( DLRMConfig const &config, tensor_guid_t const &bottom_mlp_output, std::vector const &emb_outputs) { - if (config.arch_interaction_op != "cat") { + if (config.arch_interaction_op != DLRMArchInteractionOp::CAT) { throw mk_runtime_error(fmt::format( - "Currently only arch_interaction_op=cat is supported, but found " - "arch_interaction_op={}. If you need support for additional " + "Currently only arch_interaction_op=DLRMArchInteractionOp::CAT is " + "supported, but found arch_interaction_op={}. If you need support for " + "additional " "arch_interaction_op value, please create an issue.", - config.arch_interaction_op)); + format_as(config.arch_interaction_op))); } return cgb.concat( @@ -123,11 +125,13 @@ ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) { }; // Create input tensors - std::vector sparse_inputs( - config.embedding_size.size(), - create_input_tensor({static_cast(config.batch_size), - static_cast(config.embedding_bag_size)}, - DataType::INT64)); + std::vector sparse_inputs = + repeat(config.embedding_size.size(), [&]() { + return create_input_tensor( + {static_cast(config.batch_size), + static_cast(config.embedding_bag_size)}, + DataType::INT64); + }); tensor_guid_t dense_input = create_input_tensor( {static_cast(config.batch_size), @@ -141,16 +145,16 @@ ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) { /*input=*/dense_input, /*mlp_layers=*/config.dense_arch_layer_sizes); - std::vector emb_outputs; - for (size_t i = 0; i < config.embedding_size.size(); i++) { - int input_dim = config.embedding_size.at(i); - emb_outputs.emplace_back(create_dlrm_sparse_embedding_network( - /*cgb=*/cgb, - /*config=*/config, - /*input=*/sparse_inputs.at(i), - /*input_dim=*/input_dim, - /*output_dim=*/config.embedding_dim)); - } + std::vector emb_outputs = transform( + zip(config.embedding_size, sparse_inputs), + [&](std::pair const &combined_pair) -> tensor_guid_t { + return create_dlrm_sparse_embedding_network( + /*cgb=*/cgb, + /*config=*/config, + /*input=*/combined_pair.second, + /*input_dim=*/combined_pair.first, + /*output_dim=*/config.embedding_dim); + }); tensor_guid_t interacted_features = create_dlrm_interact_features( /*cgb=*/cgb, diff --git a/lib/models/test/src/models/dlrm/dlrm.cc b/lib/models/test/src/models/dlrm/dlrm.cc index 97c528254f..b01a383b36 100644 --- a/lib/models/test/src/models/dlrm/dlrm.cc +++ b/lib/models/test/src/models/dlrm/dlrm.cc @@ -12,7 +12,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("num layers") { int result_num_layers = get_layers(result).size(); - int correct_num_layers = 27; + int correct_num_layers = 30; CHECK(result_num_layers == correct_num_layers); } }