Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: QNN Multi Chunk Execution in New Frontend #191

Merged
merged 6 commits into from
Nov 14, 2024

Conversation

oreomaker
Copy link
Collaborator

Multi Chunk Execution

Tokenizer: Create a new tokenizePaddingByChunk in SmolLMTokenizer, which will takes input and padding to nearest multiplication of chunk_size

auto [real_seq_length, input_tensor] = tokenizer.tokenizePaddingByChunk(input_str, chunk_size, config.vocab_size);

Module Static States: Add Module::isMultiChunkPrefilling and Module::isFirstChunk to record the multi chunk execution

Module Execution: Add a new tensor_status of TENSOR_UNDEFINED, which is used in QNN chunk execution. If Module::isMultiChunkPrefilling is true, the QNN modules will not reshape & setUp in following chunks, while CPU modules still reshape & setUp

if (Tensor::tensor_status == TENSOR_STATIC_INIT && device_ != MLLM_CPU) { // backend specific module reshape & setup
                if (Module::isMultiChunkPrefilling && !Module::isFirstChunk) {        // set to TENSOR_UNDEFINED and SKIP executing qnn layers
                    Tensor::tensor_status = TENSOR_UNDEFINED;
                    auto outputs =  Forward(inputs, anyArgs);
                    Tensor::tensor_status = TENSOR_STATIC_INIT;
                    return outputs;
                }
...

TODO

Multi round input still output weird results, which may be caused by stateful OPs like KVCache, RoPE and CasaulMask

@yirongjie yirongjie changed the title QNN Multi Chunk Execution in New Frontend feat: QNN Multi Chunk Execution in New Frontend Nov 14, 2024
@yirongjie yirongjie merged commit 11a2fb2 into UbiquitousLearning:main Nov 14, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants