Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请问 kv_cluster.update_kv 只在 prefill 时被调用吗? #20

Open
DOG-wooooof opened this issue Jul 22, 2024 · 3 comments
Open

请问 kv_cluster.update_kv 只在 prefill 时被调用吗? #20

DOG-wooooof opened this issue Jul 22, 2024 · 3 comments

Comments

@DOG-wooooof
Copy link

DOG-wooooof commented Jul 22, 2024

感谢您分享代码!
在自回归生成任务时,根据 llama_model.py: 92-98 的实现,似乎只在最开始进行 prefill,即self.kv_seq_len == 0 时能够进入 if key_states.shape[-2] == kv_seq_len: 分支,从而调用 kv_cluster.update_kv(…)
在之后的 decoding 过程中每次forward新输入的seq_len(即 key_states.shape[-2])为1,总不等于累加上 KVCache长度的 kv_seq_len。故 kv_cluster.update_kv 不再被调用,每次新生成的 k/v 都会保留下来。

        if key_states.shape[-2] == kv_seq_len:
            self.kv_seq_len = kv_seq_len
            key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
            past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
        else:
            self.kv_seq_len += q_len
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

如果理解无误,是否意味着 kv_cluster.update_kv 只在 prefill 时被调用?而在 decoding 中生成的 k/v 总会被保留在 past_key_value 中,以至于超过设定的上限?
我在 run_longbench.py 的同级主目录下添加简单的测试文件:

import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from pyramidkv.monkeypatch import replace_llama,replace_mistral

# 使用 pipeline 推理
if __name__ == "__main__":
    model_path = "llama-2-hf"
    method = "StreamingLLM"
    attn_implementation="flash_attention_2"
    max_capacity_prompts = 64

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        use_fast=True,
        padding_side="left"
    )

    replace_llama(method.lower())
    replace_mistral(method.lower())
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto",
        use_cache=True,
        attn_implementation=attn_implementation
    )

    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
            
    model.eval()
    
    # 复制自 run_longbench.py
    if method != "FullKV":
        if method.lower() in ["snapkv","pyramidkv","h2o"]:
            window_sizes = 8
        elif method.lower() in ["streamingllm"]:
            window_sizes = max_capacity_prompts - 4
        kernel_sizes = 7
        pooling = "maxpool"

        layers = len(model.model.layers)
        if not isinstance(window_sizes, list):
            window_sizes = [window_sizes] * layers
        if not isinstance(max_capacity_prompts, list):
            max_capacity_prompts = [max_capacity_prompts] * layers
        if not isinstance(kernel_sizes, list):
            kernel_sizes = [kernel_sizes] * layers
        for i in range(layers):
            model.model.layers[i].self_attn.config.window_size = window_sizes[i]
            model.model.layers[i].self_attn.config.max_capacity_prompt = max_capacity_prompts[i]
            model.model.layers[i].self_attn.config.kernel_size = kernel_sizes[i]
            model.model.layers[i].self_attn.config.pooling = pooling

    mypipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        torch_dtype=torch.float16,
        device_map="auto",
    )

    sequences = mypipeline(
        'I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\n',
        do_sample=True,
        top_k=1,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        max_length=200,
    )
    for seq in sequences:
        print(f"Result: {seq['generated_text']}")

观察到输出为

Using StreamingLLM!
Using StreamingLLM!
kv_seq_len: 28
key_states.shape: torch.Size([1, 32, 28, 128])
self.kv_seq_len:  0
call kv_cluster.update_kv()
===================================
kv_seq_len: 29
key_states.shape: torch.Size([1, 32, 1, 128])
self.kv_seq_len:  28
===================================
kv_seq_len: 30
key_states.shape: torch.Size([1, 32, 1, 128])
self.kv_seq_len:  29
===================================
......
===================================
kv_seq_len: 199
key_states.shape: torch.Size([1, 32, 1, 128])
self.kv_seq_len:  198
===================================
Result: I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?
I'm not sure if you're a fan of "The Wire" or not, but I'd recommend that. It's a show about the Baltimore police department, and it's a very realistic portrayal of the inner workings of the police force. It's also a very well-written show, and it's one of the best shows I've ever seen.

可以看到仅在开始时调用了kv_cluster.update_kv,且最终past_key_value保存的 KVCache超过了设置的上限 64。
我阅读了您的论文,这种实现方式是否是因为目前仅对压缩 长上下文输入/长文档多轮对话 感兴趣,故仅压缩prefill?
再者,参考H2O的开源代码(https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py )以及huggingface transformers库里的StreamingLLM实现(https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py 621行 SinkCache),二者似乎都是随decoding进行不断动态压缩的,即每次forward时都尝试压缩KVCache。为此,baseline的实现和对比实验是否恰当?

@liuxiaozhu01
Copy link

同问。PyramidKV 和 SnapKV 都是仅压缩prefill,二者的实现是一致的。但是H2O在decoding过程kv cache是动态变化的,如楼上所述。

@Zefan-Cai
Copy link
Owner

Zefan-Cai commented Jul 24, 2024

SnapKV和PyramidKV确实是只在prefill时被调用。原始的H2O和StreamingLLM也确实是动态维持一个KV cache序列,所以在decoding time和prefill time都会调用。为了公平对比,文中汇报的H2O和StreamingLLM,从它们的原始实现做了改动,也是只在prefill阶段调用。改动以后的H2O和StreamingLLM也会超过设置的上限,这样比起直接使用原始baseline相对恰当。

@DOG-wooooof
Copy link
Author

感谢回答。
对于promp压缩,这样的实现是合理的,包括SnapKV论文里也有的 Longbench实验都是合理的。我在最开始对比阅读论文时产生了一些误解,您的解释是对的。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants