Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : second attempt to refactor vision API #11292

Draft
wants to merge 22 commits into
base: master
Choose a base branch
from

Conversation

ngxson
Copy link
Collaborator

@ngxson ngxson commented Jan 18, 2025

Fix #8010

Supersede #9687

To test this, please refer to #9687 to convert the model to GGUF.

Then,

cmake --build build -j --target llama-vision
./build/bin/llama-vision -m ../models/llava-1.5-7b-hf/model.gguf --image ../models/bliss.png

# The image showcases a lush green field with a hill in the background. In the foreground, there is a large,
# bright, and vibrant green field with a Microsoft Windows XP desktop screen, possibly representing a
# screensaver, superimposed onto the scene. The field is expansive and covers most of

Goals of this PR:

  • Have the first version of public API for llama_vision
  • Support llava, mobilevlm, minicpm-v 2.6, smolVLM
  • See how API can adapt to use with encoder-decoder like llama 3.2 vision (so we can add it soon)
  • Add API to format the chat, equivalent to Processor class on HF library
  • See how quantizing affect the performance

Things that will be done in follow-up PRs:

  • Models with encoder-decoder arch like llama 3.2 vision
  • GPU support
  • Better image processing function: faster resize function, maybe even abstract out the image transformations and optimize it (example: if we run resize twice, better to detect that and only run it once)
  • Further clean up the mess in convert-hf-to-gguf python script

@github-actions github-actions bot added python python script changes server labels Jan 18, 2025
@ngxson
Copy link
Collaborator Author

ngxson commented Jan 19, 2025

Hi @ggerganov @slaren , I would like to ask for an early review from you before proceeding further.

What will be interesting to discuss here is the usage of the new API, as demo in the newly added llama-vision example. The idea is:

  • Call llama_vision_encode for each image (we don't support batching for now, to simplify the implementation)
  • Then, get the output embedding ggml_tensor and add it to llama_batch, then llama_decode it.

I'm already be able to make llava and mobilevlm working with llama-vision and convert_hf_to_gguf.py (for minicpm-v, I'm still struggling with it because the conversion is not straight-forward)

Things that are different from the initial discussion in #8010 :

  • I added a helper function llama_batch_get_one_from_tensor for creating the batch from a tensor, with appropriate n_past (for placing these tokens in the correct place in chat template), and seq_id for future usage in server.
  • llama_vision_patches actually contains slices of image, not patches, as explained in llava-uhd. The patches are actually produced in clip_image_build_graph by doing a ggml_conv_2d. I think I'll need to rename it to llama_vision_slices, but I actually prefer a more appropriate name like llama_vision_preprocessed_img since we do more than just slicing it (i.e. resize, padding, etc) - feel free to suggest if you have any ideas.

And things that are still messy and will need more works:

  1. Naming, most functions are still prefixed by clip_ and I don't know if I should prefix everything with llama_vision_clip_ or not. Please let me know what's your preference.
  2. Chat template support, we may need to introduce a new API that wraps the llama_chat_apply_template, much like how on transformers, they have Processor class that wraps around Tokenizer
  3. Not sure how this API will be adapted for encoder-decoder arch like llama 3.2 vision. In theory, llama_vision_get_output_tensor should become a no-op, but judging from this implementation, it's still needed. @danbev do you have any ideas?

I would love to hear your opinions about this. Thank you!

Comment on lines +862 to +873
if (ctx.ctx_ggml) {
ggml_free(ctx.ctx_ggml);
}
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ctx.ctx_ggml = ggml_init(params);
ctx.output = ggml_dup_tensor(ctx.ctx_ggml, output_node);
ggml_backend_alloc_ctx_tensors_from_buft(ctx.ctx_ggml, ctx.model->buft);
ggml_backend_tensor_copy(output_node, ctx.output);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slaren Not sure if there is a better way, but I'm using a hacky solution here.

Without a dedicated context (and ggml_backend_tensor_copy), the underlay buffer is realloc before the next llama_decode, rendering the data unusable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the vision part uses the same scheduler than the llama_context, that's unavoidable. You could pre-allocate the tensor in a different buffer to avoid the copy, but that's an optimization that can be done later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a separate encoder context for the clip model, the decoder context could reference tensors from it directly. They would be interpreted as inputs for the decoder.

@slaren
Copy link
Member

slaren commented Jan 20, 2025

  • llama_vision_patches actually contains slices of image, not patches, as explained in llava-uhd. The patches are actually produced in clip_image_build_graph by doing a ggml_conv_2d. I think I'll need to rename it to llama_vision_slices, but I actually prefer a more appropriate name like llama_vision_preprocessed_img since we do more than just slicing it (i.e. resize, padding, etc) - feel free to suggest if you have any ideas.

I am just wondering, is there any reason to expose the patches/slices to the user at all? Can the user do anything with the patches other than just immediately call llama_vision_encode and throw them away? If not, then maybe that could be hidden entirely from the user and llama_vision_encode could take directly an image.

@danbev
Copy link
Collaborator

danbev commented Jan 20, 2025

@ngxson I'll take a closer look at this today and specifically how about how this could work with a cross-attention model like Llama 3.2 Vision 👍

One thing that is related to this work is something we discussed about how these models should be provided. I initially though that creating a single .gguf for Llama 3.2 which contained both the vision encoder and the language model would be the way to go, but as can be read in the linked discussion having separate models is probably a better solution. It would be great to get some clarification regarding this and if vision encoders should be separate .gguf models.
I'm looking at updating the conversion for Llama 3.2 and make changes to convert_hf_to_gguf.py to produce 2 models (vision encoder, and language model) instead of one. I'd like to try this out with this latest vision api proposal but I'd prefer to know what the model(s) should look like before proceeding to not waste time.

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 20, 2025

@slaren In my first proposal, I made llama_vision_encode to directly accept an image. But then I decide to split it into postprocess-encode because:

  • The most important reason is because user will be able to retrieve the number of tokens that the image occupies (this can varies depends on image size, in case of llava-uhd). This should be done before any decode/encode so that the user can leave the appropriate places for the image after the tokenizing step. This is also similar to Processor class on HF transformers where it returns a preprocessed image and the tokenized prompt with correct number of tokens "placeholder" for image embd.
  • Second reason is that by making this a dedicated function, it's easier to manage error codes. This is mostly because this function work at pixel level, not tensor level.
  • And third reason is because this preprocessing is indeed thread-safe, so for example, llama-server can do this step in HTTP thread, much like how llama_tokenize is currently done in HTTP thread.

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 20, 2025

Btw I have been repeatedly mentioned about Processor, so I think it's better to give an example of how it works: https://gist.github.com/ngxson/ca46c72f0cc7b441c30dd85c2a24ee62

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding some thoughts that I have so far.

Continuing along the idea for having separate models and contexts for the encoder and the decoder. I think with proper llama_batch abstraction we can have the following API:

// vision
patches0 = llama_vision_tokenize(ctx_enc_v, img0);
patches1 = llama_vision_tokenize(ctx_enc_v, img1);

llama_batch_add_image(batch_enc_v, patches0);
llama_batch_add_image(batch_enc_v, patches1);

llama_encode(ctx_enc_v, batch_enc_v);

embd_enc_v = llama_get_embeddings(ctx_enc_v);

// audio
mel0 = llama_audio_tokenize(ctx_enc_a, audio0);
mel1 = llama_audio_tokenize(ctx_enc_a, audio1);

llama_batch_add_audio(batch_enc_a, mel0);
llama_batch_add_audio(batch_enc_a, mel1);

llama_encode(ctx_enc_a, batch_enc_a);

embd_enc_a = llama_get_embeddings(ctx_enc_a);

// text + vision + audio
tokens0 = llama_tokenize(ctx_dec, tokens0);
tokens1 = llama_tokenize(ctx_dec, tokens1);

llama_batch_add_text      (batch_dec, tokens0);
llama_batch_add_embd_image(batch_dec, embd_enc_v);
llama_batch_add_embd_audio(batch_dec, embd_enc_a);
llama_batch_add_text      (batch_dec, tokens1);

llama_decode(ctx_dec, batch_dec);

For cross-attention models such as Llama 3.2 Vision and Whisper, the decoding context ctx_dec could be initialized with a reference to the encoder context:

llama_context_params cparams_dec;
cparams_dec.ctx_cross[0] = ctx_enc_v;
cparams_dec.ctx_cross[1] = ctx_enc_a;

Edit: extended the example with audio input as well.

Comment on lines 558 to 570
static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, clip_image_size & image_size) {
auto & model = *ctx.model;
auto & hparams = ctx.model->hparams;

const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const int patch_size = hparams.patch_size;
const float eps = hparams.eps;
const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size));
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);

LLAMA_LOG_DEBUG("%s: num_patches = %d\n", __func__, num_patches);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The clip graph should be constructed as any other graph in src/llama.cpp, llm_build_context.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to do this right now, as I can't see how I can re-use existing build_* to make the cgraph of vision models "blend-in" with the rest of llm_build_context

But what I did so far is to make an equivalent called llama_vision_graph_builder. This meant to be a temporary solution, to simplify the migration in the near future.

Could you please have a look on my llama_vision_graph_builder to see how it can be merged into llm_build_context? Thanks!

delete p;
}

int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_patches * p) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we need separate function - we should be able to reuse llama_encode.

Copy link
Collaborator Author

@ngxson ngxson Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I don't think we can do this right now, as it requires llama_batch to also accept image tokens.

Do you think it's ok to keep llama_vision_encode(llama_img_tokens &) and refactor llama_batch later on?

Comment on lines +862 to +873
if (ctx.ctx_ggml) {
ggml_free(ctx.ctx_ggml);
}
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ctx.ctx_ggml = ggml_init(params);
ctx.output = ggml_dup_tensor(ctx.ctx_ggml, output_node);
ggml_backend_alloc_ctx_tensors_from_buft(ctx.ctx_ggml, ctx.model->buft);
ggml_backend_tensor_copy(output_node, ctx.output);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a separate encoder context for the clip model, the decoder context could reference tensors from it directly. They would be interpreted as inputs for the decoder.

Comment on lines 894 to 902
struct llama_vision_patches * llama_vision_patches_init(
struct llama_context * ctx,
llama_vision_bitmap * bmp) {
clip_context & vctx = ctx->vctx;
if (vctx.model->hparams.arch == VISION_ARCH_MINICPMV) {
return new llama_vision_patches(clip_image_preprocess_minicpmv(vctx, *bmp));
}
return new llama_vision_patches(clip_image_preprocess(vctx, *bmp));
}
Copy link
Member

@ggerganov ggerganov Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the analogy of "tokenization" in the context of vision models is the conversion of "images -> patches". So the patches could be considered as "image tokens" and it seems reasonable to have a separate function to create patches, since this would have to be performed on the CPU.

I am just wondering, is there any reason to expose the patches/slices to the user at all? Can the user do anything with the patches other than just immediately call llama_vision_encode and throw them away? If not, then maybe that could be hidden entirely from the user and llama_vision_encode could take directly an image.

Even though the user cannot explicitly operate with the patches, it seems to make sense to expose this in order to be able to multi-thread the pre-processing step.

Note that we should also consider the case of Whisper in the context of this abstraction. The whisper model takes raw input audio in PCM format, which is first pre-processed into a mel spectrogram. This pre-processing step, similar to the image pre-processing for CLIP and the text tokenization in text models, is performed on the CPU and can be multi-threaded. Of course, any of the three types of pre-processings could be implemented on the GPU with enough effort, but the important aspect is that this pre-processing can be done in parallel for different inputs and once computed, can be reused with different contexts.

In all cases, the pre-processed input is passed to the transformer graph and the first step is always to convert this input in embeddings. For text, this conversion is trivial - ggml_get_rows(w, tokens). For Whisper, this processes involves a couple of convolutions of the mel spectrogram:

https://github.com/ggerganov/whisper.cpp/blob/99b011a9f5e63f71201bfa583250506453a7b995/src/whisper.cpp#L1904-L1918

For CLIP, this appears to be again a convolution operator applied to the pre-processed input (the image patches) in order to obtain the initial embeddings:

https://github.com/ngxson/llama.cpp/blob/4a7ab89d7593ccb89f80e6e118875ee0b3ede3c7/src/llama-vision.cpp#L581-L616

All these conversions of the pre-processed input (tokens, mel, patches) into the initial embeddings should be implemented in a single place: build_inp_embd().

Copy link
Collaborator Author

@ngxson ngxson Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the analogy of "tokenization" in the context of vision models is the conversion of "images -> patches". So the patches could be considered as "image tokens" and it seems reasonable to have a separate function to create patches

Make sense then. I realized that I was always associate the notion of "token" with "text", but a quick google search tells that: "In LLMs, a token is a basic unit of input or output [...]"

In that sense, I would propose calling it llama_vision_img_tokens (though, it can be a bit confused because user may expect it a std::vector due to the plural "tokens")

// Structure represents the basic input unit of vision model
// This can be a processed image or slices of images under the hood
struct llama_vision_img_tokens;

// User must reserve N number of tokens in tokenized text prompt for each image
int32_t llama_vision_get_n_tokens(const llama_vision_img_tokens * img_tokens);

@danbev
Copy link
Collaborator

danbev commented Jan 22, 2025

@ngxson Sorry about the delay. I've been able to "force" support for mllama using the latest vision api, that is get an example working. I'm now going to iterate on this and try to figure out how cross attention will work. Just wanted to let you know that some progress is being made.

There is an issue I'm having with the vocab size which I'm not exactly sure how to handle. If anyone has some thoughts around this please let me know.

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 22, 2025

@danbev No worries, I was busy with minicpm-v too. It's still not working now (inference works, but just missing the llava-uhd preprocessor). Will have a look on your implementation of mllama very soon.

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 22, 2025

So, minicpm-v template is more complicated because it contains bot the image and all the slices. Here is what it looks like in minicpmv-cli:

<image> (if no slice, we only have one image) </image><slice><image> (first slice) </image><image> (second slice) </image> .... (n-th slice) </slice>

To get rid of this complication, my idea is to have the embeddings of these tokens (<image>, </image>, <slice> and </slice>) appended into the output tensor returned fromllama_vision_encode.

This will make this formatting transparent to the text tokenizer, but will require embeddings of these tokens to be stored as one-hot vectors in the vision model (of course we can use ggml_get_rows to get them, but will be quite messy)

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 23, 2025

Ok so I managed to get minicpm-v kinda work out of the box with the API (no changes to user-space code is required).

Upon giving it win XP wallpaper bliss, it says: I see a serene landscape featuring a vast expanse of green grass under a clear blue sky

It currently operates with a resized version of the image (like llava), so the performance will be bad for bigger images (with more details). I'll get llava-uhd to work, which breaks the image into slices and thus allow the LLM to "see" the image at different zoom level, thus preserving details.

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 23, 2025

I also got SmolVLM (tested with 500M model) to work with this API without any change to user-space code. The image preprocessor may not be 100% correct, but I'll discuss with SmolVLM team to learn more about it.

For the bliss.jpg test:

The image depicts a wide, rolling green field with a clear blue sky and scattered clouds. The field is expansive, stretching horizontally across the image, and it is lush with grass, indicating a temperate climate with ample rainfall. The grassy areas are uniformly green, suggesting it is a well-maintained field, possibly cultivated for recreational [...]

@wbraswell
Copy link

So what remains to be done before this PR can be successfully merged?

@danbev danbev mentioned this pull request Feb 4, 2025
5 tasks
@ngxson
Copy link
Collaborator Author

ngxson commented Feb 4, 2025

I'm going back to this PR, my goals for this week are:

@wbraswell
Copy link

@ngxson
Sounds good! So I guess this puts us back on the path to re-enabling multimodal?

@AIWintermuteAI
Copy link

@ngxson looks very promising!
I wanted to try out your fork locally, however perhaps there were some changes since you created the PR description?

cmake --build build -j --target llama-vision 
gmake: *** No rule to make target 'llama-vision'.  Stop.

llama.cpp main branch builds fine, following this build instructions:
https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md.

@ngxson
Copy link
Collaborator Author

ngxson commented Feb 6, 2025

Sounds good! So I guess this puts us back on the path to re-enabling multimodal?

yes

I wanted to try out your fork locally, however perhaps there were some changes since you created the PR description?

reconfigure your cmake, cmake -B build ...

llama_vision_context * ctx = new llama_vision_context;
ctx->model = &model->vit;

// TODO: this looks ugly, mostly copied from llama.cpp, refactor it in the future
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov So I finally got a first version of the vision API with dedicated context:

common_init_result llama_init = common_init_from_params(params);
const llama_model * model = llama_init.model.get();
llama_vision_context_params vparams = llama_vision_context_default_params();
llama_vision_context * vctx = llama_vision_init_from_model(model, vparams);
// then, use vctx

I don't yet have time to look deeper into your refactoring #11213 , but do you think the API that I introduced here can be benefit from that?

For context, what I'm currently missing in here is the code to setup various backends. Ideally, that should be abstract out into a "super" class, and llama_vision_init_from_model should only be responsible of setting up output tensor, preprocessor, etc. So I just want to ask if you're having the same idea.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few notes:

  • llama_vision_context should not exist in the public API
  • llama_vision_init_from_model is the same as llama_init_from_model, but it will internally create a derived llama_vision_context : llama_context based on the model parameters. It can of course create different contexts based both on the model parameters and the context parameters (e.g. different KV cache types, decoder vs encoder context, etc.).
  • Any vision-related context parameters should be appended to the existing struct llama_context_params. No need for new struct llama_vision_context_params.
  • The base llama_context would be responsible for setting up the the backend logic and distributing the model tensors. This should be the same for all models

I am wondering if we should first reimplement llama_batch to become private so we can make more non-breaking changes to it. Ideally, the prompt processing in the new vision example should be possible to do with a single llama_decode on a multi-modal batch that we created something like this:

llama_batch_add_text      (batch_dec, tokens0);
llama_batch_add_embd_image(batch_dec, embd_enc_v);
llama_batch_add_text      (batch_dec, tokens1);

llama_decode(ctx, batch_dec);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks for the explanation, that seems good to me, I'll update this PR accordingly.

I am wondering if we should first reimplement llama_batch to become private so we can make more non-breaking changes to it.

Yes it would be very nice to have this. Indeed, this will also resolve #10381 which is currently blocked. Not sure if you gonna take this task or you prefer letting me make a draft PR?

Copy link
Member

@ggerganov ggerganov Feb 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, some of the suggestions are quite dependent on the changes in #11213. I'm a bit stuck unfortunately and haven't made progress lately. Depending on how things go this week, I will either finish the refactor or just scrape the changes and let other people attempt to work on this. I think the ideas there are sound, but it's not very easy for me to implement it.

Yes it would be very nice to have this. Indeed, this will also resolve #10381 which is currently blocked. Not sure if you gonna take this task or you prefer letting me make a draft PR?

Before #11213 is resolved, I don't plan to make changes to this part. So feel free to give it a shot.

@AIWintermuteAI
Copy link

AIWintermuteAI commented Feb 6, 2025

@ngxson
Thanks for the fast reply!
Actually, I just forgot to git switch xD working late evening. It compiles fine now!

I'm interested in running SmolVLM in particular, will try digging into #9687 to see how can I convert the model over the weekend.

@ngxson
Copy link
Collaborator Author

ngxson commented Feb 6, 2025

I'm interested in running SmolVLM in particular

SmolVLM 500M can already be run via the current PR, you should base on this, not the other one.

@agNihit928
Copy link

First of all, great work!
Just wanted to know
How did you convert the SmolVLM to GGUF?
Cause when I tried I got this error:

ubuntu@deepstream-7-base:~/llama-vision/llama.cpp$ python3 convert_hf_to_gguf.py ../SmolVLM-500M-Instruct/
INFO:hf-to-gguf:Loading model: SmolVLM-500M-Instruct
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Exporting model...
INFO:hf-to-gguf:gguf: loading model part 'model.safetensors'
INFO:hf-to-gguf:output.weight,               torch.bfloat16 --> F16, shape = {960, 49280}
INFO:hf-to-gguf:v.mmproj.fc.weight,          torch.bfloat16 --> F16, shape = {12288, 960}
Traceback (most recent call last):
  File "/home/ubuntu/llama-vision/llama.cpp/convert_hf_to_gguf.py", line 5421, in <module>
    main()
  File "/home/ubuntu/llama-vision/llama.cpp/convert_hf_to_gguf.py", line 5415, in main
    model_instance.write()
  File "/home/ubuntu/llama-vision/llama.cpp/convert_hf_to_gguf.py", line 479, in write
    self.prepare_tensors()
  File "/home/ubuntu/llama-vision/llama.cpp/convert_hf_to_gguf.py", line 1843, in prepare_tensors
    super().prepare_tensors()
  File "/home/ubuntu/llama-vision/llama.cpp/convert_hf_to_gguf.py", line 338, in prepare_tensors
    for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
  File "/home/ubuntu/llama-vision/llama.cpp/convert_hf_to_gguf.py", line 1811, in modify_tensors
    return [(self.map_tensor_name(name), data_torch)]
  File "/home/ubuntu/llama-vision/llama.cpp/convert_hf_to_gguf.py", line 238, in map_tensor_name
    raise ValueError(f"Can not map tensor {name!r}")
ValueError: Can not map tensor 'model.text_model.embed_tokens.weight'

@liyimeng
Copy link

Looking forward to seeing this get merged, a couple of other PRs seem depends on it.

@ngxson
Copy link
Collaborator Author

ngxson commented Feb 15, 2025

@agNihit928 I think something gets buggy when I rebase to latest master, you can maybe go back to c3a654c to see if it works.

@agNihit928
Copy link

Sure @ngxson
Will check it out
Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples python python script changes server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

server: Bring back multimodal support
8 participants