-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IQ1_S_R4: better 1.5 bpw quants #185
Conversation
With this mix we arrive at PPL(512) = 9.4140 for Deepseek-Lite using 1.766 bpw for the repeating layers. On the Ryzen-7950X we get PP-512 = 494 t/s and TG-128 = 52 t/s @ 16 threads.
I do.
It is.
Sadly, it doesn't really function. I haven't tried his IQ1_S, but yours might just be too small. You did a 127 GB. The unsloth creator said on reddit "I had a 127GB version, but it didn't go that good". |
@saood06 Do you have by any chance the quantization log? It would be useful to have it to verify that the intended tensors with higher bpw are correctly selected. It ends up being smaller than Unsloth's because Oh, the other thing is that I did not change the default quantization for the token embeddings. It will use |
Yes, I had to do some tweaks to it as well to work with the new tensor. It is in the log below. I want to say, I'm happy with my IQ4_K_R4, using this saood06/pull/1 I got all the way up to 30K context fitting on 384 GB of RAM without any cache quantization.
Log
|
I think Do you have an imatrix with the changed attention tensors? |
I can try that, will let you know later as this quant takes a bit of time to make.
No, and I don't have the dataset or the compute. The new tensors are split from an old one is there a chance they could be converted from the old one? |
In that case I would simply use |
I'll do that. I'll probably remake my IQ4_K_R4 with these changes. |
You may also want to change else if (qs.model.hparams.n_expert >= 8 && (name.find("blk.0.ffn_down") != std::string::npos ||
name.find("blk.0.ffn_gate") != std::string::npos ||
name.find("blk.0.ffn_up") != std::string::npos)) {
new_type = GGML_TYPE_IQ3_K_R4;
} to else if (qs.model.hparams.n_expert >= 8 && (name.find("ffn_down.weight") != std::string::npos ||
name.find("ffn_gate.weight") != std::string::npos ||
name.find("ffn_up.weight") != std::string::npos)) {
new_type = GGML_TYPE_IQ4_K_R4;
} This will cost ~0.4 GiB in quantized model size increase. The check is like this because in DeepSeek-Lite there is a single layer without MoE, but in DeepSeek-R1 there are 3 such layers, and my guess is that those are important to get things on the right track before the experts get involved. |
Will do, just a question why for attn_q and attn_k do you use Q4_K_R4 and not IQ4_K_R4. My IQ4_K_R4 uses IQ4_K_R4 for those. |
Because of copy/paste. It can be changed to |
I changed some things but it still didn't work. Log
|
When you say "It didn't work", how did it not work? Produced NaNs? Produced gibberish? Produced something like human language but with no real meaning? It isn't as coherent as a higher bit quantization? |
Original one produced just NaNs |
else if (i_layer < n_layer/8) { | ||
new_type = GGML_TYPE_Q2_K_R4; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this need to be higher for R1? The unsloth quant does this up to and including layer 8, my most recent attempt only did up to and including layer 6.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the early layers tend to be more important, so increasing the number of layers and/or increasing the bpw of the quantization used will improve results. It is basically a matter of the balance between quantization quality and model size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in DeepSeek-Lite there is a single layer without MoE, but in DeepSeek-R1 there are 3 such layers
The additional 2 layers of dense, means you hit 2 less MoE layers with this then you do on Lite, and this is still the only meaningful way I can see that the quant I just made is worse than the unsloth one, basically everything else is better, or the same.
Hmm, not sure. The token probabilities are not completely useless (same top-4 tokens). It is possible the imatrix is not adequate. 4+ bpw quants work even without an imatrix, so a bad imatrix is not immediately recognizable. I see in the log that 315 chunks were used. We have 8 out of 256 experts being active, so each expert got on average less than 10 chunks. That's not a lot of data to properly determine the relative importance of the tensor columns. In case you have time and energy:
It is of course also possible that removing the super-block scale in |
The one unsloth uses is significantly shorter, only 124. I also do believe the imatrix data is better. The Arctic MoE the person who's imatrix's I use activated all but one expert and they tried hard to get the last one to no avail. All other imatrix activated far less.
I think this is to be expected. It is a whole different attention mechanism. MLA uses less bits to represents the KV, it is far better at conserving information while compressing the KV cache compared to GQA, but it is still less bits than MHA. They claim it is better than MHA because redundancy in information between heads means you do have some effectively lossless compression. But I've seen enough people actually micro benchmark MHA and MLA and it does seem a bit worse. The real benefit of MLA is that it uses less bits, and there was a branch I was working on which allowed me to make use of that (thanks to another one of fairydreaming's PR), which uses mmap to avoid allocating KV until used which means the old gigantic KV (full 128k is ~600 GB), does not allocate and start paging me out. I was able to request 64K of context ( CPU NUMA KV buffer size = 313101.56 MiB ) from server and I used 30K before ending that test, and it never paged to disk thanks to the mmap only allocating what was used. I also did not quantize the cache at all, as with MLA it was already so small. I saw your PR #188 , there was some minor optimizations from fairydreaming that have that haven't made it to my PR ( #180 ) , along with some other stuff from fairydreaming that is experimental (mmap) and QoL stuff (MoE warmup actually loads in all experts) in this branch saood06/pull/1 . Although the mmap allocator is working for me (and I might create a PR with it being toggled via a CLI argument) I think when MLA is toggled on the other KV cache should not allocate.
When I have some more time I will. |
** is data that was posted by other people online, not my tests. Edit: |
@saood06 Thanks for these results. So, it looks like I have added some extra guards in #191, but they never trigger with DeepSeek-Lite or LLaMA-3.1-8B-Instruct, so not sure if this will help. It may be useful to try |
I have tested #192 by merging it into my WIP testing branch, saood06/pull/1. IQ1_S_R4 (V2) in my single very basic test it now functions (produced coherent output), but it still produced Only including new results in the table below.
IQ4_K_R4 (V2) is slower (2.63 t/s for V2 vs 3.22 t/s V1) for TG probably because it uses IQ6_K as IQ6_K_R4 does not exist, and thus for now I still think I prefer V1 even with its flaws. Off topic but when should you use Q8_K_R8 vs Q8_0_R8? Also there may be some MLA quality issues, there is some discussion happening over at ggerganov/llama.cpp#11446 where setting GGML_TYPE_F32 for some tensors helped quality (GGML_TYPE_F16 for those tensors broke it, while Q8_0 worked but with noticeably degraded performance). IQ4_K_R4 V1 quantization logsload_imatrix: imatrix dataset='imatrix-training-full-3' llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.0.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.1.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.2.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.3.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.4.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.5.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.6.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.7.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.8.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.9.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.10.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.11.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.58.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.59.attn_k_b.weight llama_tensor_get_type : tensor cols 128 x 65536 are not divisible by 256, required for iq4_k_r4 - using fallback quantization q5_0 ====== llama_model_quantize_internal: did not find weights for blk.60.attn_k_b.weight main: quantize time = 13788349.37 ms IQ4_K_R4 V2 quantization logsload_imatrix: imatrix dataset='imatrix-training-full-3' main: quantize time = 10290932.85 ms Quantization logs had to be truncated to fit github comment length limits. |
Just saw this thread linked from the main MLA PR:
This 128 token prompt:
seems to be a good test of the model getting dumber, eg:
|
I was just about to edit my comment, and mention attn_k_b.weight. Since you found your way here, I want to tell you with a 4.52BPW (using quant types that are better than those that exist on mainline llama.cpp), on a dual socket dual socket Xeon E5-2690 v3 without any offloading I get this performance ( I use batched-bench to test PP performance as context grows, and also spot test TG performance at various context depths).
My initial tests with offloading ( on mainline llama.cpp with the PR that lets override tensor placement to keep non-shared experts on CPU) showed worse performance the more layers I offloaded. This fork currently is missing some RPC fixes that would support this model, and also some RPC performance tweaks, but I do plan to bring those over here. Edit:
This I've noticed and it has bothered me, although I don't have much reference as almost all of my usage has been with MLA, and the little that hasn't has been at low contexts. |
Anytime the tiny difference in accuracy does not matter to you (and a block size of 256 is possible). It is faster than Here is a PP performance comparison between
And here the same comparison on Zen4 (Ryzen-7950X)
In these tables To put things in perspective, the best mainline On the Ryzen-7950X memory bandwidth is fully saturated with just 2 threads with |
Concerning If there are indeed activations that fall outside the |
Given the hype around DeepSeek's models and Unsloth's sub-2 bpw quantization of DeepSeek-R1 using
IQ1_S/IQ1_M
, I decided to give some love to sub-2 bpw quants. This PR addsIQ1_S_R4
, a 4-row interleaved version ofIQ1_S
.IQ1_S_R4
uses 1.5 bpw instead of the 1.5625 bpw needed byIQ1_S
. Thef16
super-block scale is removed and is replaced by af16
scale per rowIQ1_S_R4
is implemented with a block size of 32. I wanted to have this because DeepSeek-Lite, the model I'm testing with, has a lot of tensors with row sizes not divisible by 256, so a significant fraction of tensors gets quantized toIQ4_NL
when usingIQ1_S
llama.cpp
since the introduction of k-quants. The only reason it does not work well for DeepSeek's models is that the attention tensors have different names so that the heuristics used to assign a higher bpw quantization to the attention tensors fails. Case in point, today's mainlinellama.cpp
arrives at a context-512 perplexity (PPL(512)
in what follows) of 36.8 for DeepSeek-Lite using 2.62 bpw. TheIQ1_S_R4
quantization in this PR getsPPL-512 = 9.4
with 1.766 bpw for the repeating layers.IQ1_S_R4
is much faster on the CPU compared toIQ1_S
(see tables below). I never implemented iqk-style GEMM forIQ1_S/IQ1_M
, so these quantization types run at the snail speed of mainlinellama.cpp
.The following table compares prompt processing (pp512) and token generation (tg128) speed for LLaMA-3.1-8B on
AVX2
(Ryzen-5975WX),Zen4
(Ryzen-7950X) andARM_NEON
(M2-Max CPU). I didn't use DeepSeek-Lite for this comparison to avoid the difference in quantization types one ends up with due to not all tensors having row sizes that are multiple of 256.I don't have the disk space and RAM to play with DeepSeek-R1, so I would be really curious to hear from someone trying this PR for this model. It should be quite a bit faster than mainline, and I wouldn't be surprised if quality is better than Unsloth's
IQ1_S
quantization.