Skip to content

Commit

Permalink
support ET dump for llama3 runner (#7507)
Browse files Browse the repository at this point in the history
Summary:

Support ET dump for llama3 runner to easy understand the regression performance issue

Differential Revision: D67656207
  • Loading branch information
billmguo authored and facebook-github-bot committed Jan 6, 2025
1 parent 68c0208 commit 83e01d8
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 5 deletions.
4 changes: 2 additions & 2 deletions examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

# model sharding with custom op
set(CUSTOM_OP_SRCS_FILE
set(CUSTOM_OP_SRCS_FILE
"${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp"
)
add_library(custom_ops ${CUSTOM_OP_SRCS_FILE})
Expand Down Expand Up @@ -45,7 +45,7 @@ list(
# build qnn llama3.2 1b runner
add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs})
target_include_directories(
qnn_llama3_2_runner PUBLIC ${_common_include_directories}
qnn_llama3_2_runner PUBLIC ${_common_include_directories} ${EXECUTORCH_SOURCE_DIR}/devtools/etdump
)

target_link_libraries(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ DEFINE_int32(
0,
"0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)");

DEFINE_bool(
gen_etdump,
false,
"false: Disable ET dump/ True: Enable ET dump (default: false)");


int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

Expand All @@ -57,7 +63,8 @@ int main(int argc, char** argv) {
{FLAGS_model_path},
FLAGS_tokenizer_path.c_str(),
FLAGS_temperature,
FLAGS_eval_mode);
FLAGS_eval_mode,
FLAGS_gen_etdump);
std::vector<char> buf;
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
std::ofstream fout(FLAGS_output_path.c_str());
Expand Down
49 changes: 48 additions & 1 deletion examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ Runner::Runner(
const std::vector<std::string>& models_path,
const std::string& tokenizer_path,
const float temperature,
const int eval_mode)
const int eval_mode,
const bool gen_etdump)
: n_bos_(1),
n_eos_(1),
tokenizer_path_(tokenizer_path),
Expand All @@ -54,6 +55,26 @@ Runner::Runner(
}
ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str());
ET_LOG(Info, "eval mode=%d", eval_mode);
if (gen_etdump) {
gen_etdump_ = true;
switch(eval_mode) {
case EvalMode::kPrefill:
prefill_dump_ = new torch::executor::ETDumpGen();
break;
case EvalMode::kKVCached:
decode_dump_ = new torch::executor::ETDumpGen();
break;
case EvalMode::kHybrid:
prefill_dump_ = new torch::executor::ETDumpGen();
decode_dump_ = new torch::executor::ETDumpGen();
break;
default:
ET_CHECK_MSG(false, "Unsupported eval mode");
break;
}
prefill_etdump_path_ = "prefill_etdump.etdp";
decode_etdump_path_ = "decode_etdump.etdp";
}
}

bool Runner::is_loaded() const {
Expand Down Expand Up @@ -91,9 +112,15 @@ Error Runner::load() {

for (std::shared_ptr<Module>& module : modules_) {
if (!prefill_forward_name_.empty()) {
if (gen_etdump_) {
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(prefill_forward_name_, prefill_dump_));
}
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(prefill_forward_name_));
}
if (!kv_forward_name_.empty()) {
if (gen_etdump_) {
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kv_forward_name_, decode_dump_));
}
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kv_forward_name_));
}
}
Expand Down Expand Up @@ -395,6 +422,8 @@ Error Runner::generate(

stats_.num_prompt_tokens = num_prompt_tokens;
stats_.num_generated_tokens = pos - num_prompt_tokens;
if (gen_etdump_)
gen_etdump_data();
printReport(stats_);
if (stats_callback) {
stats_callback(stats_);
Expand All @@ -403,6 +432,24 @@ Error Runner::generate(
return Error::Ok;
}

void Runner::gen_etdump_data(){
//dump the prefill and decode etdump data
if (prefill_dump_ != nullptr) {
torch::executor::etdump_result result = prefill_dump_->get_etdump_data();
FILE* ptr = fopen(prefill_etdump_path_.c_str(), "w+");
fwrite(result.buf, 1, result.size, ptr);
fclose(ptr);
prefill_dump_->reset();
}
if (decode_dump_ != nullptr) {
torch::executor::etdump_result result = decode_dump_->get_etdump_data();
FILE* ptr = fopen(decode_etdump_path_.c_str(), "w+");
fwrite(result.buf, 1, result.size, ptr);
fclose(ptr);
decode_dump_->reset();
}
}

namespace {
void printReport(const Runner::Stats& stats) {
printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());
Expand Down
11 changes: 10 additions & 1 deletion examples/qualcomm/oss_scripts/llama3_2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/llm/tokenizer/tokenizer.h>
#include <executorch/extension/module/module.h>
#include <executorch/devtools/etdump/etdump_flatcc.h>


namespace example {

Expand All @@ -30,7 +32,8 @@ class Runner {
const std::vector<std::string>& models_path,
const std::string& tokenizer_path,
const float temperature,
const int eval_mode);
const int eval_mode,
const bool gen_etdump);

struct Stats {
// Scaling factor for timestamps - in this case, we use ms.
Expand Down Expand Up @@ -69,6 +72,7 @@ class Runner {
void stop();
std::vector<executorch::runtime::Result<executorch::runtime::MethodMeta>>
get_methods_meta(std::string& method_name);
void gen_etdump_data();

private:
template <typename T>
Expand All @@ -93,6 +97,11 @@ class Runner {
float temperature_;
std::unique_ptr<executorch::extension::llm::Tokenizer> tokenizer_;
std::unique_ptr<executorch::extension::llm::Sampler> sampler_;
torch::executor::ETDumpGen* prefill_dump_ = nullptr;
torch::executor::ETDumpGen* decode_dump_ = nullptr;
bool gen_etdump_ = false;
std::string prefill_etdump_path_;
std::string decode_etdump_path_;
Stats stats_;
std::unique_ptr<Memory> io_mem_;
EvalMode eval_mode_;
Expand Down
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/llama3_2/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def define_common_targets():
"//executorch/extension/llm/tokenizer:bpe_tokenizer",
"//executorch/extension/evalue_util:print_evalue",
"//executorch/backends/qualcomm/runtime:runtime",
"//executorch/devtools/etdump:etdump_flatcc",
],
external_deps = [
"gflags",
Expand Down

0 comments on commit 83e01d8

Please sign in to comment.