-
Notifications
You must be signed in to change notification settings - Fork 870
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Include llama2.c as a submodule and just add header file to example i…
…nstead of .c file
- Loading branch information
Showing
6 changed files
with
124 additions
and
867 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
|
||
add_library(babyllama_handler SHARED src/baby_llama_handler.cc) | ||
add_library(llama2_c STATIC ../../../cpp/third-party/llama2.c/run.c) | ||
target_compile_options(llama2_c PRIVATE -Wall -Wextra -Ofast -fPIC) | ||
|
||
target_link_libraries(babyllama_handler PRIVATE ts_backends_core ts_utils ${TORCH_LIBRARIES}) | ||
target_compile_options(babyllama_handler PRIVATE -Wall -Wextra -Ofast) | ||
add_library(babyllama_handler SHARED src/baby_llama_handler.cc) | ||
target_link_libraries(babyllama_handler PRIVATE llama2_c ts_backends_core ts_utils ${TORCH_LIBRARIES}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
#include <ctype.h> | ||
#include <time.h> | ||
#include <math.h> | ||
#include <string.h> | ||
#include <fcntl.h> | ||
#include <unistd.h> | ||
#include <sys/mman.h> | ||
// ---------------------------------------------------------------------------- | ||
// Transformer model | ||
|
||
typedef struct { | ||
int dim; // transformer dimension | ||
int hidden_dim; // for ffn layers | ||
int n_layers; // number of layers | ||
int n_heads; // number of query heads | ||
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery) | ||
int vocab_size; // vocabulary size, usually 256 (byte-level) | ||
int seq_len; // max sequence length | ||
} Config; | ||
|
||
typedef struct { | ||
// token embedding table | ||
float* token_embedding_table; // (vocab_size, dim) | ||
// weights for rmsnorms | ||
float* rms_att_weight; // (layer, dim) rmsnorm weights | ||
float* rms_ffn_weight; // (layer, dim) | ||
// weights for matmuls. note dim == n_heads * head_size | ||
float* wq; // (layer, dim, n_heads * head_size) | ||
float* wk; // (layer, dim, n_kv_heads * head_size) | ||
float* wv; // (layer, dim, n_kv_heads * head_size) | ||
float* wo; // (layer, n_heads * head_size, dim) | ||
// weights for ffn | ||
float* w1; // (layer, hidden_dim, dim) | ||
float* w2; // (layer, dim, hidden_dim) | ||
float* w3; // (layer, hidden_dim, dim) | ||
// final rmsnorm | ||
float* rms_final_weight; // (dim,) | ||
// (optional) classifier weights for the logits, on the last layer | ||
float* wcls; | ||
} TransformerWeights; | ||
|
||
typedef struct { | ||
// current wave of activations | ||
float *x; // activation at current time stamp (dim,) | ||
float *xb; // same, but inside a residual branch (dim,) | ||
float *xb2; // an additional buffer just for convenience (dim,) | ||
float *hb; // buffer for hidden dimension in the ffn (hidden_dim,) | ||
float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,) | ||
float *q; // query (dim,) | ||
float *k; // key (dim,) | ||
float *v; // value (dim,) | ||
float *att; // buffer for scores/attention values (n_heads, seq_len) | ||
float *logits; // output logits | ||
// kv cache | ||
float* key_cache; // (layer, seq_len, dim) | ||
float* value_cache; // (layer, seq_len, dim) | ||
} RunState; | ||
|
||
typedef struct { | ||
Config config; // the hyperparameters of the architecture (the blueprint) | ||
TransformerWeights weights; // the weights of the model | ||
RunState state; // buffers for the "wave" of activations in the forward pass | ||
// some more state needed to properly clean up the memory mapping (sigh) | ||
int fd; // file descriptor for memory mapping | ||
float* data; // memory mapped data pointer | ||
ssize_t file_size; // size of the checkpoint file in bytes | ||
} Transformer; | ||
// ---------------------------------------------------------------------------- | ||
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens | ||
|
||
typedef struct { | ||
char *str; | ||
int id; | ||
} TokenIndex; | ||
|
||
typedef struct { | ||
char** vocab; | ||
float* vocab_scores; | ||
TokenIndex *sorted_vocab; | ||
int vocab_size; | ||
unsigned int max_token_length; | ||
unsigned char byte_pieces[512]; // stores all single-byte strings | ||
} Tokenizer; | ||
|
||
// ---------------------------------------------------------------------------- | ||
// The Sampler, which takes logits and returns a sampled token | ||
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling | ||
|
||
typedef struct { | ||
float prob; | ||
int index; | ||
} ProbIndex; // struct used when sorting probabilities during top-p sampling | ||
|
||
typedef struct { | ||
int vocab_size; | ||
ProbIndex* probindex; // buffer used in top-p sampling | ||
float temperature; | ||
float topp; | ||
unsigned long long rng_state; | ||
} Sampler; | ||
void build_transformer(Transformer *t, char* checkpoint_path); | ||
void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size); | ||
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed); | ||
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens); | ||
float* forward(Transformer* transformer, int token, int pos); | ||
int sample(Sampler* sampler, float* logits); | ||
long time_in_ms(); | ||
char* decode(Tokenizer* t, int prev_token, int token); | ||
void free_sampler(Sampler* sampler); | ||
void free_tokenizer(Tokenizer* t); | ||
void free_transformer(Transformer* t); |
Oops, something went wrong.