diff --git a/tools/jni/LibHelper.cpp b/tools/jni/LibHelper.cpp index 12340578..7e1e1fef 100644 --- a/tools/jni/LibHelper.cpp +++ b/tools/jni/LibHelper.cpp @@ -98,7 +98,7 @@ bool LibHelper::setUp(const std::string &base_path, std::string weights_path, st return true; }); Module::isFirstChunk = false; - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(0); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); Module::isMultiChunkPrefilling = true; @@ -146,7 +146,7 @@ bool LibHelper::setUp(const std::string &base_path, std::string weights_path, st return true; }); Module::isFirstChunk = false; - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(0); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); Module::isMultiChunkPrefilling = true; @@ -186,6 +186,11 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u const int chunk_num = seq_length_padding / chunk_size; bool isSwitched = false; + // set total seq length for HeadLinear execute, which can not get the real seq length from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setTotalSequenceLength(real_seq_length); + // set chunk size for the HeadLinear execute, which can not get the chunk size from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setChunkSize(chunk_size); + LlmTextGeneratorOpts opt{ .max_new_tokens = 1, .do_sample = false, @@ -226,7 +231,7 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u }); Module::isFirstChunk = false; } - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(real_seq_length); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(real_seq_length); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(AUTOREGRESSIVE); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); @@ -259,7 +264,7 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u if (!not_end) { return false; } return true; }); - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(0); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); } else { // CPU @@ -318,6 +323,11 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u const int chunk_num = seq_length_padding / chunk_size; bool isSwitched = false; + // set total seq length for HeadLinear execute, which can not get the real seq length from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setTotalSequenceLength(real_seq_length); + // set chunk size for the HeadLinear execute, which can not get the chunk size from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setChunkSize(chunk_size); + LlmTextGeneratorOpts opt{ .max_new_tokens = 1, .do_sample = false, @@ -360,7 +370,7 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u }); Module::isFirstChunk = false; } - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(real_seq_length); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(real_seq_length); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(AUTOREGRESSIVE); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); @@ -393,7 +403,7 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u if (!not_end) { return false; } return true; }); - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(0); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); } else { // CPU