Skip to content

Commit

Permalink
Include llama2.c as a submodule and just add header file to example i…
Browse files Browse the repository at this point in the history
…nstead of .c file
  • Loading branch information
mreso committed Jan 27, 2024
1 parent a07b7d9 commit d886bff
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 867 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "cpp/third-party/llama.cpp"]
path = cpp/third-party/llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
[submodule "cpp/third-party/llama2.c"]
path = cpp/third-party/llama2.c
url = https://github.com/karpathy/llama2.c
1 change: 1 addition & 0 deletions cpp/third-party/llama2.c
Submodule llama2.c added at d98620
7 changes: 4 additions & 3 deletions examples/cpp/babyllama/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

add_library(babyllama_handler SHARED src/baby_llama_handler.cc)
add_library(llama2_c STATIC ../../../cpp/third-party/llama2.c/run.c)
target_compile_options(llama2_c PRIVATE -Wall -Wextra -Ofast -fPIC)

target_link_libraries(babyllama_handler PRIVATE ts_backends_core ts_utils ${TORCH_LIBRARIES})
target_compile_options(babyllama_handler PRIVATE -Wall -Wextra -Ofast)
add_library(babyllama_handler SHARED src/baby_llama_handler.cc)
target_link_libraries(babyllama_handler PRIVATE llama2_c ts_backends_core ts_utils ${TORCH_LIBRARIES})
4 changes: 3 additions & 1 deletion examples/cpp/babyllama/src/baby_llama_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

#include <typeinfo>

#include "llama2.c/run.c"
extern "C" {
#include "llama2.c/llama2.h"
}

namespace llm {

Expand Down
113 changes: 113 additions & 0 deletions examples/cpp/babyllama/src/llama2.c/llama2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <time.h>
#include <math.h>
#include <string.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/mman.h>
// ----------------------------------------------------------------------------
// Transformer model

typedef struct {
int dim; // transformer dimension
int hidden_dim; // for ffn layers
int n_layers; // number of layers
int n_heads; // number of query heads
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size, usually 256 (byte-level)
int seq_len; // max sequence length
} Config;

typedef struct {
// token embedding table
float* token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
float* rms_att_weight; // (layer, dim) rmsnorm weights
float* rms_ffn_weight; // (layer, dim)
// weights for matmuls. note dim == n_heads * head_size
float* wq; // (layer, dim, n_heads * head_size)
float* wk; // (layer, dim, n_kv_heads * head_size)
float* wv; // (layer, dim, n_kv_heads * head_size)
float* wo; // (layer, n_heads * head_size, dim)
// weights for ffn
float* w1; // (layer, hidden_dim, dim)
float* w2; // (layer, dim, hidden_dim)
float* w3; // (layer, hidden_dim, dim)
// final rmsnorm
float* rms_final_weight; // (dim,)
// (optional) classifier weights for the logits, on the last layer
float* wcls;
} TransformerWeights;

typedef struct {
// current wave of activations
float *x; // activation at current time stamp (dim,)
float *xb; // same, but inside a residual branch (dim,)
float *xb2; // an additional buffer just for convenience (dim,)
float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
float *q; // query (dim,)
float *k; // key (dim,)
float *v; // value (dim,)
float *att; // buffer for scores/attention values (n_heads, seq_len)
float *logits; // output logits
// kv cache
float* key_cache; // (layer, seq_len, dim)
float* value_cache; // (layer, seq_len, dim)
} RunState;

typedef struct {
Config config; // the hyperparameters of the architecture (the blueprint)
TransformerWeights weights; // the weights of the model
RunState state; // buffers for the "wave" of activations in the forward pass
// some more state needed to properly clean up the memory mapping (sigh)
int fd; // file descriptor for memory mapping
float* data; // memory mapped data pointer
ssize_t file_size; // size of the checkpoint file in bytes
} Transformer;
// ----------------------------------------------------------------------------
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens

typedef struct {
char *str;
int id;
} TokenIndex;

typedef struct {
char** vocab;
float* vocab_scores;
TokenIndex *sorted_vocab;
int vocab_size;
unsigned int max_token_length;
unsigned char byte_pieces[512]; // stores all single-byte strings
} Tokenizer;

// ----------------------------------------------------------------------------
// The Sampler, which takes logits and returns a sampled token
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling

typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling

typedef struct {
int vocab_size;
ProbIndex* probindex; // buffer used in top-p sampling
float temperature;
float topp;
unsigned long long rng_state;
} Sampler;
void build_transformer(Transformer *t, char* checkpoint_path);
void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size);
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed);
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens);
float* forward(Transformer* transformer, int token, int pos);
int sample(Sampler* sampler, float* logits);
long time_in_ms();
char* decode(Tokenizer* t, int prev_token, int token);
void free_sampler(Sampler* sampler);
void free_tokenizer(Tokenizer* t);
void free_transformer(Transformer* t);
Loading

0 comments on commit d886bff

Please sign in to comment.