Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized DeepSeek V2/V3 implementation (MLA) #11446

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

fairydreaming
Copy link
Collaborator

@fairydreaming fairydreaming commented Jan 27, 2025

This PR introduces various optimizations for DeepSeek V2/V3 implementation:

Note that you need to reconvert the model to use this implementation.

Performance compared to the previous "naive" implementation:

deepseek-mla

deepseek-lite-mla-pp

deepseek-r1-mla

deepseek-mla-pp

CUDA performance is worse for short context lengths, but the curve is flatter:

deepseek-lite-mla

deepseek-lite-cuda-mla-pp

TODO:

  • remove unused kv_b tensor from the model
  • maybe add support for old model files (compute k_b and v_b during inference with reduced performance)
  • wait for completion of: llama : refactor llama_kv_cache, llama_context and llm_build_context #11213
  • implement MLA KV cache
  • address regressions in prompt processing performance (different permutations of tensors?) - I don't think it's possible, as this implementation is more compute-intensive compared to regular attention implementation

@fairydreaming fairydreaming marked this pull request as draft January 28, 2025 11:23
@wronkiew
Copy link

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@fairydreaming
Copy link
Collaborator Author

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@wronkiew What model would you like to test?

@wronkiew
Copy link

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@wronkiew What model would you like to test?

V3/R1, Q4_K_S.

@fairydreaming
Copy link
Collaborator Author

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@wronkiew What model would you like to test?

V3/R1, Q4_K_S.

@wronkiew I don't have the model uploaded (my upload bandwidth is too low), you have to download, convert to bf16, convert to gguf and quantize the original model by yourself (or download one that is already converted to bf16, this will save you one step).

@fairydreaming
Copy link
Collaborator Author

I spent some time investigating this hint from the DeepSeek V2 paper:

Fortunately, due to the associative law of matrix multiplication, we can absorb $𝑊^{𝑈𝐾}$ into $𝑊^{𝑈𝑄}$ , and $𝑊^{𝑈𝑉}$ into $𝑊^𝑂$

At first glance it looks reasonable, each absorbed matrix allows to replace two matrix multiplications with a single multiplication, thus reducing the number of operations.

However when we take a look into dimensions of these matrices, this stops being reasonable. For example in DeepSeek V2 lite:

  • $𝑊^{𝑈𝑄}$ tensor has shape [2048, 2048], that is [16, 2048, 128] after reshaping to 3d and permutation
  • $𝑊^{𝑈𝐾}$ tensor has shape [128, 8192], that is [16, 512, 128] after reshaping to 3d and permutation
  • combined "absorbed" tensor has shape [16, 512, 2048]

So (let's ignore the head dimension) this allows to replace two multiplications: with [2048, 128] matrix and [512, 128] matrix with a single multiplication with a [512, 2048]. The combined matrix has over 3x elements compared to individual matrices, so it will take more memory and it will be actually slower to multiply compared to two multiplications with smaller matrices.

With $𝑊^{𝑈𝑉}$ and $𝑊^𝑂$ it's the same story:

  • $𝑊^{𝑈𝑉}$ tensor has shape [2048, 512], that is [16, 512, 128] after reshaping to 3d and permutation
  • $𝑊^𝑂$ tensor has shape [2048, 2048], that is [16, 2048, 128] after reshaping to 3d and permutation
  • combined "absorbed" tensor has shape [16, 512, 2048]

I also found this blog post: https://github.com/xjdr-alt/mla_blog_translation where they mention:

Compared to performing projection with these particularly large low-rank matrices, it is obviously more advantageous to multiply them successively according to the low-rank decomposition form. Therefore, we believe that this optimization step is not very necessary.

So it looks like a dead end, it won't give us any speed gains.

@divine-taco
Copy link

I ran into an issue with DeepSeek-R1-UD-Q2_K_XL from unsloth/DeepSeek-R1-GGUF

llama_model_load: error loading model: missing tensor 'blk.0.attn_k_b.weight'                                                        llama_model_load_from_file_impl: failed to load model

@fairydreaming
Copy link
Collaborator Author

fairydreaming commented Jan 31, 2025

I ran into an issue with DeepSeek-R1-UD-Q2_K_XL from unsloth/DeepSeek-R1-GGUF

llama_model_load: error loading model: missing tensor 'blk.0.attn_k_b.weight'                                                        llama_model_load_from_file_impl: failed to load model

As I wrote in the PR:

Note that you need to reconvert the model to use this implementation.

Existing GGUFs won't work, you have to convert and quantize one with the code from this PR.

@danielhanchen
Copy link
Contributor

Ohh hmm should I re-quantize the ones in https://huggingface.co/unsloth/DeepSeek-R1-GGUF?

@fairydreaming
Copy link
Collaborator Author

Ohh hmm should I re-quantize the ones in https://huggingface.co/unsloth/DeepSeek-R1-GGUF?

I think it's best to wait a bit until this is stable and merged, it's possible that there will be some changes that would cause them to stop working and you'd have to repeat the conversion again.

@fairydreaming
Copy link
Collaborator Author

I updated the token generation performance plots in the PR post, also added some new showing the prompt processing performance. The optimized implementation generally performs WORSE in prompt processing - DeepSeek R1 671B Q4_K_S running on CPU performs only a little worse (~10% with 4k prompt), but DeepSeek V2 Lite Q8_0 running on RTX 4090 performs MUCH WORSE (~30% with 16k prompt) and in both cases the gap widens as the prompt length increases. So it's not all sunshine and rainbows.

Considering all these performance regressions I think the best course of action would be to put the optimized implementation into separate model architecture (LLM_ARCH_DEEPSEEK2_MLA or something like this). This will prevent issues with existing GGUFs - they would keep working with existing architecture. I guess in this case the convert script would have to allow selection of the target model architecture with some option, but that shouldn't be difficult to add. @ggerganov what do you think?

Comment on lines +6406 to +6409
// whether to use n_tokens as the matrix dimension during multiplication or n_head
// n_tokens is higher during prompt processing, this allows to optimize for this case
bool pp_opt = n_tokens > n_head;

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure this is the right approach. Haven't followed through the logic yet, but it seems strange to involve so many permutes and conts.

I would first look into improving the FA kernels to support DeepSeek head sizes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure this is the right approach. Haven't followed through the logic yet, but it seems strange to involve so many permutes and conts.

Hmm? I'm quite sure there's only one ggml_cont() call (excluding the ones for CUDA compatibility that already existed in the previous implementation).

As for the permutes the idea is to multiply by a matrix with a second dimension equal to the number of heads instead of the number of tokens (which is 1) during a single sequence token generation, that increased the performance on a CPU a bit.

So during prompt processing we have 2 permutes and 1 cont. During token generation we have 5 permutes (yeah, that may be a lot) and 0 conts.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the correction - I did imagine the extra conts when I saw the permutes.

@ggerganov
Copy link
Owner

Considering all these performance regressions I think the best course of action would be to put the optimized implementation into separate model architecture (LLM_ARCH_DEEPSEEK2_MLA or something like this). This will prevent issues with existing GGUFs - they would keep working with existing architecture. I guess in this case the convert script would have to allow selection of the target model architecture with some option, but that shouldn't be difficult to add. @ggerganov what do you think?

While this is possible to do, I think it has a lot of cons. It will make it difficult for everyone to know which model variation on which hardware to use for better performance. Ideally, we want to have a single implementation that is optimal in all use cases, which can be deprecated at some point for a better alternative. But having 2 alternatives neither of which is optimal is not great.

Also, I'm not sure how this implementation fits with multiple parallel sequences and it introduces extra KV cache logic, specific to this type of arch.

I know there is a lot of interest in the DeepSeek arch right now and such optimizations are really important for people. But I think that we have to keep this work in a PR for a while. It is much more important to fix the software architecture in libllama after which such changes should become easier.

@fairydreaming
Copy link
Collaborator Author

Considering all these performance regressions I think the best course of action would be to put the optimized implementation into separate model architecture (LLM_ARCH_DEEPSEEK2_MLA or something like this). This will prevent issues with existing GGUFs - they would keep working with existing architecture. I guess in this case the convert script would have to allow selection of the target model architecture with some option, but that shouldn't be difficult to add. @ggerganov what do you think?

While this is possible to do, I think it has a lot of cons. It will make it difficult for everyone to know which model variation on which hardware to use for better performance. Ideally, we want to have a single implementation that is optimal in all use cases, which can be deprecated at some point for a better alternative. But having 2 alternatives neither of which is optimal is not great.

That may not be possible - IMHO MLA attention implementation that caches "compressed" latent kv representations introduces unavoidable computational overhead due to the need to "decompress" these representations in order to calculate attention scores and attention output. So "naive" attention implementation that caches full K/V vectors will always use less compute but more memory bandwidth, while caching latent representations results in using more compute, but less memory bandwidth. So there can't be a single implementation optimal in all use cases. I'd be happy to be proven wrong about this, though.

Also, I'm not sure how this implementation fits with multiple parallel sequences and it introduces extra KV cache logic, specific to this type of arch.

I think there shouldn't be any problems with this, as there is a straightforward direct mapping between the cached representations and full K/V vectors.

I know there is a lot of interest in the DeepSeek arch right now and such optimizations are really important for people. But I think that we have to keep this work in a PR for a while. It is much more important to fix the software architecture in libllama after which such changes should become easier.

That's fine with me. I'm taking a break from this anyway, got bored with tensor shuffling looking for 0.1 t/s more performance. 😉

@saood06
Copy link

saood06 commented Feb 2, 2025

@fairydreaming
Is there any reason this should cause issues with RPC.
Encountered:

ggml_cuda_compute_forward: cannot compute kqv-31: src0->ne[3] = 1, src1->ne[3] = 2 - fallback to CPU
evaluate_and_capture_cuda_graph: op not supported kqv-31 (MUL_MAT)
[...]\llama.cpp\ggml\src\ggml-cuda\ggml-cuda.cu:2660: GGML_ASSERT(ok) failed

I don't have a quant on hand that I can test without this branch, but this branch does give me a nice performance boost for TG at longer contexts, but RPC to CUDA does not work.

@slaren
Copy link
Collaborator

slaren commented Feb 7, 2025

@slaren @ggerganov Is there a way to profile GGML in the same way to see if there could be any transpose or reshape ops hurting the performance?

View ops are "free", they have no runtime cost. However, most operations require at least the first dimension to be contiguous, so using a view often requires adding a ggml_cont to make the tensor contiguous, and that's where the expensive copy happens. But it is not hidden from the user, the cost will be clearly visible in the ggml_cont.

@slaren
Copy link
Collaborator

slaren commented Feb 7, 2025

I can't use bfloat16 as there is no backend support in llama.cpp for CUDA (unless this has changed recently?)

CUDA should support bf16 after #11093

@jukofyork
Copy link
Contributor

@slaren @ggerganov Is there a way to profile GGML in the same way to see if there could be any transpose or reshape ops hurting the performance?

View ops are "free", they have no runtime cost. However, most operations require at least the first dimension to be contiguous, so using a view often requires adding a ggml_cont to make the tensor contiguous, and that's where the expensive copy happens. But it is not hidden from the user, the cost will be clearly visible in the ggml_cont.

Thanks - I'm gonna try and have a look at this over the weekend.

I can't use bfloat16 as there is no backend support in llama.cpp for CUDA (unless this has changed recently?)

CUDA should support bf16 after #11093

Oh thanks - I might as well see if I can track down which operation is causing the float16 overflow now I've started.

(I'm rechecking all the tensors for any values being over sqrt(65504) now too)

@jukofyork
Copy link
Contributor

jukofyork commented Feb 7, 2025

So:

  • The ggml_mul_mat_set_prec(XXX, GGML_PREC_F32) didn't fix it.
  • There are no bf16 tensors in the original .safetensors files with any values >= sqrt(65504).

Trying bf16 on CUDA now, and it fhat works will try to whittle down to exactly which it is...

@jukofyork
Copy link
Contributor

Ignore all I posted and just deleted - I had the BF16 and F16 switched for my tests 🤦

I have it down to one of the "_b" tensors so far.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 8, 2025

It is the attn_k_b.weight tensors that can't be float16.

@saood06
Copy link

saood06 commented Feb 8, 2025

It is the attn_kv_b.weight tensors that can't be float16.

I'm confused, I thought that tensor doesn't get used? The PR description says "remove unused kv_b tensor from the model".

Do you mind posting the quantization log for the smallest of your quants that has the quantitative differences of longer and more structured responses. I'm planning on making a new quant anyway.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 8, 2025

It is the attn_kv_b.weight tensors that can't be float16.

I'm confused, I thought that tensor doesn't get used? The PR description says "remove unused kv_b tensor from the model".

Shit sorry, yes it's the attn_k_b.weight.

It's been a long day of testing these, but it is 100% this tensor type - with some more time it may be possible to identify if it is restricted to only 1 or a few layers, but I can't do this for now.

Do you mind posting the quantization log for the smallest of your quants that has the quantitative differences of longer and more structured responses. I'm planning on making a new quant anyway.

It's really obvious as all you have to do is set the type like this in src/llama-quant.cpp.

static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
    const std::string name = ggml_get_name(tensor);

    // TODO: avoid hardcoded tensor names - use the TN_* constants
    const llm_arch arch = qs.model.arch;
    const auto       tn = LLM_TN(arch);

    auto use_more_bits = [](int i_layer, int n_layers) -> bool {
        return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
    };
    const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
    auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
        if (n_expert > 1) {
            // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
            // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
            // for getting the current layer as I initially thought, and we need to resort to parsing the
            // tensor name.
            if (sscanf(name, "blk.%d.", &i_layer) != 1) {
                throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
            }
            if (i_layer < 0 || i_layer >= n_layer) {
                throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
            }
        }
        return std::make_pair(i_layer, n_layer);
    };

    // ###
    if (name.find("attn_k_b.weight") != std::string::npos) {
        new_type = GGML_TYPE_F16;
    }
    else
    // ###

    // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
    // with the quantization of the output tensor
    if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
.
.
.

and it will just repeat the same word over and over whatever else you choose for the quant type.

The drop in quality using Q8_0 is more subtle.

The solution for now is to hack in

    // ###
    if (name.find("attn_k_b.weight") != std::string::npos) {
        new_type = GGML_TYPE_BF16;
    }
    else
    // ###

or

    // ###
    if (name.find("attn_k_b.weight") != std::string::npos) {
        new_type = GGML_TYPE_F32;
    }
    else
    // ###

But to be safe I'm just straight up using this now:

    // ### JUK ###
    if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q6_K) {
        if (name.find("_exps") != std::string::npos) {
            if (name.find("ffn_down") != std::string::npos) {
                new_type = GGML_TYPE_Q6_K;
            }
            else {
                if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
                    new_type = GGML_TYPE_Q4_K;
                }
                else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) {
                    new_type = GGML_TYPE_Q5_K;
                }
                else {
                    new_type = GGML_TYPE_Q6_K;
                }
            }
        }
        else {
            new_type = GGML_TYPE_BF16;
        }
    }
    else
    // ### JUK ###

as I'm offloading all the _exps tensors to RAM and keeping everything else in VRAM.

I've also hacked this PR to use this inside src/llama-kv-cache.cpp

        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*1);//kv_size);
        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*1);//kv_size);

to hopefully avoid creating the original KV cache, but not tested it yet.

@saood06
Copy link

saood06 commented Feb 8, 2025

It is the attn_kv_b.weight tensors that can't be float16.

I'm confused, I thought that tensor doesn't get used? The PR description says "remove unused kv_b tensor from the model".

Shit sorry, yes it's the attn_k_b.weight.

It's been a long day of testing these, but it is 100% this tensor type - with some more time it may be possible to identify if it is restricted to only 1 or a few layers, but I can't do this for now.

Thanks for spending the time to test it out. At F16 those tensors are ~1GB, and given the smallest functional quants of this I've seen is ~130GB narrowing it down even if possible is not as important.

I've also hacked this PR to use this inside src/llama-kv-cache.cpp

        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*1);//kv_size);
        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*1);//kv_size);

to hopefully avoid creating the original KV cache, but not tested it yet.

Let me know if that works, right now for me it doesn't matter as I'm using fairydreaming's PR to mmap the KV, but I plan to like you offload everything besides the non-shared experts to VRAM, and the large and unused KV allocation in VRAM would be nice to avoid.

@jukofyork
Copy link
Contributor

@saood06

I can run the perplexity test for you now, but forgot what you said - default command line using wiki.test.raw?

I will also try to run it for you with the same Q5_K_XL quant as above, but without the MLA PR merged tomorrow to check.

@saood06
Copy link

saood06 commented Feb 8, 2025

@saood06

I can run the perplexity test for you now, but forgot what you said - default command line using wiki.test.raw?

Yep, default command line using wiki.test.raw

I will also try to run it for you with the same Q5_K_XL quant as above, but without the MLA PR merged tomorrow to check.

Would be interesting to see if they diverge.

@jukofyork
Copy link
Contributor

to hopefully avoid creating the original KV cache, but not tested it yet.

Let me know if that works, right now for me it doesn't matter as I'm using fairydreaming's PR to mmap the KV, but I plan to like you offload everything besides the non-shared experts to VRAM, and the large and unused KV allocation in VRAM would be nice to avoid.

I can confirm this works, eg for 32k context:

llama_init_from_model: n_ctx_per_seq (32768) < n_ctx_train (163840) -- the full capacity of the model will not be utilized
llama_kv_cache_init: kv_size = 32768, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 61, can_shift = 0
llama_kv_cache_init:      CUDA0 KV buffer size =  2110.42 MiB
llama_kv_cache_init:      CUDA1 KV buffer size =  2042.34 MiB
llama_init_from_model: KV self size  =    4.77 MiB, K (f16):    2.86 MiB, V (f16):    1.91 MiB
llama_init_from_model: KV self size  = 2196.00 MiB, K^R (f16):  244.00 MiB, c^KV (f16): 1952.00 MiB

I'm only using 35/48 GB per GPU even with this as all GPU tensors as bfloat16, so should probably be able to use the full context of the model (if I can face waiting 2 days for an answer that is!).

@saood06
Copy link

saood06 commented Feb 8, 2025

to hopefully avoid creating the original KV cache, but not tested it yet.

Let me know if that works, right now for me it doesn't matter as I'm using fairydreaming's PR to mmap the KV, but I plan to like you offload everything besides the non-shared experts to VRAM, and the large and unused KV allocation in VRAM would be nice to avoid.

I can confirm this works

Nice.

I'm only using 35/48 GB per GPU even with this as all GPU tensors as bfloat16, so should probably be able to use the full context of the model (if I can face waiting 2 days for an answer that is!).

The furthest I've pushed was 32K context used, with a total of 64K context available.

@saood06
Copy link

saood06 commented Feb 9, 2025

@jukofyork

I can run the perplexity test for you now.

Any chance you have run the tests? Here are my results.

Quant [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] [11] [12] [13] [14] [15] [16] [17] [18] [19] [20] [21] [22] [23] [24]
(V1) 2.5944 3.3242 2.4001 1.9949 1.8067 1.6666 1.5704 1.5055 1.4559 1.4154 1.3999 1.4404 1.4500 1.5786 1.7101 1.7729 1.9347 2.0639 2.0260 2.0157 2.1257 2.0994 2.0710 2.0844
(V2) 2.5474 3.3247 2.4001 2.0029 1.8181 1.6716 1.5734 1.5084 1.4592 1.4194 1.4035 1.4376 1.4476 1.5734 1.7047 1.7654 1.9276 2.0560 2.0189 2.0066 2.1138 2.0865 2.0588 2.0738
(V3) 2.5551 3.3239 2.3980 1.9980 1.8057 1.6631 1.5676 1.5029 1.4525 1.4122 1.3963 1.4421 1.4516 1.5784 1.7089 1.7692 1.9317 2.0597 2.0222 2.0123 2.1185 2.0922 2.0642 2.0772

They are all roughly the same size although each version is bigger than the last. I won't keep that up though, If I V4, it will be sized between V1 and V2.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 10, 2025

@jukofyork

I can run the perplexity test for you now.

Any chance you have run the tests?

Sorry, I had a drive fail in a raid array where all this is stored and it spend ages rebuilding.

I will try and run some benchmarks this week.

I also have something that may speed up this PR, but not sure yet as it has to do a lot of SVDs on huge matrices and is taking a while.

@jukofyork
Copy link
Contributor

@fairydreaming Just saw your post on LocalLLaMA:

So the prompt processing rate is massively improved (3.38 times as fast as llama.cpp, thanks to the RTX 4090 I guess), while the token generation rate increased by 64%.

I'm not sure why llama.cpp prompt processing is so much worse, but the generation part isn't as impressive as it first seems:

(selectively using 6 experts, V0.3 only)

https://github.com/kvcache-ai/ktransformers/blob/dbaecd0ca5e99da4e2cb4d86b0d6ef3fa2b1eba3/doc/en/DeepseekR1_V3_tutorial.md

So that alone would account for a 33% speedup if llama.cpp did the same.

@jukofyork
Copy link
Contributor

What tools are available for profiling GGML operations?

Is there some way we can see exactly what is happening during prompt processing as it seems oddly low to me - I can see it pulling stuff in to process on the GPU, but it seems to gain almost nothing as you increase the batch size for a large prompt?

I get the feeling something weird is going on regarding the experts, but didn't manage to get any further than tracing the GGML code to the mat_mul_id stuff and nothing was obvious.

@fairydreaming
Copy link
Collaborator Author

@fairydreaming Just saw your post on LocalLLaMA:

So the prompt processing rate is massively improved (3.38 times as fast as llama.cpp, thanks to the RTX 4090 I guess), while the token generation rate increased by 64%.

I'm not sure why llama.cpp prompt processing is so much worse, but the generation part isn't as impressive as it first seems:

(selectively using 6 experts, V0.3 only)

https://github.com/kvcache-ai/ktransformers/blob/dbaecd0ca5e99da4e2cb4d86b0d6ef3fa2b1eba3/doc/en/DeepseekR1_V3_tutorial.md

So that alone would account for a 33% speedup if llama.cpp did the same.

AFAIK in ktransformers the whole attention computation is offloaded to GPU, that's a major source of performance boost.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 10, 2025

@fairydreaming Just saw your post on LocalLLaMA:

So the prompt processing rate is massively improved (3.38 times as fast as llama.cpp, thanks to the RTX 4090 I guess), while the token generation rate increased by 64%.

I'm not sure why llama.cpp prompt processing is so much worse, but the generation part isn't as impressive as it first seems:

(selectively using 6 experts, V0.3 only)

https://github.com/kvcache-ai/ktransformers/blob/dbaecd0ca5e99da4e2cb4d86b0d6ef3fa2b1eba3/doc/en/DeepseekR1_V3_tutorial.md
So that alone would account for a 33% speedup if llama.cpp did the same.

AFAIK in ktransformers the whole attention computation is offloaded to GPU, that's a major source of performance boost.

I'm already doing this using this PR:

#11397

I have only the 57 sets of non-shared experts' tensor triplets on the CPU and the rest lives on the GPU.

I've tried lots of other configs too:

  • Sending it though 6 x 48GB GPUs using RPC (all in VRAM), but the prompt processing speed just never seems to improve.
  • Using a single GPU only.
  • Using 2 GPUs linked with an Nvlink bridge and split-mode row on the non-MLA brach (the 3D view blocks this branch).
  • I've tried using bfloat16 and float16 (which currently doesn't generate coherent output, but isn't any faster) - in case there is some weird transpose causing the quantised weights to do 1 multiply per weight instead 1 per block.

and so on...

Batch sizes seem to make no difference either: it clearly pulls something from RAM to the GPU during prompt processing, runs the GPU at about 75% load (and 1 single thread on the CPU), but a ubatch of 64 is no faster than a ubatch of 1024.

I strongly suspect something funky is going on and instead of pulling the 256 experts to process a batch on, it's pulling each in/out of RAM or something, but sadly no idea how to profile GGML to really get to the bottom of what is happening :/

@fairydreaming
Copy link
Collaborator Author

@jukofyork For starters you could make a debug build of llama.cpp and set GGML_SCHED_DEBUG=2 environment variable, this will print the whole model graph to console. You can see there if a given tensor is on a GPU or on a CPU. Maybe this will give you some answers.

@jukofyork
Copy link
Contributor

@jukofyork For starters you could make a debug build of llama.cpp and set GGML_SCHED_DEBUG=2 environment variable, this will print the whole model graph to console. You can see there if a given tensor is on a GPU or on a CPU. Maybe this will give you some answers.

Thanks - I'll try and look into this!

@jukofyork
Copy link
Contributor

@fairydreaming

It's well worth playing with the #11397 PR if you have a GPU as well now: I got quite a nice generation increase by flushing the page cache, using numactl --interleave ... with --numa distribute and only having the massive experts in system RAM.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 11, 2025

So I've spent all day on this:

1. It's not in ggml_cuda_mul_mat_id.

The if (ne12 == 1) case is really clear and if you force this for batches it barely makes any difference to the prompt processing speed.

2. Forcing const int min_batch_size = 1000000 in ggml_backend_cuda_device_offload_op actually makes the prompt processing about 1.5-2x faster.

From watching top and nvtop with const int min_batch_size = 1000000 you can actually see the 4 batches of 512 tokens go back and forth between the 2 GPUs (due to me using the other PR to place the attention tensors on the GPU and the experts on the CPU).

It looks like the cause then is simply that pulling 450GB though the PCI-e 3.0 bus at a max of 32GB/s (for 16x) takes around 14 seconds alone (at best!), and the gain of using ggml_cuda_mul_mat_id is offset by this cost unless we set the ubatch size way higher than 512 and have a very large prompt to go with it.

I think we should probably make the minimum batch size to offload an option then:

static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
    const int min_batch_size = 32;

    return get_op_batch_size(op) >= min_batch_size;

    GGML_UNUSED(dev);
}

rather than have it fixed like this (@slaren this follows up the discussion we had: #11397 (comment)).


BTW: If anybody wants to use this PR with --split-mode row I found a way to allow it:

and it won't split these tensors.

... Before this: I actually spent 4 hours redoing the conversion to use 3D tensors for the 2 new "temp" tensors:

INFO:hf-to-gguf:blk.0.attn_k_b.weight,        torch.bfloat16 --> BF16, shape = {128, 512, 128}
INFO:hf-to-gguf:blk.0.attn_v_b.weight,        torch.bfloat16 --> BF16, shape = {512, 128, 128}

So I could remove the two ggml_view_3d calls in this PR which were stopping --split-mode row from working, only to find that --split-mode row doesn't work with 3D/4D tensors anyway 😀


I still have no idea how KTransformers are getting the prompt processing speeds they are reporting - the newer Intel CPUs with AMX must be amazingly good!!!

@jukofyork
Copy link
Contributor

jukofyork commented Feb 11, 2025

I've just spoke to my brother and he actually has a single-CPU server with a Sapphire Rapid CPU in it running 8 channels of 48GB 6800 DDR5 (no GPU in it though). I'll hopefully see him sometime this week and will test llama.cpp on it and report back what the prompt processing speed is like - it will be nice to know if KTransformers is using some clever hand-crafted MLX kernel or if we can get similar via optimising compiler for llama.cpp.


@saood06 I've started running the tests:

numactl --interleave=all ./lama-perplexity --n-gpu-layers 99 --override-tensor exps=CPU --override-tensor attn_kv_b=CPU --numa distribute --threads 30 --model ./DeepSeek-R1-mla-Q5_K_XL.gguf --file ./wiki.test.raw

Where Q5_K_XL is BF16 / Q5_K / Q6_K using this hacked version of llama_tensor_get_type():

    if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q6_K) {
        if (name.find("_exps") != std::string::npos) {
            if (name.find("ffn_down") != std::string::npos) {
                new_type = GGML_TYPE_Q6_K;
            }
            else {
                if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
                    new_type = GGML_TYPE_Q4_K;
                }
                else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) {
                    new_type = GGML_TYPE_Q5_K;
                }
                else {
                    new_type = GGML_TYPE_Q6_K;
                }
            }
        }
        else if (name.find("attn_") != std::string::npos && name.find("_output") == std::string::npos) {
            new_type = GGML_TYPE_BF16;
        }
        else {
            new_type = GGML_TYPE_Q8_0;
        }
    }
    else
[1]2.5030,[2]3.2798,[3]2.3704,[4]1.9793,[5]1.7866,[6]1.6453,[7]1.5536,[8]1.4883,[9]1.4388,[10]1.3993,[11]1.3838,[12]1.4188,[13]1.4298,[14]1.5565,[15]1.6874,[16]1.7464

I will try BF16 / Q4_K / Q6_K, Q8_0 / Q5_K / Q6_K, etc later today if I get time.

@saood06
Copy link

saood06 commented Feb 11, 2025

I still have no idea how KTransformers are getting the prompt processing speeds they are reporting - the newer Intel CPUs with AMX must be amazingly good!!!

Even without AMX results got ~80 t/s for PP on KTransformers with a single 4090. I want to try KTransformers but it requires a local GPU, and mine is on a different machine.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 11, 2025

If we set const int min_batch_size = 999999; in ggml_backend_cuda_device_offload_op, can we also use -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=... for OpenBLAS or MKL, or will it not allow this at the same time as -DLLAMA_CUDA=ON?

system_info: n_threads = 30 (n_threads_batch = 30) / 88 | CUDA : ARCHS = 860 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 9999999 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 |

I can't see anything about BLAS here, but it looks like it made a separate backend maybe:

-- Found BLAS: /usr/lib/x86_64-linux-gnu/libopenblas.so  
-- BLAS found, Libraries: /usr/lib/x86_64-linux-gnu/libopenblas.so
-- Found PkgConfig: /usr/bin/pkg-config (found version "1.8.1") 
-- Checking for module 'openblas64'
--   Package 'openblas64', required by 'virtual:world', not found
-- Checking for module 'openblas'
--   Found openblas, version 0.3.21
-- BLAS found, Includes: /usr/include/x86_64-linux-gnu/openblas-pthread/
-- Including BLAS backend

@saood06
Copy link

saood06 commented Feb 11, 2025

I've just spoke to my brother and he actually has a single-CPU server with a Sapphire Rapid CPU in it running 8 channels of 48GB 6800 DDR5 (no GPU in it though). I'll hopefully see him sometime this week and will test llama.cpp on it and report back what the prompt processing speed is like - it will be nice to know if KTransformers is using some clever hand-crafted MLX kernel or if we can get similar via optimising compiler for llama.cpp.

Also just for your reference in KTransformers "The GPU part, which uses Marlin, will always run in 4-bit mode", since you are running your GPU weights with far higher bpw.

@saood06 I've started running the tests:

[1]2.5030,[2]3.2798,[3]2.3704,[4]1.9793,[5]1.7866,[6]1.6453,[7]1.5536,[8]1.4883,[9]1.4388,[10]1.3993,[11]1.3838,[12]1.4188,[13]1.4298,[14]1.5565,[15]1.6874,[16]1.7464

I will try BF16 / Q4_K / Q6_K, Q8_0 / Q5_K / Q6_K, etc later today if I get time.

Thanks taking the sum of the 16 PPL for each of our quant.

Quant Sum
My V1 29.0860
My V2 29.0574
My V3 29.0255,
Your Q5_K_XL 28.6671

Yours is definitely better, which is expected given mine is closer in size to Q4_K based not Q5_K.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 11, 2025

I've just spoke to my brother and he actually has a single-CPU server with a Sapphire Rapid CPU in it running 8 channels of 48GB 6800 DDR5 (no GPU in it though). I'll hopefully see him sometime this week and will test llama.cpp on it and report back what the prompt processing speed is like - it will be nice to know if KTransformers is using some clever hand-crafted MLX kernel or if we can get similar via optimising compiler for llama.cpp.

Also just for your reference in KTransformers "The GPU part, which uses Marlin, will always run in 4-bit mode", since you are running your GPU weights with far higher bpw.

@saood06 I've started running the tests:

[1]2.5030,[2]3.2798,[3]2.3704,[4]1.9793,[5]1.7866,[6]1.6453,[7]1.5536,[8]1.4883,[9]1.4388,[10]1.3993,[11]1.3838,[12]1.4188,[13]1.4298,[14]1.5565,[15]1.6874,[16]1.7464

I will try BF16 / Q4_K / Q6_K, Q8_0 / Q5_K / Q6_K, etc later today if I get time.

Thanks taking the sum of the 16 PPL for each of our quant.
Quant Sum
My V1 29.0860
My V2 29.0574
My V3 29.0255,
Your Q5_K_XL 28.6671

Yours is definitely better, which is expected given mine is closer in size to Q4_K based not Q5_K.

I think there might still be a problem with the overflowing attn_k_b.weight and fp16 as I'm pretty sure Q8_0 is slightly dumber than BF16, and the stock Q4_0 (ie: with only input embedding Q6_K and all other Q4_0 that I tried earlier was completely braindead.

I can't be arsed doing any more today, but it should be possible to balance whatever is overflowing into the gamma parameter of the layernorm which gets stored as float32 IIRC. I'll see if I can get to the bottom of that tomorrow.

@slaren
Copy link
Collaborator

slaren commented Feb 11, 2025

If we set const int min_batch_size = 999999; in ggml_backend_cuda_device_offload_op, can we also use -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=... for OpenBLAS or MKL, or will it not allow this at the same time as -DLLAMA_CUDA=ON?

Yes, but the BLAS backend does not support mul_mat_id. If you want to be sure, set GGML_SCHED_DEBUG=2 and look for operations being scheduled to the BLAS backend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.