Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add elastic llama #98

Merged
merged 10 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,20 @@ else ()
target_link_libraries(demo_sparse_llama MLLM_CPU)
endif ()

add_executable(demo_elastic_llama ${PROJECT_SOURCE_DIR}/examples/demo_elastic_llama.cpp ${DIR_SRC_CPU} ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_EXP} ${DIR_SRC} # ${DIR_SRC_QUANT}
src/tokenizers/Tokenizer.cpp
src/tokenizers/Tokenizer.hpp
src/tokenizers/BPE/Bpe.cpp
src/tokenizers/BPE/Bpe.hpp
)
# target_compile_definitions(demo_elastic_llama PRIVATE MLLM_QKK_64)
if (ARM AND NOT APK)
target_compile_options(demo_elastic_llama PRIVATE -fopenmp)
target_link_libraries(demo_elastic_llama PUBLIC MLLM_CPU -fopenmp -static-openmp)
else ()
target_link_libraries(demo_elastic_llama MLLM_CPU)
endif ()

add_executable(demo_llava ${PROJECT_SOURCE_DIR}/examples/demo_llava.cpp ${DIR_SRC_CPU} ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_EXP} ${DIR_SRC}
src/tokenizers/Tokenizer.cpp
src/tokenizers/BPE/Bpe.cpp
Expand Down
94 changes: 94 additions & 0 deletions examples/demo_elastic_llama.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
//
// Created by Rongjie Yi on 2024/1/26 0026.
//

#include <iostream>
#include "cmdline.h"
#include "models/llama/modeling_elastic_llama.hpp"
#include "models/llama/tokenization_llama.hpp"
#include "processor/PostProcess.hpp"


using namespace mllm;

int main(int argc, char **argv) {
cmdline::parser cmdParser;
cmdParser.add<string>("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/llama_vocab.mllm");
cmdParser.add<string>("model", 'm', "specify mllm model path", false, "../models/llama-2-7b-chat-q4_k.mllm");
cmdParser.add<int>("limits", 'l', "max KV cache size", false, 400);
cmdParser.add<int>("thread", 't', "num of threads", false, 4);
cmdParser.parse_check(argc, argv);

string vocab_path = cmdParser.get<string>("vocab");
string model_path = cmdParser.get<string>("model");
int tokens_limit = cmdParser.get<int>("limits");
CPUBackend::cpu_threads = cmdParser.get<int>("thread");

auto tokenizer = LLaMATokenizer(vocab_path);

LLaMAConfig config(tokens_limit, "7B", LLAMAROPE);
auto model = ElasticLLaMAModel(config);
model.load(model_path);

vector<string> in_strs = {
" Hello, who are you?",
" What can you do?",
"Please introduce Beijing University of Posts and Telecommunications."};

for (int i = 0; i < in_strs.size(); ++i) {
auto in_str = in_strs[i];
auto input_tensor = tokenizer.tokenize(in_str, i);
std::cout << "[Q] " << in_str << std::endl;
std::cout << "[A] " << std::flush;
for (int step = 0; step < 100; step++) {
// vecor<vector<int>> activate_dims = {{32*8,256}};
// 32*8 is attn_head*attn_hidden_dim(e.g. llama:32*128); 256 is ffn_hidden_dim(e.g. llama:11008)
vector<vector<int>> activate_dims = {
{-1,-1}, //0
{-1,-1}, //1
{-1,-1}, //2
{-1,-1}, //3
{-1,-1}, //4
{-1,-1}, //5
{-1,-1}, //6
{-1,-1}, //7
{-1,-1}, //8
{-1,-1}, //9
{-1,-1}, //10
{-1,-1}, //11
{-1,-1}, //12
{-1,-1}, //13
{-1,-1}, //14
{-1,-1}, //15
{-1,-1}, //16
{-1,-1}, //17
{-1,-1}, //18
{-1,-1}, //19
{-1,-1}, //20
{-1,-1}, //21
{-1,-1}, //22
{-1,-1}, //23
{-1,-1}, //24
{-1,-1}, //25
{-1,-1}, //26
{-1,-1}, //27
{-1,-1}, //28
{-1,-1}, //29
{-1,-1}, //30
{-1,-1} //31
};
auto result = model({input_tensor}, activate_dims);
auto outputs = tokenizer.detokenize(result[0]);
auto out_string = outputs.first;
auto out_token = outputs.second;
if (out_token == 2) {
break;
}
std::cout << out_string << std::flush;
chatPostProcessing(out_token, input_tensor, {});
}
printf("\n");
}

return 0;
}
4 changes: 4 additions & 0 deletions include/OpDefined.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ enum OpType {
PREDICTOR,
SPARSELINEAR,
SPARSEIDLINEAR,
ELASTICLINEAR,
OP_NUM
};

Expand Down Expand Up @@ -89,6 +90,9 @@ static const vector<string> OpNames = {
"Range",
"Where",
"Replace",
"SparseLinear",
"SparseIdLinear",
"ElasticLinear",
"OP_NUM"};

enum TensorFuncType {
Expand Down
77 changes: 77 additions & 0 deletions src/Layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#ifndef OPERATION_H
#define OPERATION_H

#include <cstdlib>
#include <iostream>
#include <utility>

#include "Tensor.hpp"
Expand Down Expand Up @@ -49,6 +51,13 @@ class Layer {
return _3I1O_OP(input0, input1, input2);
}

Tensor &operator()(Tensor &input0, int activate_input_dim, int activate_output_dim) {
auto activate_input_dim_tensor = Tensor(1, 1, 1, 1, backend_, true);
activate_input_dim_tensor.setDataAt<float>(0,0,0,0,(float)activate_input_dim);
auto activate_output_dim_tensor = Tensor(1, 1, 1, 1, backend_, true);
activate_output_dim_tensor.setDataAt<float>(0,0,0,0,(float)activate_output_dim);
return _3I1O_only1map_OP(input0, activate_input_dim_tensor, activate_output_dim_tensor);
}

private:
std::string name_num_to_X(const std::string &input_string) {
Expand Down Expand Up @@ -344,6 +353,62 @@ class Layer {
return Tensor::gph_[next_name];
}
}
Tensor &_3I1O_only1map_OP(Tensor &input0, Tensor &input1, Tensor &input2) {
Module::runlistIdx = saved_list_idx;
if (INIT_OP()) {
return input0;
} else {
string layer_next_name = "out-" + op_->name();
if (Tensor::gph_.find(input0.name()) != Tensor::gph_.end()) {
Tensor::gph_[input0.name()].status() = input0.status();
}
switch (input0.status()) {
case TENSOR_STATIC_INIT: {
if (Tensor::gph_.find(input0.name()) == Tensor::gph_.end() || input0.count() != Tensor::gph_[input0.name()].count()) {
Tensor::gph_[input0.name()] = input0;
Tensor::gph_[input0.name()].setName(input0.name());
}
if (layername_2_tensorname.find(layer_next_name) == layername_2_tensorname.end()) {
layername_2_tensorname[layer_next_name] = name_num_to_X(layer_next_name);
}
auto next_name = layername_2_tensorname[layer_next_name];
if (Tensor::gph_.find(next_name) == Tensor::gph_.end()) {
Tensor::gph_[next_name] = Tensor(backend_);
Tensor::gph_[next_name].setName(next_name);
}
vector<shared_ptr<Tensor>> shared_inputs{
std::shared_ptr<Tensor>(&Tensor::gph_[input0.name()], [](Tensor *) {}),
std::shared_ptr<Tensor>(&input1, [](Tensor *) {}),
std::shared_ptr<Tensor>(&input2, [](Tensor *) {})};
vector<shared_ptr<Tensor>> shared_outputs{std::shared_ptr<Tensor>(&Tensor::gph_[next_name], [](Tensor *) {})};
op_->reshape(shared_inputs, shared_outputs);
op_->setUp(shared_inputs, shared_outputs);
assert(Tensor::gph_[next_name].hostPtr<float>() != nullptr);
break;
}
case TENSOR_STATIC_READY: {
auto next_name = layername_2_tensorname[layer_next_name];
vector<shared_ptr<Tensor>> shared_inputs{
std::shared_ptr<Tensor>(&Tensor::gph_[input0.name()], [](Tensor *) {}),
std::shared_ptr<Tensor>(&input1, [](Tensor *) {}),
std::shared_ptr<Tensor>(&input2, [](Tensor *) {})};
vector<shared_ptr<Tensor>> shared_outputs{std::shared_ptr<Tensor>(&Tensor::gph_[next_name], [](Tensor *) {})};
op_->execute(shared_inputs, shared_outputs);
assert(Tensor::gph_[next_name].hostPtr<float>() != nullptr);
break;
}
default: {
break;
}
}
auto next_name = layername_2_tensorname[layer_next_name];
Tensor::gph_[next_name].status() = Tensor::gph_[input0.name()].status();
if(saveNDataFlag){
Tensor::gph_[next_name].saveNData<float>(layer_next_name);
}
return Tensor::gph_[next_name];
}
}
Tensor &_0I1O_OP() {
Module::runlistIdx = saved_list_idx;
if (INIT_OP()) {
Expand Down Expand Up @@ -525,6 +590,18 @@ class Predictor final : public Layer {
// no need to defined a new operator() function, just use the default one
};

class ElasticLinear final : public Layer {
public:
explicit ElasticLinear(int in_features, int out_features, bool bias, std::string name) {
param_["in_features"] = in_features;
param_["out_features"] = out_features;
param_["bias"] = (float)bias;
init(std::move(name), OpType::ELASTICLINEAR);
}
// Use: Tensor &operator()(Tensor &input0, int activate_input_dim, int activate_output_dim) {
};


class SiLU final : public Layer {
public:
SiLU() = default;
Expand Down
25 changes: 22 additions & 3 deletions src/Module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "backends/cpu/CPUBackend.hpp"

#include <any>
#include <iostream>
#include <memory/SystemMemoryManager.hpp>
#include <numeric>
#include <utility>
Expand Down Expand Up @@ -70,9 +71,27 @@ class Module {
Tensor::gph_[std::to_string(i)] = Tensor(Module::backends[MLLM_CPU]);
tmps.push_back(Tensor::gph_[std::to_string(i)]);
}
vector<int> tmpt = {0, 0};
uint64_t time_start = mllm_time_us();
operator()(tmps, tmpt);
vector<std::any> alternate_args={
{},
vector<int>{0, 0},
std::vector<std::vector<int>>(32, std::vector<int>(2))
};
uint64_t time_start = 0;
for (auto args : alternate_args) {
time_start = mllm_time_us();
try {
operator()(tmps, args);
break;
} catch (const std::exception& e) {
if("bad any_cast" != e.what()) {
std::cerr << e.what() << std::endl;
exit(0);
}
} catch (...) {
std::cerr << "load error" << std::endl;
exit(0);
}
}
uint64_t time_end = mllm_time_us();
load_time_ = (time_end - time_start) / 1000.0F;//ms
Module::doLoad = false;
Expand Down
6 changes: 5 additions & 1 deletion src/ParamLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "Types.hpp"
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <string>
Expand Down Expand Up @@ -123,7 +124,10 @@ std::tuple<uint8_t *, uint64_t> ParamLoader::load(string name) {
}
DataType ParamLoader::getDataType(string name) {
if (data_type_.count(name) != 1) {
if (this->fp_ != nullptr) {
if(this->path_ != "" && this->fp_ == nullptr){
std::cerr<<this->path_<<" not found"<<std::endl;
exit(0);
}else if (this->fp_ != nullptr && this->path_ != "") {
std::cerr<<name<<" not found"<<std::endl;
}
return DataType::MLLM_TYPE_COUNT;
Expand Down
2 changes: 2 additions & 0 deletions src/backends/cpu/CPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "CPUPredictor.hpp"
#include "CPUSparseIdLinear.hpp"
#include "CPUSparseLinear.hpp"
#include "CPUElasticLinear.hpp"
#include "CPUTensorFunction.hpp"

namespace mllm {
Expand Down Expand Up @@ -99,6 +100,7 @@ void CPUBackend::registerOps() {
addCreator(PREDICTOR, (CPUBackend::Creator *)(new CPUPredictorCreator()));
addCreator(SPARSELINEAR, (CPUBackend::Creator *)(new CPUSparseLinearCreator()));
addCreator(SPARSEIDLINEAR, (CPUBackend::Creator *)(new CPUSparseIdLinearCreator()));
addCreator(ELASTICLINEAR, (CPUBackend::Creator *)(new CPUElasticLinearCreator()));
}

TensorFunction *CPUBackend::funcCreate(const TensorFuncType type) {
Expand Down
Loading
Loading