Skip to content

Commit

Permalink
separate vision ctx and llm ctx
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Feb 6, 2025
1 parent ff77b15 commit fa55281
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 35 deletions.
14 changes: 11 additions & 3 deletions examples/vision/vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ int main(int argc, char ** argv) {
return 1;
}

llama_vision_context_params vparams = llama_vision_context_default_params();
vparams.n_threads = llama_n_threads(ctx);
llama_vision_context * vctx = llama_vision_init_from_model(model, vparams);
if (!vctx) {
LOG_ERR("model does not have vision encoder\n");
return 1;
}

struct common_sampler * smpl = common_sampler_init(model, params.sampling);

llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
Expand All @@ -136,12 +144,12 @@ int main(int argc, char ** argv) {
}
llama_vision_bitmap * img = load_image_from_file(img_path);
LOG_INF("loaded image %s, size = %d x %d\n", img_path, img->nx, img->ny);
img_tokens = llama_vision_tokenize(ctx, img);
img_tokens = llama_vision_tokenize(vctx, img);
if (!img_tokens) {
LOG_ERR("failed to create image tokens\n");
return 1;
}
if (llama_vision_encode(ctx, img_tokens)) {
if (llama_vision_encode(vctx, img_tokens)) {
LOG_ERR("failed to encode image\n");
return 1;
}
Expand All @@ -163,7 +171,7 @@ int main(int argc, char ** argv) {
return 1;
}
} else {
auto * img_embd = llama_vision_get_output_tensor(ctx);
auto * img_embd = llama_vision_get_output_tensor(vctx);
// std::vector<float> output_debug(ggml_nelements(img_embd));
// ggml_backend_tensor_get(img_embd, output_debug.data(), 0, ggml_nbytes(img_embd));
// for (int row = 0; row < 10; row++) {
Expand Down
23 changes: 20 additions & 3 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ extern "C" {
bool sorted;
} llama_token_data_array;

struct llama_vision_context;

// Structure represents the basic input unit of vision model
// This can be a processed image or slices of images under the hood
struct llama_vision_tokens;
Expand Down Expand Up @@ -365,6 +367,10 @@ extern "C" {
void * abort_callback_data;
};

struct llama_vision_context_params {
int32_t n_threads;
};

// model quantization parameters
typedef struct llama_model_quantize_params {
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
Expand Down Expand Up @@ -402,6 +408,7 @@ extern "C" {
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void);
LLAMA_API struct llama_vision_context_params llama_vision_context_default_params(void);
LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);

Expand Down Expand Up @@ -1297,20 +1304,30 @@ extern "C" {
// Vision API
//

// Vision context
LLAMA_API struct llama_vision_context * llama_vision_init_from_model(
const struct llama_model * model,
struct llama_vision_context_params params);
LLAMA_API void llama_vision_free(struct llama_vision_context * ctx);

// Container for RGB bitmap
LLAMA_API struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny);
LLAMA_API void llama_vision_bitmap_free(struct llama_vision_bitmap * bmp);

// Create image tokens from the RGB bitmap
LLAMA_API struct llama_vision_tokens * llama_vision_tokenize(struct llama_context * ctx, llama_vision_bitmap * bmp);
LLAMA_API struct llama_vision_tokens * llama_vision_tokenize(
struct llama_vision_context * ctx,
struct llama_vision_bitmap * bmp);
LLAMA_API void llama_vision_tokens_free(struct llama_vision_tokens * img_tokens);

// User must reserve N number of tokens in tokenized text prompt for each image
// LLAMA_API int32_t llama_vision_get_n_tokens(const llama_vision_img_tokens * img_tokens);

// Encode patches into embeddings
LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_tokens * img_tokens);
LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_context * ctx);
LLAMA_API int32_t llama_vision_encode(
struct llama_vision_context * ctx,
struct llama_vision_tokens * img_tokens);
LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_vision_context * ctx);

//
// Model split
Expand Down
4 changes: 2 additions & 2 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1576,8 +1576,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_V_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_V_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_PRE_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_V_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_V_PRE_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}},
{LLM_TENSOR_V_POST_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_V_RESMPL_POS_EMBD_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_ADD}},
{LLM_TENSOR_V_RESMPL_ATTN_Q, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_RESMPL_ATTN_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
Expand Down
3 changes: 0 additions & 3 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ struct llama_context {
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]

// vision
llama_vision_context vctx;
};

// TODO: make these methods of llama_context
Expand Down
112 changes: 98 additions & 14 deletions src/llama-vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
}

// alloc memory for graph
bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf);
bool ok = ggml_backend_sched_alloc_graph(ctx.sched.get(), gf);
if (!ok) {
LLAMA_LOG_ERROR("failed to alloc memory for graph\n");
return -1;
Expand Down Expand Up @@ -1064,7 +1064,7 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
// compute
LLAMA_LOG_DEBUG("%s: compute start\n", __func__);
int64_t t_start = ggml_time_ms();
ggml_backend_sched_graph_compute(ctx.sched, gf);
ggml_backend_sched_graph_compute(ctx.sched.get(), gf);

// the last node is the embedding tensor
struct ggml_tensor * output_node = ggml_graph_node(gf, -1);
Expand All @@ -1091,6 +1091,92 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
////////////////////////////////////////////////////////////////////////////////////////
// public API

struct llama_vision_context_params llama_vision_context_default_params() {
return {
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
};
}

struct llama_vision_context * llama_vision_init_from_model(const struct llama_model * model, struct llama_vision_context_params params) {
if (!model->has_vision) {
return nullptr;
}

llama_vision_context * ctx = new llama_vision_context;
ctx->model = &model->vit;

// TODO: this looks ugly, mostly copied from llama.cpp, refactor it in the future

// init backends
{
// add CPU backend
ctx->backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (ctx->backend_cpu == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
llama_vision_free(ctx);
return nullptr;
}
ctx->backends.emplace_back(ctx->backend_cpu);

// create a list of the set_n_threads functions in the backends
for (auto & backend : ctx->backends) {
ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
if (reg) {
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
ggml_backend_set_n_threads_fn(backend.get(), params.n_threads);
}
}
}

// scheduler and compute buffers
{
// buffer types used for the compute buffer of each backend
std::vector<ggml_backend_buffer_type_t> backend_buft;
std::vector<ggml_backend_t> backend_ptrs;
for (auto & backend : ctx->backends) {
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model->devices.empty()) {
// use the host buffer of the first device CPU for faster transfer of the intermediate state
auto * dev = model->devices[0];
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
if (host_buft) {
buft = host_buft;
}
}
backend_buft.push_back(buft);
backend_ptrs.push_back(backend.get());
}

const size_t max_nodes = model->max_nodes();

// buffer used to store the computation graph and the tensor meta data
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));

// TODO: support pipeline_parallel
const bool pipeline_parallel = false;

ctx->sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));

if (pipeline_parallel) {
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched.get()));
}
}

const size_t max_nodes = VISION_GRAPH_MAX_NODE; // TODO: make it dynamic
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));

return ctx;
}

void llama_vision_free(struct llama_vision_context * ctx) {
if (ctx->ctx_ggml) {
ggml_free(ctx->ctx_ggml);
}
delete ctx;
}

struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny) {
llama_vision_bitmap * bmp = new llama_vision_bitmap;
bmp->nx = nx;
Expand All @@ -1105,16 +1191,15 @@ void llama_vision_bitmap_free(llama_vision_bitmap * bmp) {
}

struct llama_vision_tokens * llama_vision_tokenize(
struct llama_context * ctx,
llama_vision_bitmap * bmp) {
llama_vision_context & vctx = ctx->vctx;
switch (vctx.model->hparams.arch) {
struct llama_vision_context * ctx,
struct llama_vision_bitmap * bmp) {
switch (ctx->model->hparams.arch) {
case LLM_ARCH_VISION_LLAVA:
case LLM_ARCH_VISION_MOBILEVLM:
case LLM_ARCH_VISION_IDEFICS3:
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
return new llama_vision_tokens(llama_vision_processor_llava(*ctx).tokenize(*bmp));
case LLM_ARCH_VISION_MINICPMV:
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
return new llama_vision_tokens(llama_vision_processor_llava(*ctx).tokenize(*bmp));
default:
GGML_ASSERT(false && "unsupported arch");
}
Expand All @@ -1124,19 +1209,18 @@ void llama_vision_tokens_free(llama_vision_tokens * p) {
delete p;
}

int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_tokens * p) {
int32_t llama_vision_encode(struct llama_vision_context * ctx, struct llama_vision_tokens * p) {
if (p->buf.empty()) {
LLAMA_LOG_ERROR("%s: nothing to encode\n", __func__);
return -1;
}

llama_vision_context & vctx = ctx->vctx;
auto & hparams = vctx.model->hparams;
auto & hparams = ctx->model->hparams;
switch (hparams.mm_patch_merge_type) {
case MM_PATCH_MERGE_FLAT:
{
// flat / default llava-1.5 type embedding
int32_t encoded = llama_vision_encode_impl(vctx, *p);
int32_t encoded = llama_vision_encode_impl(*ctx, *p);
if (encoded != 0) {
LLAMA_LOG_ERROR("Unable to encode image\n");
return encoded;
Expand All @@ -1154,8 +1238,8 @@ int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_tokens * p)
return 0;
}

struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx) {
return ctx->vctx.output;
struct ggml_tensor * llama_vision_get_output_tensor(struct llama_vision_context * ctx) {
return ctx->output;
}

////////////////////////////////////////////////////////////////////////////////////////
Expand Down
7 changes: 5 additions & 2 deletions src/llama-vision.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "ggml.h"
#include "ggml-cpp.h"
#include "llama.h"
#include "llama-arch.h"

Expand Down Expand Up @@ -142,12 +143,14 @@ struct llama_vision_model {
struct llama_vision_context {
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_t sched = nullptr;
struct ggml_context * ctx_ggml = nullptr;
ggml_backend_sched_ptr sched;
std::vector<ggml_backend_ptr> backends;
ggml_backend_t backend_cpu;

const llama_vision_model * model;

// temporary output data, to be picked up by llama_decode()
struct ggml_context * ctx_ggml = nullptr;
struct ggml_tensor * output;
};

Expand Down
11 changes: 3 additions & 8 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8460,7 +8460,9 @@ static int llama_prepare_sbatch(
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
GGML_ASSERT((batch.token && !batch.embd && !batch.embd_tensor)
|| (!batch.token && batch.embd && !batch.embd_tensor)
|| (!batch.token && !batch.embd && batch.embd_tensor)); // NOLINT
if (batch.token) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
Expand Down Expand Up @@ -9893,13 +9895,6 @@ struct llama_context * llama_init_from_model(
}
}

if (model->has_vision) {
ctx->vctx.model = &model->vit;
ctx->vctx.sched = ctx->sched.get();
const size_t max_nodes = VISION_GRAPH_MAX_NODE; // TODO: make it dynamic
ctx->vctx.buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
}

return ctx;
}

Expand Down

0 comments on commit fa55281

Please sign in to comment.