-
Notifications
You must be signed in to change notification settings - Fork 10.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimized DeepSeek V2/V3 implementation (MLA) #11446
base: master
Are you sure you want to change the base?
Conversation
…kv representations
…nsposing the cache during inference
@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes. |
@wronkiew What model would you like to test? |
V3/R1, Q4_K_S. |
@wronkiew I don't have the model uploaded (my upload bandwidth is too low), you have to download, convert to bf16, convert to gguf and quantize the original model by yourself (or download one that is already converted to bf16, this will save you one step). |
I spent some time investigating this hint from the DeepSeek V2 paper:
At first glance it looks reasonable, each absorbed matrix allows to replace two matrix multiplications with a single multiplication, thus reducing the number of operations. However when we take a look into dimensions of these matrices, this stops being reasonable. For example in DeepSeek V2 lite:
So (let's ignore the head dimension) this allows to replace two multiplications: with [2048, 128] matrix and [512, 128] matrix with a single multiplication with a [512, 2048]. The combined matrix has over 3x elements compared to individual matrices, so it will take more memory and it will be actually slower to multiply compared to two multiplications with smaller matrices. With
I also found this blog post: https://github.com/xjdr-alt/mla_blog_translation where they mention:
So it looks like a dead end, it won't give us any speed gains. |
I ran into an issue with DeepSeek-R1-UD-Q2_K_XL from unsloth/DeepSeek-R1-GGUF
|
As I wrote in the PR:
Existing GGUFs won't work, you have to convert and quantize one with the code from this PR. |
Ohh hmm should I re-quantize the ones in https://huggingface.co/unsloth/DeepSeek-R1-GGUF? |
I think it's best to wait a bit until this is stable and merged, it's possible that there will be some changes that would cause them to stop working and you'd have to repeat the conversion again. |
I updated the token generation performance plots in the PR post, also added some new showing the prompt processing performance. The optimized implementation generally performs WORSE in prompt processing - DeepSeek R1 671B Q4_K_S running on CPU performs only a little worse (~10% with 4k prompt), but DeepSeek V2 Lite Q8_0 running on RTX 4090 performs MUCH WORSE (~30% with 16k prompt) and in both cases the gap widens as the prompt length increases. So it's not all sunshine and rainbows. Considering all these performance regressions I think the best course of action would be to put the optimized implementation into separate model architecture ( |
// whether to use n_tokens as the matrix dimension during multiplication or n_head | ||
// n_tokens is higher during prompt processing, this allows to optimize for this case | ||
bool pp_opt = n_tokens > n_head; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not really sure this is the right approach. Haven't followed through the logic yet, but it seems strange to involve so many permutes and conts.
I would first look into improving the FA kernels to support DeepSeek head sizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not really sure this is the right approach. Haven't followed through the logic yet, but it seems strange to involve so many permutes and conts.
Hmm? I'm quite sure there's only one ggml_cont() call (excluding the ones for CUDA compatibility that already existed in the previous implementation).
As for the permutes the idea is to multiply by a matrix with a second dimension equal to the number of heads instead of the number of tokens (which is 1) during a single sequence token generation, that increased the performance on a CPU a bit.
So during prompt processing we have 2 permutes and 1 cont. During token generation we have 5 permutes (yeah, that may be a lot) and 0 conts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the correction - I did imagine the extra conts when I saw the permutes.
While this is possible to do, I think it has a lot of cons. It will make it difficult for everyone to know which model variation on which hardware to use for better performance. Ideally, we want to have a single implementation that is optimal in all use cases, which can be deprecated at some point for a better alternative. But having 2 alternatives neither of which is optimal is not great. Also, I'm not sure how this implementation fits with multiple parallel sequences and it introduces extra KV cache logic, specific to this type of arch. I know there is a lot of interest in the DeepSeek arch right now and such optimizations are really important for people. But I think that we have to keep this work in a PR for a while. It is much more important to fix the software architecture in |
That may not be possible - IMHO MLA attention implementation that caches "compressed" latent kv representations introduces unavoidable computational overhead due to the need to "decompress" these representations in order to calculate attention scores and attention output. So "naive" attention implementation that caches full K/V vectors will always use less compute but more memory bandwidth, while caching latent representations results in using more compute, but less memory bandwidth. So there can't be a single implementation optimal in all use cases. I'd be happy to be proven wrong about this, though.
I think there shouldn't be any problems with this, as there is a straightforward direct mapping between the cached representations and full K/V vectors.
That's fine with me. I'm taking a break from this anyway, got bored with tensor shuffling looking for 0.1 t/s more performance. 😉 |
@fairydreaming
I don't have a quant on hand that I can test without this branch, but this branch does give me a nice performance boost for TG at longer contexts, but RPC to CUDA does not work. |
View ops are "free", they have no runtime cost. However, most operations require at least the first dimension to be contiguous, so using a view often requires adding a |
CUDA should support bf16 after #11093 |
Thanks - I'm gonna try and have a look at this over the weekend.
Oh thanks - I might as well see if I can track down which operation is causing the (I'm rechecking all the tensors for any values being over |
So:
Trying |
Ignore all I posted and just deleted - I had the I have it down to one of the "_b" tensors so far. |
It is the |
I'm confused, I thought that tensor doesn't get used? The PR description says "remove unused kv_b tensor from the model". Do you mind posting the quantization log for the smallest of your quants that has the quantitative differences of longer and more structured responses. I'm planning on making a new quant anyway. |
Shit sorry, yes it's the It's been a long day of testing these, but it is 100% this tensor type - with some more time it may be possible to identify if it is restricted to only 1 or a few layers, but I can't do this for now.
It's really obvious as all you have to do is set the type like this in static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
const std::string name = ggml_get_name(tensor);
// TODO: avoid hardcoded tensor names - use the TN_* constants
const llm_arch arch = qs.model.arch;
const auto tn = LLM_TN(arch);
auto use_more_bits = [](int i_layer, int n_layers) -> bool {
return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
};
const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
if (n_expert > 1) {
// Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
// sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
// for getting the current layer as I initially thought, and we need to resort to parsing the
// tensor name.
if (sscanf(name, "blk.%d.", &i_layer) != 1) {
throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
}
if (i_layer < 0 || i_layer >= n_layer) {
throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
}
}
return std::make_pair(i_layer, n_layer);
};
// ###
if (name.find("attn_k_b.weight") != std::string::npos) {
new_type = GGML_TYPE_F16;
}
else
// ###
// for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
// with the quantization of the output tensor
if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
.
.
. and it will just repeat the same word over and over whatever else you choose for the quant type. The drop in quality using The solution for now is to hack in // ###
if (name.find("attn_k_b.weight") != std::string::npos) {
new_type = GGML_TYPE_BF16;
}
else
// ### or // ###
if (name.find("attn_k_b.weight") != std::string::npos) {
new_type = GGML_TYPE_F32;
}
else
// ### But to be safe I'm just straight up using this now: // ### JUK ###
if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q6_K) {
if (name.find("_exps") != std::string::npos) {
if (name.find("ffn_down") != std::string::npos) {
new_type = GGML_TYPE_Q6_K;
}
else {
if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
new_type = GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) {
new_type = GGML_TYPE_Q5_K;
}
else {
new_type = GGML_TYPE_Q6_K;
}
}
}
else {
new_type = GGML_TYPE_BF16;
}
}
else
// ### JUK ### as I'm offloading all the I've also hacked this PR to use this inside ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*1);//kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*1);//kv_size); to hopefully avoid creating the original KV cache, but not tested it yet. |
Thanks for spending the time to test it out. At F16 those tensors are ~1GB, and given the smallest functional quants of this I've seen is ~130GB narrowing it down even if possible is not as important.
Let me know if that works, right now for me it doesn't matter as I'm using fairydreaming's PR to mmap the KV, but I plan to like you offload everything besides the non-shared experts to VRAM, and the large and unused KV allocation in VRAM would be nice to avoid. |
I can run the perplexity test for you now, but forgot what you said - default command line using I will also try to run it for you with the same |
Yep, default command line using
Would be interesting to see if they diverge. |
I can confirm this works, eg for 32k context:
I'm only using 35/48 GB per GPU even with this as all GPU tensors as |
Nice.
The furthest I've pushed was 32K context used, with a total of 64K context available. |
Any chance you have run the tests? Here are my results.
They are all roughly the same size although each version is bigger than the last. I won't keep that up though, If I V4, it will be sized between V1 and V2. |
Sorry, I had a drive fail in a raid array where all this is stored and it spend ages rebuilding. I will try and run some benchmarks this week. I also have something that may speed up this PR, but not sure yet as it has to do a lot of SVDs on huge matrices and is taking a while. |
@fairydreaming Just saw your post on LocalLLaMA:
I'm not sure why
So that alone would account for a 33% speedup if |
What tools are available for profiling GGML operations? Is there some way we can see exactly what is happening during prompt processing as it seems oddly low to me - I can see it pulling stuff in to process on the GPU, but it seems to gain almost nothing as you increase the batch size for a large prompt? I get the feeling something weird is going on regarding the experts, but didn't manage to get any further than tracing the GGML code to the |
AFAIK in ktransformers the whole attention computation is offloaded to GPU, that's a major source of performance boost. |
I'm already doing this using this PR: I have only the 57 sets of non-shared experts' tensor triplets on the CPU and the rest lives on the GPU. I've tried lots of other configs too:
and so on... Batch sizes seem to make no difference either: it clearly pulls something from RAM to the GPU during prompt processing, runs the GPU at about 75% load (and 1 single thread on the CPU), but a I strongly suspect something funky is going on and instead of pulling the 256 experts to process a batch on, it's pulling each in/out of RAM or something, but sadly no idea how to profile GGML to really get to the bottom of what is happening :/ |
@jukofyork For starters you could make a debug build of llama.cpp and set GGML_SCHED_DEBUG=2 environment variable, this will print the whole model graph to console. You can see there if a given tensor is on a GPU or on a CPU. Maybe this will give you some answers. |
Thanks - I'll try and look into this! |
It's well worth playing with the #11397 PR if you have a GPU as well now: I got quite a nice generation increase by flushing the page cache, using |
So I've spent all day on this: 1. It's not in
|
I've just spoke to my brother and he actually has a single-CPU server with a Sapphire Rapid CPU in it running 8 channels of 48GB 6800 DDR5 (no GPU in it though). I'll hopefully see him sometime this week and will test @saood06 I've started running the tests: numactl --interleave=all ./lama-perplexity --n-gpu-layers 99 --override-tensor exps=CPU --override-tensor attn_kv_b=CPU --numa distribute --threads 30 --model ./DeepSeek-R1-mla-Q5_K_XL.gguf --file ./wiki.test.raw Where if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q6_K) {
if (name.find("_exps") != std::string::npos) {
if (name.find("ffn_down") != std::string::npos) {
new_type = GGML_TYPE_Q6_K;
}
else {
if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
new_type = GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) {
new_type = GGML_TYPE_Q5_K;
}
else {
new_type = GGML_TYPE_Q6_K;
}
}
}
else if (name.find("attn_") != std::string::npos && name.find("_output") == std::string::npos) {
new_type = GGML_TYPE_BF16;
}
else {
new_type = GGML_TYPE_Q8_0;
}
}
else
I will try |
Even without AMX results got ~80 t/s for PP on KTransformers with a single 4090. I want to try KTransformers but it requires a local GPU, and mine is on a different machine. |
If we set
I can't see anything about BLAS here, but it looks like it made a separate backend maybe:
|
Also just for your reference in KTransformers "The GPU part, which uses Marlin, will always run in 4-bit mode", since you are running your GPU weights with far higher bpw.
Thanks taking the sum of the 16 PPL for each of our quant.
Yours is definitely better, which is expected given mine is closer in size to Q4_K based not Q5_K. |
I think there might still be a problem with the overflowing I can't be arsed doing any more today, but it should be possible to balance whatever is overflowing into the |
Yes, but the BLAS backend does not support mul_mat_id. If you want to be sure, set |
This PR introduces various optimizations for DeepSeek V2/V3 implementation:
Note that you need to reconvert the model to use this implementation.
Performance compared to the previous "naive" implementation:
CUDA performance is worse for short context lengths, but the curve is flatter:
TODO:
address regressions in prompt processing performance (different permutations of tensors?)- I don't think it's possible, as this implementation is more compute-intensive compared to regular attention implementation