You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from pyramidkv.monkeypatch import replace_llama,replace_mistral
# 使用 pipeline 推理
if __name__ == "__main__":
model_path = "llama-2-hf"
method = "StreamingLLM"
attn_implementation="flash_attention_2"
max_capacity_prompts = 64
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=True,
padding_side="left"
)
replace_llama(method.lower())
replace_mistral(method.lower())
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
use_cache=True,
attn_implementation=attn_implementation
)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
model.eval()
# 复制自 run_longbench.py
if method != "FullKV":
if method.lower() in ["snapkv","pyramidkv","h2o"]:
window_sizes = 8
elif method.lower() in ["streamingllm"]:
window_sizes = max_capacity_prompts - 4
kernel_sizes = 7
pooling = "maxpool"
layers = len(model.model.layers)
if not isinstance(window_sizes, list):
window_sizes = [window_sizes] * layers
if not isinstance(max_capacity_prompts, list):
max_capacity_prompts = [max_capacity_prompts] * layers
if not isinstance(kernel_sizes, list):
kernel_sizes = [kernel_sizes] * layers
for i in range(layers):
model.model.layers[i].self_attn.config.window_size = window_sizes[i]
model.model.layers[i].self_attn.config.max_capacity_prompt = max_capacity_prompts[i]
model.model.layers[i].self_attn.config.kernel_size = kernel_sizes[i]
model.model.layers[i].self_attn.config.pooling = pooling
mypipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16,
device_map="auto",
)
sequences = mypipeline(
'I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\n',
do_sample=True,
top_k=1,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=200,
)
for seq in sequences:
print(f"Result: {seq['generated_text']}")
观察到输出为
Using StreamingLLM!
Using StreamingLLM!
kv_seq_len: 28
key_states.shape: torch.Size([1, 32, 28, 128])
self.kv_seq_len: 0
call kv_cluster.update_kv()
===================================
kv_seq_len: 29
key_states.shape: torch.Size([1, 32, 1, 128])
self.kv_seq_len: 28
===================================
kv_seq_len: 30
key_states.shape: torch.Size([1, 32, 1, 128])
self.kv_seq_len: 29
===================================
......
===================================
kv_seq_len: 199
key_states.shape: torch.Size([1, 32, 1, 128])
self.kv_seq_len: 198
===================================
Result: I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?
I'm not sure if you're a fan of "The Wire" or not, but I'd recommend that. It's a show about the Baltimore police department, and it's a very realistic portrayal of the inner workings of the police force. It's also a very well-written show, and it's one of the best shows I've ever seen.
感谢您分享代码!
在自回归生成任务时,根据
llama_model.py: 92-98
的实现,似乎只在最开始进行 prefill,即self.kv_seq_len == 0
时能够进入if key_states.shape[-2] == kv_seq_len:
分支,从而调用 kv_cluster.update_kv(…)在之后的 decoding 过程中每次forward新输入的seq_len(即
key_states.shape[-2]
)为1,总不等于累加上 KVCache长度的kv_seq_len
。故 kv_cluster.update_kv 不再被调用,每次新生成的 k/v 都会保留下来。如果理解无误,是否意味着 kv_cluster.update_kv 只在 prefill 时被调用?而在 decoding 中生成的 k/v 总会被保留在 past_key_value 中,以至于超过设定的上限?
我在
run_longbench.py
的同级主目录下添加简单的测试文件:观察到输出为
可以看到仅在开始时调用了kv_cluster.update_kv,且最终past_key_value保存的 KVCache超过了设置的上限 64。
我阅读了您的论文,这种实现方式是否是因为目前仅对压缩 长上下文输入/长文档多轮对话 感兴趣,故仅压缩prefill?
再者,参考H2O的开源代码(https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py )以及huggingface transformers库里的StreamingLLM实现(https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py 621行 SinkCache),二者似乎都是随decoding进行不断动态压缩的,即每次forward时都尝试压缩KVCache。为此,baseline的实现和对比实验是否恰当?
The text was updated successfully, but these errors were encountered: