Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
easyeasydev committed Jan 23, 2025
1 parent 0e2e07f commit 4001926
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
1 change: 1 addition & 0 deletions lib/models/include/models/dlrm/dlrm.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_DLRM_H
#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_DLRM_H

#include "models/dlrm/dlrm_arch_interaction_op.dtg.h"
#include "models/dlrm/dlrm_config.dtg.h"
#include "pcg/computation_graph_builder.h"

Expand Down
10 changes: 10 additions & 0 deletions lib/models/include/models/dlrm/dlrm_arch_interaction_op.enum.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace = "FlexFlow"
name = "DLRMArchInteractionOp"

features = ["hash", "json", "rapidcheck", "fmt"]

[[values]]
name = "DOT"

[[values]]
name = "CAT"
3 changes: 2 additions & 1 deletion lib/models/include/models/dlrm/dlrm_config.struct.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ features = [
includes = [
"<vector>",
"<string>",
"models/dlrm/dlrm_arch_interaction_op.dtg.h",
]

src_includes = [
Expand Down Expand Up @@ -42,7 +43,7 @@ type = "std::vector<int>"

[[fields]]
name = "arch_interaction_op"
type = "std::string"
type = "FlexFlow::DLRMArchInteractionOp"

[[fields]]
name = "batch_size"
Expand Down
44 changes: 24 additions & 20 deletions lib/models/src/models/dlrm/dlrm.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "models/dlrm/dlrm.h"
#include "pcg/computation_graph.h"
#include "utils/containers/concat_vectors.h"
#include "utils/containers/repeat.h"
#include "utils/containers/transform.h"
#include "utils/containers/zip.h"

Expand Down Expand Up @@ -29,7 +30,7 @@ DLRMConfig get_default_dlrm_config() {
64,
2,
},
/*arch_interaction_op=*/"cat",
/*arch_interaction_op=*/DLRMArchInteractionOp::CAT,
/*batch_size=*/64,
/*seed=*/std::rand(),
};
Expand Down Expand Up @@ -97,12 +98,13 @@ tensor_guid_t create_dlrm_interact_features(
DLRMConfig const &config,
tensor_guid_t const &bottom_mlp_output,
std::vector<tensor_guid_t> const &emb_outputs) {
if (config.arch_interaction_op != "cat") {
if (config.arch_interaction_op != DLRMArchInteractionOp::CAT) {
throw mk_runtime_error(fmt::format(
"Currently only arch_interaction_op=cat is supported, but found "
"arch_interaction_op={}. If you need support for additional "
"Currently only arch_interaction_op=DLRMArchInteractionOp::CAT is "
"supported, but found arch_interaction_op={}. If you need support for "
"additional "
"arch_interaction_op value, please create an issue.",
config.arch_interaction_op));
format_as(config.arch_interaction_op)));
}

return cgb.concat(
Expand All @@ -123,11 +125,13 @@ ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) {
};

// Create input tensors
std::vector<tensor_guid_t> sparse_inputs(
config.embedding_size.size(),
create_input_tensor({static_cast<size_t>(config.batch_size),
static_cast<size_t>(config.embedding_bag_size)},
DataType::INT64));
std::vector<tensor_guid_t> sparse_inputs =
repeat(config.embedding_size.size(), [&]() {
return create_input_tensor(
{static_cast<size_t>(config.batch_size),
static_cast<size_t>(config.embedding_bag_size)},
DataType::INT64);
});

tensor_guid_t dense_input = create_input_tensor(
{static_cast<size_t>(config.batch_size),
Expand All @@ -141,16 +145,16 @@ ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) {
/*input=*/dense_input,
/*mlp_layers=*/config.dense_arch_layer_sizes);

std::vector<tensor_guid_t> emb_outputs;
for (size_t i = 0; i < config.embedding_size.size(); i++) {
int input_dim = config.embedding_size.at(i);
emb_outputs.emplace_back(create_dlrm_sparse_embedding_network(
/*cgb=*/cgb,
/*config=*/config,
/*input=*/sparse_inputs.at(i),
/*input_dim=*/input_dim,
/*output_dim=*/config.embedding_dim));
}
std::vector<tensor_guid_t> emb_outputs = transform(
zip(config.embedding_size, sparse_inputs),
[&](std::pair<int, tensor_guid_t> const &combined_pair) -> tensor_guid_t {
return create_dlrm_sparse_embedding_network(
/*cgb=*/cgb,
/*config=*/config,
/*input=*/combined_pair.second,
/*input_dim=*/combined_pair.first,
/*output_dim=*/config.embedding_dim);
});

tensor_guid_t interacted_features = create_dlrm_interact_features(
/*cgb=*/cgb,
Expand Down
2 changes: 1 addition & 1 deletion lib/models/test/src/models/dlrm/dlrm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TEST_SUITE(FF_TEST_SUITE) {

SUBCASE("num layers") {
int result_num_layers = get_layers(result).size();
int correct_num_layers = 27;
int correct_num_layers = 30;
CHECK(result_num_layers == correct_num_layers);
}
}
Expand Down

0 comments on commit 4001926

Please sign in to comment.