Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ET dump for llama3 runner #7507

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

# model sharding with custom op
set(CUSTOM_OP_SRCS_FILE
set(CUSTOM_OP_SRCS_FILE
"${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp"
)
add_library(custom_ops ${CUSTOM_OP_SRCS_FILE})
Expand Down Expand Up @@ -45,7 +45,7 @@ list(
# build qnn llama3.2 1b runner
add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs})
target_include_directories(
qnn_llama3_2_runner PUBLIC ${_common_include_directories}
qnn_llama3_2_runner PUBLIC ${_common_include_directories} ${EXECUTORCH_SOURCE_DIR}/devtools/etdump
)

target_link_libraries(
Expand All @@ -58,6 +58,8 @@ target_link_libraries(
gflags
re2::re2
custom_ops
etdump
${FLATCCRT_LIB}
)
target_compile_options(
qnn_llama3_2_runner PUBLIC ${_common_compile_options}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ DEFINE_int32(
DEFINE_double(logits_scale, 0.0, "Logits scale");
DEFINE_int32(logits_offset, 0, "Logits offset");

DEFINE_bool(
gen_etdump,
false,
"false: Disable ET dump/ True: Enable ET dump (default: false)");

int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

Expand All @@ -61,7 +66,8 @@ int main(int argc, char** argv) {
FLAGS_logits_scale,
FLAGS_logits_offset,
FLAGS_temperature,
FLAGS_eval_mode);
FLAGS_eval_mode,
FLAGS_gen_etdump);
std::vector<char> buf;
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
std::ofstream fout(FLAGS_output_path.c_str());
Expand Down
51 changes: 50 additions & 1 deletion examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ Runner::Runner(
const float logits_scale,
const int32_t logits_offset,
const float temperature,
const int eval_mode)
const int eval_mode,
const bool gen_etdump)
: n_bos_(1),
n_eos_(1),
tokenizer_path_(tokenizer_path),
Expand All @@ -58,6 +59,28 @@ Runner::Runner(
}
ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str());
ET_LOG(Info, "eval mode=%d", eval_mode);
if (gen_etdump) {
gen_etdump_ = true;
switch (eval_mode) {
case EvalMode::kPrefill:
prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
break;
case EvalMode::kKVCached:
decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
break;
case EvalMode::kHybrid:
prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
break;
default:
ET_CHECK_MSG(false, "Unsupported eval mode");
break;
}
std::string etdump_dir =
models_path[0].substr(0, models_path[0].find_last_of("/\\") + 1);
prefill_etdump_path_ = etdump_dir + "prefill_etdump.etdp";
decode_etdump_path_ = etdump_dir + "decode_etdump.etdp";
}
}

bool Runner::is_loaded() const {
Expand Down Expand Up @@ -95,9 +118,17 @@ Error Runner::load() {

for (std::shared_ptr<Module>& module : modules_) {
if (!prefill_forward_name_.empty()) {
if (gen_etdump_) {
ET_CHECK_OK_OR_RETURN_ERROR(
module->load_method(prefill_forward_name_, prefill_dump_.get()));
}
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(prefill_forward_name_));
}
if (!kv_forward_name_.empty()) {
if (gen_etdump_) {
ET_CHECK_OK_OR_RETURN_ERROR(
module->load_method(kv_forward_name_, decode_dump_.get()));
}
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kv_forward_name_));
}
}
Expand Down Expand Up @@ -424,6 +455,8 @@ Error Runner::generate(

stats_.num_prompt_tokens = num_prompt_tokens;
stats_.num_generated_tokens = pos - num_prompt_tokens;
if (gen_etdump_)
gen_etdump_data();
printReport(stats_);
if (stats_callback) {
stats_callback(stats_);
Expand All @@ -432,6 +465,22 @@ Error Runner::generate(
return Error::Ok;
}

void Runner::gen_etdump_data() {
// dump the prefill and decode etdump data
if (prefill_dump_.get() != nullptr) {
torch::executor::etdump_result result = prefill_dump_->get_etdump_data();
FILE* ptr = fopen(prefill_etdump_path_.c_str(), "w+");
fwrite(result.buf, 1, result.size, ptr);
fclose(ptr);
}
if (decode_dump_.get() != nullptr) {
torch::executor::etdump_result result = decode_dump_->get_etdump_data();
FILE* ptr = fopen(decode_etdump_path_.c_str(), "w+");
fwrite(result.buf, 1, result.size, ptr);
fclose(ptr);
}
}

namespace {
void printReport(const Runner::Stats& stats) {
printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());
Expand Down
10 changes: 9 additions & 1 deletion examples/qualcomm/oss_scripts/llama3_2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <string>
#include <unordered_map>

#include <executorch/devtools/etdump/etdump_flatcc.h>
#include <executorch/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h>
#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/llm/tokenizer/tokenizer.h>
Expand All @@ -32,7 +33,8 @@ class Runner {
const float logits_scale,
const int32_t logits_offset,
const float temperature,
const int eval_mode);
const int eval_mode,
const bool gen_etdump);

struct Stats {
// Scaling factor for timestamps - in this case, we use ms.
Expand Down Expand Up @@ -71,6 +73,7 @@ class Runner {
void stop();
std::vector<executorch::runtime::Result<executorch::runtime::MethodMeta>>
get_methods_meta(std::string& method_name);
void gen_etdump_data();

private:
template <typename T>
Expand Down Expand Up @@ -98,6 +101,11 @@ class Runner {
float temperature_;
std::unique_ptr<executorch::extension::llm::Tokenizer> tokenizer_;
std::unique_ptr<executorch::extension::llm::Sampler> sampler_;
std::unique_ptr<torch::executor::ETDumpGen> prefill_dump_;
std::unique_ptr<torch::executor::ETDumpGen> decode_dump_;
bool gen_etdump_ = false;
std::string prefill_etdump_path_;
cccclai marked this conversation as resolved.
Show resolved Hide resolved
std::string decode_etdump_path_;
Stats stats_;
std::unique_ptr<Memory> io_mem_;
EvalMode eval_mode_;
Expand Down
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/llama3_2/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def define_common_targets():
"//executorch/extension/llm/tokenizer:bpe_tokenizer",
"//executorch/extension/evalue_util:print_evalue",
"//executorch/backends/qualcomm/runtime:runtime",
"//executorch/devtools/etdump:etdump_flatcc",
],
external_deps = [
"gflags",
Expand Down