diff --git a/examples/demo_bert.cpp b/examples/demo_bert.cpp index b21281df..46c3eb03 100644 --- a/examples/demo_bert.cpp +++ b/examples/demo_bert.cpp @@ -5,6 +5,7 @@ #include "models/bert/modeling_bert.hpp" #include "models/bert/tokenization_bert.hpp" #include "cmdline.h" +#include /* * an intent to support gte-small BertModel to do text embedding @@ -24,15 +25,17 @@ int main(int argc, char *argv[]) { CPUBackend::cpu_threads = cmdParser.get("thread"); BertTokenizer tokenizer(vocab_path, true); - string text = "Help me set an alarm at 21:30"; - auto inputs = tokenizer.tokenizes(text); auto config = BertConfig(); auto model = BertModel(config); model.load(model_path); - auto res = model({inputs[0], inputs[1], inputs[2]})[0]; - - res.printData(); + string text = "Help me set an alarm at 21:30"; + vector texts = {text, text}; + for (auto &text : texts) { + auto inputs = tokenizer.tokenizes(text); + auto res = model({inputs[0], inputs[1], inputs[2]})[0]; + res.printData(); + } return 0; } diff --git a/examples/demo_gemma.cpp b/examples/demo_gemma.cpp index 18e0a885..bdd9c664 100644 --- a/examples/demo_gemma.cpp +++ b/examples/demo_gemma.cpp @@ -54,6 +54,7 @@ int main(int argc, char **argv) { chatPostProcessing(out_token, input_tensor, {}); } printf("\n"); + model.clear_kvcache(); } return 0; diff --git a/examples/demo_imagebind_1mod.cpp b/examples/demo_imagebind_1mod.cpp index 4c4d782d..210c030f 100644 --- a/examples/demo_imagebind_1mod.cpp +++ b/examples/demo_imagebind_1mod.cpp @@ -13,7 +13,7 @@ int main(int argc, char **argv) { cmdParser.add("model", 'm', "specify mllm model path", false, "../models/imagebind_huge-q4_k.mllm"); cmdParser.add("merges", 'f', "specify mllm tokenizer merges.txt path", false, "../vocab/clip_merges.txt"); cmdParser.add("thread", 't', "num of threads", false, 4); - cmdParser.add("loop_times", 'l', "number of inference loops", false, 10); + cmdParser.add("loop_times", 'l', "number of inference loops", false, 2); cmdParser.add("modality", 'o', "inference modality (text/vision/audio/all)", false, "all"); cmdParser.parse_check(argc, argv); diff --git a/examples/demo_openelm.cpp b/examples/demo_openelm.cpp index 2f23a2f0..5f1c19ab 100644 --- a/examples/demo_openelm.cpp +++ b/examples/demo_openelm.cpp @@ -48,10 +48,10 @@ int main(int argc, char **argv) { LlmTextGeneratorOpts opt{ .max_new_tokens = 100, - .do_sample = true, - .temperature = 0.3F, - .top_k = 50, - .top_p = 0.F, + .do_sample = false, + // .temperature = 0.3F, + // .top_k = 50, + // .top_p = 0.F, }; model.generate(input_tensor, opt, [&](unsigned int out_token) -> bool { auto out_string = tokenizer.detokenize({out_token}); @@ -61,5 +61,6 @@ int main(int argc, char **argv) { return true; }); std::cout << "\n"; + model.clear_kvcache(); } } \ No newline at end of file diff --git a/examples/demo_phi3v.cpp b/examples/demo_phi3v.cpp index 89c3cd67..43a4b26e 100644 --- a/examples/demo_phi3v.cpp +++ b/examples/demo_phi3v.cpp @@ -49,7 +49,7 @@ int main(int argc, char **argv) { auto [not_end, output_string] = processor.tokenizer->postprocess(out_string); if (!not_end) { break; } std::cout << output_string << std::flush; - chatPostProcessing(out_token, input_tensor[0], {}); + chatPostProcessing(out_token, input_tensor[0], {&input_tensor[1], &input_tensor[2]}); } printf("\n"); } diff --git a/examples/demo_stablelm.cpp b/examples/demo_stablelm.cpp index 372c967c..077fb2f0 100644 --- a/examples/demo_stablelm.cpp +++ b/examples/demo_stablelm.cpp @@ -11,7 +11,7 @@ int main(int argc, char **argv) { cmdParser.add("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/stablelm_vocab.mllm"); cmdParser.add("merge", 'e', "specify mllm merge path", false, "../vocab/stablelm_merges.txt"); cmdParser.add("model", 'm', "specify mllm model path", false, "../models/stablelm-2-1.6b-chat-q4_k.mllm"); - cmdParser.add("limits", 'l', "max KV cache size", false, 400); + cmdParser.add("limits", 'l', "max KV cache size", false, 600); cmdParser.add("thread", 't', "num of threads", false, 4); cmdParser.parse_check(argc, argv); diff --git a/examples/demo_vit.cpp b/examples/demo_vit.cpp index f7632658..404c7dcf 100644 --- a/examples/demo_vit.cpp +++ b/examples/demo_vit.cpp @@ -1,4 +1,5 @@ #include +#include #include "cmdline.h" #include "models/vit/modeling_vit.hpp" #include "models/vit/labels_vit.hpp" @@ -21,8 +22,15 @@ int main(int argc, char **argv) { auto model = ViTModel(config); model.load(model_path); - auto input_tensor = processor.process("../assets/cat.jpg", 224); - auto result = model({input_tensor}); - auto token_idx = processor.postProcess(result[0]); - std::cout << imagenet_id2label[token_idx] << std::endl; + vector imgs = {"../assets/cat.jpg", + "../assets/dog_image.jpg", + "../assets/bird_image.jpg", + "../assets/car_image.jpg", + "../assets/bus.png"}; + for (auto &img : imgs) { + auto input_tensor = processor.process(img, 224); + auto result = model({input_tensor}); + auto token_idx = processor.postProcess(result[0]); + std::cout << imagenet_id2label[token_idx] << std::endl; + } } \ No newline at end of file diff --git a/examples/demo_yi.cpp b/examples/demo_yi.cpp index ab6457a7..b7f0ca43 100644 --- a/examples/demo_yi.cpp +++ b/examples/demo_yi.cpp @@ -20,7 +20,7 @@ int main(int argc, char **argv) { cmdline::parser cmdParser; cmdParser.add("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/yi_vocab.mllm"); cmdParser.add("model", 'm', "specify mllm model path", false, "../models/yi-1.5-6b-chat-q4_k.mllm"); - cmdParser.add("limits", 'l', "max KV cache size", false, 400); + cmdParser.add("limits", 'l', "max KV cache size", false, 600); cmdParser.add("thread", 't', "num of threads", false, 4); cmdParser.parse_check(argc, argv); diff --git a/src/Layer.hpp b/src/Layer.hpp index 5babf22f..8792c432 100644 --- a/src/Layer.hpp +++ b/src/Layer.hpp @@ -117,6 +117,7 @@ class Layer { module = Module::llm_model_ptr; } map> &activation_tensors = module->activation_tensors; + auto &activation_tensors_num = module->activation_tensors_num; Module::runlistIdx = saved_list_idx; bool do_init = false; // set backend to current module device and try to create op @@ -182,6 +183,7 @@ class Layer { activation_tensors[next_name] = std::make_shared(backend_); activation_tensors[next_name]->setName(next_name); activation_tensors[next_name]->setModule(module); + activation_tensors_num[next_name] = 0; } } if (module->doLoad) { @@ -237,6 +239,28 @@ class Layer { break; } } + if (Backend::global_backends.size() == 1) { + for (auto input_tensor : input_tensors) { + if ((activation_tensors_num.find(input_tensor->name()) != activation_tensors_num.end())) { + switch (Tensor::tensor_status) { + case TENSOR_STATIC_INIT: { + activation_tensors_num[input_tensor->name()] += 1; + break; + } + case TENSOR_STATIC_READY: { + activation_tensors_num[input_tensor->name()] -= 1; + break; + } + default: { + } + } + if (activation_tensors_num[input_tensor->name()] == 0 && activation_tensors[input_tensor->name()]->sequence() > 1) { + activation_tensors[input_tensor->name()]->dealloc(); + // std::cout << input_tensor->name() << "|" << std::endl; + } + } + } + } #ifdef DEBUGOPTIME if (Tensor::tensor_status == TENSOR_STATIC_READY) { auto end_t = mllm_time_us(); diff --git a/src/Module.cpp b/src/Module.cpp index 0a856c5e..95479179 100644 --- a/src/Module.cpp +++ b/src/Module.cpp @@ -25,33 +25,33 @@ std::unordered_map> Module::tensor_func_ops; vector Module::profiling(string name) { vector output; // printf("\n"); - MLLM_LOG_INFO_STREAM << "===========================================" << std::endl; + std::cout << "===========================================" << std::endl; if (!name.empty()) { - MLLM_LOG_INFO_STREAM << " " << name << std::endl; - MLLM_LOG_INFO_STREAM << "-------------------------------------------" << std::endl; + std::cout << " " << name << std::endl; + std::cout << "-------------------------------------------" << std::endl; } double load_time_s = load_time_ / 1000.0F; - MLLM_LOG_INFO_STREAM << " Load time: " << load_time_ / 1000.0F << " s" << std::endl; + std::cout << " Load time: " << load_time_ / 1000.0F << " s" << std::endl; if (inference_times_.size() > 1 && decoding_token_size_ != prefilling_token_size_) { double prefile_speed = 1000 * prefilling_token_size_ / inference_times_[0]; - MLLM_LOG_INFO_STREAM << " Prefilling speed: " << prefile_speed << " tokens/s" << std::endl; + std::cout << " Prefilling speed: " << prefile_speed << " tokens/s" << std::endl; double sum_decoding_time = std::accumulate(std::begin(inference_times_) + 1, std::end(inference_times_), 0.0); double mean_decoding_time = sum_decoding_time / (inference_times_.size() - 1); double decoding_speed = 1000 / mean_decoding_time; - MLLM_LOG_INFO_STREAM << " Decoding speed: " << decoding_speed << " tokens/s" << std::endl; + std::cout << " Decoding speed: " << decoding_speed << " tokens/s" << std::endl; output = {load_time_s, prefile_speed, decoding_speed}; } else { double sum_time = std::accumulate(std::begin(inference_times_), std::end(inference_times_), 0.0); double mean_time = sum_time / (inference_times_.size()); double inference_time_s = mean_time / 1000.0F; - MLLM_LOG_INFO_STREAM << " Inference latency: " << mean_time / 1000.0F << " s" << std::endl; + std::cout << " Inference latency: " << mean_time / 1000.0F << " s" << std::endl; output = {load_time_s, inference_time_s}; } // double sum_time = std::accumulate(std::begin(inference_times_), std::end(inference_times_), 0.0); - // MLLM_LOG_INFO_STREAM<> activation_tensors; + map activation_tensors_num; AbstructLoader *loader; bool doLoad = false; + bool op_transposed_flag = false; static Module *llm_model_ptr; // tag to indicate the multi-chunk prefilling @@ -183,7 +185,6 @@ class Module { } else if (decoding_token_size_ == 0) { decoding_token_size_ = inputs[0].sequence(); } - bool need_setup = true; for (int i = 0; i < inputs.size(); i++) { auto &input = inputs[i]; input.setName("input" + std::to_string(i)); @@ -191,25 +192,12 @@ class Module { activation_tensors[input.name()] = std::shared_ptr(&input, [](Tensor *) {}); activation_tensors[input.name()]->setName(input.name()); activation_tensors[input.name()]->setModule(this); - llm_model_ptr = this; - if (inputs[0].sequence() != 1 && !last_shape_bshd_.empty()) { - // if LLM/VLLM model, the `need_setup` should be `true` - if (input.batch() == last_shape_bshd_[i][0] & input.sequence() == last_shape_bshd_[i][1] & input.head() == last_shape_bshd_[i][2] & input.dimension() == last_shape_bshd_[i][3]) { - // if it is the QNN multi-chunk prefilling, the `need_setup` should be `true` to reshape & setUp CPU Ops - if (Module::isMultiChunkPrefilling) { - need_setup = true; - break; - } - need_setup = false; - } - } } + llm_model_ptr = this; Tensor::tensor_status = TENSOR_STATIC_INIT; uint64_t time_start = mllm_time_us(); - if (need_setup) { - Forward(inputs, anyArgs); - } + Forward(inputs, anyArgs); Tensor::tensor_status = TENSOR_STATIC_READY; // uint64_t time_start = mllm_time_us(); auto output = Forward(inputs, anyArgs); @@ -222,7 +210,7 @@ class Module { last_shape_bshd_.push_back({input.batch(), input.sequence(), input.head(), input.dimension()}); } - + llm_model_ptr->op_transposed_flag = true; return output; } else { // inner Modules // offload according to the backends' info inited during loading diff --git a/src/Tensor.cpp b/src/Tensor.cpp index 29f45838..db6b26ae 100644 --- a/src/Tensor.cpp +++ b/src/Tensor.cpp @@ -10,6 +10,7 @@ #include "Types.hpp" #include #include +#include #include #include @@ -91,10 +92,12 @@ void Tensor::alloc() { void Tensor::dealloc() { if (aggregated_) { return; } assert(backend_ != nullptr); - if (masterTensor() != nullptr) { return; } - if (!shape_offset_.empty() && !shape_master_.empty()) { return; } - backend_->free(host_ptr_); - host_ptr_ = nullptr; + // if (masterTensor() != nullptr) { return; } + // if (!shape_offset_.empty() && !shape_master_.empty()) { return; } + if (masterTensor() == nullptr) { + backend_->free(host_ptr_); + host_ptr_ = nullptr; + } allocated_ = 0; count_ = 0; } @@ -169,6 +172,7 @@ Tensor &Tensor::getFunc(const std::string &suffix, const TensorFuncType type, vector float_args, vector other_tensors) { assert(module() != nullptr); auto &module_tensors = module()->activation_tensors; + auto &activation_tensors_num = module()->activation_tensors_num; const std::string next_name = name_ + "-" + suffix; // if (module_tensors.find(name_) == module_tensors.end()) { // module_tensors[name_] = std::shared_ptr(this, [](Tensor *) {}); @@ -177,6 +181,7 @@ Tensor &Tensor::getFunc(const std::string &suffix, const TensorFuncType type, module_tensors[next_name] = std::make_shared(backend_); module_tensors[next_name]->setName(next_name); module_tensors[next_name]->setModule(module()); + activation_tensors_num[next_name] = 0; } if (module()->doLoad) { return *module_tensors[next_name]; } TensorFunction *func = backend_->funcCreate(type); @@ -197,6 +202,28 @@ Tensor &Tensor::getFunc(const std::string &suffix, const TensorFuncType type, default: { } } + if (Backend::global_backends.size() == 1) { + for (auto input_tensor : tensorPtrs) { + if (activation_tensors_num.find(input_tensor->name()) != activation_tensors_num.end()) { + switch (Tensor::tensor_status) { + case TENSOR_STATIC_INIT: { + activation_tensors_num[input_tensor->name()] += 1; + break; + } + case TENSOR_STATIC_READY: { + activation_tensors_num[input_tensor->name()] -= 1; + break; + } + default: { + } + } + if (activation_tensors_num[input_tensor->name()] == 0 && module_tensors[input_tensor->name()]->sequence() > 1) { + module_tensors[input_tensor->name()]->dealloc(); + // std::cout << input_tensor->name() << " |F" << std::endl; + } + } + } + } #ifdef DEBUGOPTIME if (Tensor::tensor_status == TENSOR_STATIC_READY) { auto end_t = mllm_time_us(); @@ -227,6 +254,7 @@ std::vector> Tensor::getStaticFunc(vectoractivation_tensors; + auto &activation_tensors_num = module->activation_tensors_num; auto *backend_h = Backend::global_backends[MLLM_CPU]; if (!input_tensors.empty() && input_tensors[0]->backend_ != nullptr) { backend_h = input_tensors[0]->backend(); @@ -236,6 +264,7 @@ std::vector> Tensor::getStaticFunc(vector(backend_h); module_tensors[out_name]->setName(out_name); module_tensors[out_name]->setModule(module); + activation_tensors_num[out_name] = 0; } } if (module->doLoad) { @@ -263,6 +292,28 @@ std::vector> Tensor::getStaticFunc(vectorname()) != activation_tensors_num.end()) { + switch (Tensor::tensor_status) { + case TENSOR_STATIC_INIT: { + activation_tensors_num[input_tensor->name()] += 1; + break; + } + case TENSOR_STATIC_READY: { + activation_tensors_num[input_tensor->name()] -= 1; + break; + } + default: { + } + } + if (activation_tensors_num[input_tensor->name()] == 0 && module_tensors[input_tensor->name()]->sequence() > 1) { + module_tensors[input_tensor->name()]->dealloc(); + // std::cout << input_tensor->name() << " |S "<< std::endl;// << out_names[0] << std::endl; + } + } + } + } #ifdef DEBUGOPTIME if (Tensor::tensor_status == TENSOR_STATIC_READY) { auto end_t = mllm_time_us(); @@ -391,10 +442,20 @@ Tensor &Tensor::cat(vector input_tensors, Chl axis) { return getStaticFunc({input_tensors[0].name() + "-cat"}, FUNC_CAT, {(float)axis}, inputs)[0].get(); } +std::string _name_num_to_X(const std::string &input_string) { + std::regex pattern(R"(\.\d{1,3}\.)"); // Matches any number between 1 and 100 between two dots + std::string replacement = ".X."; // The string to replace the matched pattern with + std::string output_string = std::regex_replace(input_string, pattern, replacement); + return output_string; +} + Tensor &Tensor::mm(Tensor &input0, Tensor &input1) { Module *module = input0.module(); + string nname = input0.name() + "-mm-" + input1.name(); + if (nname.find(".X.") != string::npos) + nname = _name_num_to_X(nname); return getStaticFunc( - {input0.name() + "-mm-" + input1.name()}, FUNC_MM, {}, + {nname}, FUNC_MM, {}, {module->activation_tensors[input0.name()].get(), module->activation_tensors[input1.name()].get()})[0] .get(); } diff --git a/src/backends/cpu/compute/Convolution.cpp b/src/backends/cpu/compute/Convolution.cpp index 842de5d9..b5d5ef2f 100644 --- a/src/backends/cpu/compute/Convolution.cpp +++ b/src/backends/cpu/compute/Convolution.cpp @@ -62,7 +62,7 @@ void conv2d_fp32_VALID(Tensor *input, Tensor *output, float **k_new, int kernel_ } } -#pragma omp parallel for num_threads(thread_count) +#pragma omp parallel for collapse(3) num_threads(thread_count) for (int out_ch = 0; out_ch < out_channel; ++out_ch) { for (int out_h = 0; out_h < out_height; ++out_h) { for (int out_w = 0; out_w < out_width; ++out_w) { @@ -73,7 +73,8 @@ void conv2d_fp32_VALID(Tensor *input, Tensor *output, float **k_new, int kernel_ if (support_bias) { value += *bias->ptrAt(0, 0, 0, out_ch); } - *output->ptrAt(b, out_h, out_ch, out_w) = value; + // *output->ptrAt(b, out_h, out_ch, out_w) = value; + output->setDataAt(b, out_h, out_ch, out_w, value); } } } @@ -127,7 +128,7 @@ void conv2d_fp32_SAME(Tensor *input, Tensor *output, float **k_new, int kernel_h } } } -#pragma omp parallel for num_threads(thread_count) +#pragma omp parallel for collapse(3) num_threads(thread_count) for (int out_ch = 0; out_ch < out_channel; ++out_ch) { for (int out_h = 0; out_h < out_height; ++out_h) { for (int out_w = 0; out_w < out_width; ++out_w) { @@ -137,7 +138,8 @@ void conv2d_fp32_SAME(Tensor *input, Tensor *output, float **k_new, int kernel_h if (support_bias) { value += *bias->ptrAt(0, 0, 0, out_ch); } - *output->ptrAt(b, out_h, out_ch, out_w) = value; + // *output->ptrAt(b, out_h, out_ch, out_w) = value; + output->setDataAt(b, out_h, out_ch, out_w, value); } } } @@ -212,7 +214,7 @@ void conv3d_fp32_VALID(Tensor *input, Tensor *output, float **k_new, int kernel_ } } -#pragma omp parallel for num_threads(thread_count) +#pragma omp parallel for collapse(4) num_threads(thread_count) for (int out_ch = 0; out_ch < out_channel; ++out_ch) { for (int out_t = 0; out_t < out_time; ++out_t) { for (int out_h = 0; out_h < out_height; ++out_h) { @@ -223,7 +225,8 @@ void conv3d_fp32_VALID(Tensor *input, Tensor *output, float **k_new, int kernel_ if (support_bias) { value += *bias->ptrAt(0, 0, 0, 0, out_ch); } - *output->ptrAt(b, out_ch, out_t, out_h, out_w) = value; + // *output->ptrAt(b, out_ch, out_t, out_h, out_w) = value; + output->setDataAt(b, out_ch, out_t, out_h, out_w, value); } } } diff --git a/src/backends/cpu/function/CPUFlattenFunc.hpp b/src/backends/cpu/function/CPUFlattenFunc.hpp index e19a0960..a971e638 100644 --- a/src/backends/cpu/function/CPUFlattenFunc.hpp +++ b/src/backends/cpu/function/CPUFlattenFunc.hpp @@ -6,6 +6,7 @@ #define CPUFLATTENFUNC_HPP #include "Tensor.hpp" #include "Types.hpp" +#include "Module.hpp" namespace mllm { class Tensor; @@ -64,6 +65,14 @@ class CPUflattenFunction : public TensorFunction { outputs[0]->setDtype(inputs[0]->dtype()); outputs[0]->alloc(); inputs[0]->deepCopyFrom(outputs[0], false); + } else if (Module::llm_model_ptr->op_transposed_flag) { + if (inputs[0]->masterTensor() == nullptr) { + inputs[0]->free(); + } + outputs[0]->setDtype(inputs[0]->dtype()); + outputs[0]->alloc(); + inputs[0]->deepCopyFrom(outputs[0], false); + return; } else { std::cout << "[TODO]Tensor.Flatten not support!!!!" << std::endl; } diff --git a/src/backends/cpu/function/CPUTransposeFunc.hpp b/src/backends/cpu/function/CPUTransposeFunc.hpp index 9bc8d41b..85ec6c73 100644 --- a/src/backends/cpu/function/CPUTransposeFunc.hpp +++ b/src/backends/cpu/function/CPUTransposeFunc.hpp @@ -6,6 +6,7 @@ #define CPUTRANSPOSEFUNC_HPP #include "Tensor.hpp" #include "Types.hpp" +#include "Module.hpp" namespace mllm { class Tensor; @@ -17,8 +18,10 @@ class CPUtransposeFunction : public TensorFunction { for (int i = 0; i < args.size(); i += 2) { axiss.push_back({(Chl)args[i], (Chl)args[i + 1]}); } - if (outputs[0]->count() <= 0 || outputs[0]->shape() != inputs[0]->shape()) { - outputs[0]->transCopyShape(inputs[0]->shape()); + // if (outputs[0]->count() <= 0 || outputs[0]->shape() != inputs[0]->shape()) + // { + outputs[0]->transCopyShape(inputs[0]->shape()); + if (!Module::llm_model_ptr->op_transposed_flag) { std::map origin_chls = {{BATCH, 0}, {SEQUENCE, 1}, {HEAD, 2}, {DIMENSION, 3}, {CHANNLE, 1}, {TIME, 2}, {HEIGHT, 3}, {WIDTH, 4}}; if (std::equal(outputs[0]->chls().begin(), outputs[0]->chls().end(), origin_chls.begin())) { outputs[0]->chls() = inputs[0]->chls(); @@ -33,24 +36,25 @@ class CPUtransposeFunction : public TensorFunction { outputs[0]->changeCtype(inputs[0]->shape().size()); outputs[0]->undiffusion() = true; } - // if (inputs[0]->masterTensor() != nullptr) { - if (inputs[0]->masterTensor() != nullptr && (inputs[0]->masterTensor()->name().find("Cache") != std::string::npos || inputs[0]->masterTensor()->name().find("weight") != std::string::npos)) { - if (outputs[0]->masterTensor() == nullptr) { - outputs[0]->setDtype(inputs[0]->dtype()); - outputs[0]->deepCopyFrom(inputs[0], false); - } - } else { - if (inputs[0]->masterTensor() == nullptr) { - inputs[0]->free(); - } + } + // if (inputs[0]->masterTensor() != nullptr) { + if (inputs[0]->masterTensor() != nullptr && (inputs[0]->masterTensor()->name().find("Cache") != std::string::npos || inputs[0]->masterTensor()->name().find("weight") != std::string::npos)) { + if (outputs[0]->masterTensor() == nullptr) { outputs[0]->setDtype(inputs[0]->dtype()); - outputs[0]->alloc(); - // inputs[0]->undiffusion() = true; - inputs[0]->setUndiffusion(true); - inputs[0]->deepCopyFrom(outputs[0], false); - outputs[0]->transFrom() = axiss; + outputs[0]->deepCopyFrom(inputs[0], false); + } + } else { + if (inputs[0]->masterTensor() == nullptr) { + inputs[0]->free(); } + outputs[0]->setDtype(inputs[0]->dtype()); + outputs[0]->alloc(); + // inputs[0]->undiffusion() = true; + inputs[0]->setUndiffusion(true); + inputs[0]->deepCopyFrom(outputs[0], false); + outputs[0]->transFrom() = axiss; } + // } } void execute(vector outputs, vector inputs, vector args) override { } diff --git a/src/backends/cpu/op/CPUIRoPE.cpp b/src/backends/cpu/op/CPUIRoPE.cpp index 7ef9150e..39e5b77f 100644 --- a/src/backends/cpu/op/CPUIRoPE.cpp +++ b/src/backends/cpu/op/CPUIRoPE.cpp @@ -416,11 +416,6 @@ ErrorCode CPUIRoPE::doExecute(vector> inputs, vectorsequence(); - if (h_cnt_ >= pos_max_) { - h_cnt_ = 0; - } - #pragma omp parallel for collapse(4) num_threads(thread_count) for (int n = 0; n < input->batch(); ++n) { for (int h = 0; h < input->head(); ++h) { @@ -435,6 +430,10 @@ ErrorCode CPUIRoPE::doExecute(vector> inputs, vectorsequence(); + if (h_cnt_ >= pos_max_) { + h_cnt_ = 0; + } return Op::execute(inputs, outputs); } diff --git a/src/backends/cpu/op/CPURoPE.cpp b/src/backends/cpu/op/CPURoPE.cpp index a721a71d..e71c7aa8 100644 --- a/src/backends/cpu/op/CPURoPE.cpp +++ b/src/backends/cpu/op/CPURoPE.cpp @@ -378,10 +378,6 @@ ErrorCode CPURoPE::doExecute(vector> inputs, vectorsequence(); - if (h_cnt_ >= pos_max_) { - h_cnt_ = 0; - } #pragma omp parallel for collapse(4) num_threads(thread_count) for (int n = 0; n < input->batch(); ++n) { for (int h = 0; h < input->head(); ++h) { @@ -396,6 +392,10 @@ ErrorCode CPURoPE::doExecute(vector> inputs, vectorsequence(); + if (h_cnt_ >= pos_max_) { + h_cnt_ = 0; + } return Op::execute(inputs, outputs); } diff --git a/src/models/dclm/modeling_dclm.hpp b/src/models/dclm/modeling_dclm.hpp index 266e3871..ab0e285f 100644 --- a/src/models/dclm/modeling_dclm.hpp +++ b/src/models/dclm/modeling_dclm.hpp @@ -57,7 +57,7 @@ class DCLMAttention final : public Module { RoPE k_rope; KVCache k_cache; KVCache v_cache; - Layer softmax; + Softmax softmax; int attn_hidden_dim_; int head_dim_; @@ -106,7 +106,7 @@ class DCLMAttention final : public Module { auto qk = Tensor::mm(q, k); qk = qk / std::sqrt(head_dim_); - qk = softmax(qk); + qk = softmax(qk, k_cache.getCacheSeqLen()); auto o = Tensor::mm(qk, v); o = o.view(-1, 1, -1, n_heads_ * head_dim_); diff --git a/src/models/gemma/modeling_gemma.hpp b/src/models/gemma/modeling_gemma.hpp index 0dcede29..a0dcfcba 100644 --- a/src/models/gemma/modeling_gemma.hpp +++ b/src/models/gemma/modeling_gemma.hpp @@ -53,9 +53,9 @@ class GemmaDecoder final : public Module { public: GemmaDecoder() = default; GemmaDecoder(const GemmaConfig &config, const GemmaNameConfig &names, const string &base_name) { - self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, + self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, - config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, true, false, names, base_name + names._attn_base_name); + config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, true, false, names, base_name + names._attn_base_name); mlp = GemmaMLP(config.hidden_size, config.intermediate_size, names, base_name + names._ffn_base_name); input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, base_name + names._attn_norm_name); post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, base_name + names._ffn_norm_name); @@ -71,6 +71,10 @@ class GemmaDecoder final : public Module { return {x}; } + MultiHeadAttention &get_attention() { + return self_atten; + } + private: MultiHeadAttention self_atten; GemmaMLP mlp; @@ -95,6 +99,15 @@ class GemmaModel final : public Module { return {x}; } + void clear_kvcache() override { + for (auto &block : blocks) { + auto kvcache = block.get_attention().get_cache(); + for (auto &cache : kvcache) { cache->clearCache(); } + auto ropes = block.get_attention().get_rope(); + for (auto &rope : ropes) { rope->clearCache(); } + } + } + private: std::vector blocks; Layer norm; @@ -124,6 +137,9 @@ class GemmaForCausalLM final : public Module { outputs = Tensor::mm(outputs, lm_head().transpose(Chl::SEQUENCE, Chl::DIMENSION)); return {outputs}; } + void clear_kvcache() override { + model.clear_kvcache(); + } private: int hidden_size; diff --git a/src/models/openelm/modeling_openelm.hpp b/src/models/openelm/modeling_openelm.hpp index 2cc1a476..87cf3fab 100644 --- a/src/models/openelm/modeling_openelm.hpp +++ b/src/models/openelm/modeling_openelm.hpp @@ -48,7 +48,7 @@ class OpenELMMultiHeadCausalAttention final : public Module { KVCache k_cache; KVCache v_cache; - Layer softmax; + Softmax softmax; int iter = 0; @@ -103,13 +103,19 @@ class OpenELMMultiHeadCausalAttention final : public Module { qk = qk / std::sqrt(head_dim_); - qk = softmax(qk); + qk = softmax(qk, k_cache.getCacheSeqLen()); auto o = Tensor::mm(qk, v); o = o.view(-1, 1, -1, q_heads_ * head_dim_); o = out_proj(o); return {o}; } + vector get_cache() { + return {&k_cache, &v_cache}; + } + vector get_rope() { + return {&q_rope, &k_rope}; + } }; class OpenELMFeedForwardNetwork final : public Module { @@ -171,6 +177,9 @@ class OpenELMDecoderLayer final : public Module { x = x + tmp; return {x}; } + OpenELMMultiHeadCausalAttention &get_attention() { + return attn; + } }; class OpenElMModel final : public Module { @@ -217,4 +226,12 @@ class OpenElMModel final : public Module { return {logits}; } + void clear_kvcache() override { + for (auto &block : decode_layers) { + auto kvcache = block.get_attention().get_cache(); + for (auto &cache : kvcache) { cache->clearCache(); } + auto ropes = block.get_attention().get_rope(); + for (auto &rope : ropes) { rope->clearCache(); } + } + } }; \ No newline at end of file diff --git a/src/models/phi3v/modeling_phi3v.hpp b/src/models/phi3v/modeling_phi3v.hpp index ea0d2766..9efb8c22 100644 --- a/src/models/phi3v/modeling_phi3v.hpp +++ b/src/models/phi3v/modeling_phi3v.hpp @@ -105,7 +105,7 @@ class Phi3Embedding final : public Module { } vector Forward(vector inputs, vector args) override { - bool have_img = inputs.size() > 1; + bool have_img = inputs[1].batch() > 0; auto text_features = embed_tokens({inputs[0]}); if (have_img) { auto image_features = img_processor({inputs[1]})[0];