Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create INT8 KV Cache on Qserve #2446

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/llama/convert_checkpoint.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def convert_and_save_hf(args):
# llava_llama needs its own defined config.
logger.warning("AutoConfig cannot load the huggingface config.")

if args.smoothquant is not None or args.int8_kv_cache:
if (args.smoothquant is not None or args.int8_kv_cache) and not args.use_qserve:
assert not args.load_by_shard, "When using quantization, TRT-LLM needs to load the whole HF model, thus load by shard not supported"
mapping = Mapping(world_size=world_size,
tp_size=args.tp_size,
Expand Down Expand Up @@ -474,6 +474,10 @@ def convert_and_save_rank(args, rank):
args.dtype,
mapping=mapping,
quant_config=quant_config,
device='cpu' if args.load_model_on_cpu else 'cuda',
calib_dataset=args.calib_dataset,
calib_batches=args.calib_size,
calib_max_seq_length=args.calib_max_seq_length,
load_by_shard=load_by_shard,
**override_fields,
)
Expand Down
65 changes: 63 additions & 2 deletions tensorrt_llm/models/llama/convert.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
retrieved_layer_index_from_name, smooth_gemm,
smooth_gemm_fc1_gate, split, split_matrix_tp,
split_qkv_bias_tp, split_qkv_tp)
from ..modeling_utils import PretrainedConfig
from ..modeling_utils import PretrainedConfig, QuantConfig
from .config import LLaMAConfig


Expand Down Expand Up @@ -1921,7 +1921,16 @@ def process_and_assign_weight(v: List[torch.Tensor],
return weights


def load_weights_from_lmquant(lmquant_ckpt_path: str, config: LLaMAConfig):
def load_weights_from_lmquant(
lmquant_ckpt_path: str,
config: LLaMAConfig,
quant_config: QuantConfig,
hf_model_dir: str,
device: str = "cuda",
calib_dataset: str = "cnn_dailymail",
calib_batches: int = 512,
calib_max_seq_length: int = 512,
):
logger.info(
'Loading weights from lmquant torch checkpoint for QServe W4A8 inference...'
)
Expand All @@ -1945,6 +1954,40 @@ def load_weights_from_lmquant(lmquant_ckpt_path: str, config: LLaMAConfig):
quant_params = torch.load(lmquant_ckpt_path + '/scale.pt',
map_location='cpu')

int8_kv_cache = quant_config.kv_cache_quant_algo == QuantAlgo.INT8

act_range = {}
if int8_kv_cache:
hf_model = AutoModelForCausalLM.from_pretrained(
hf_model_dir,
device_map=device if device != 'cpu' else 'cpu',
torch_dtype='auto',
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
hf_model_dir,
trust_remote_code=True,
use_fast=False,
padding_side='left')

dataset = load_calib_dataset(calib_dataset)

if calib_batches == -1:
calib_batches = len(dataset)

model_prefix, layer_prefix, param_name_map = get_prefix_and_param_name_map(
config.architecture, use_safetensors=True)

lmquant_keys = fake_quant_weights.keys()
for name, param in hf_model.named_parameters():
if name in lmquant_keys:
param.data = fake_quant_weights[name].data.to(device)

act_range = capture_activation_range(hf_model,
tokenizer,
dataset,
num_samples=calib_batches,
seq_len=calib_max_seq_length)

def load(key):
if 'zero' in key:
v = quant_params[key]
Expand Down Expand Up @@ -2082,6 +2125,24 @@ def process_weight_and_params(v: List[torch.Tensor], tllm_prex: str):
]
weights.update(
process_weight_and_params(qkv, f'{tllm_prex}.attention.qkv'))

if int8_kv_cache:
act_range_prefix = f'{model_prefix}.{layer_prefix}.{layer_idx}.'
qkv_y = torch.cat([
# act_range.get(act_range_prefix +
# f'{param_name_map["attention.qkv"]}.q_proj')["y"],
act_range.get(act_range_prefix +
f'{param_name_map["attention.qkv"]}.k_proj')["y"],
act_range.get(act_range_prefix +
f'{param_name_map["attention.qkv"]}.v_proj')["y"]
], dim=0)

int8_kv_scales = qkv_y.max() / 127.

kv_cache_weights = {}

kv_cache_weights[f'{tllm_prex}.attention.kv_cache_scaling_factor'] = int8_kv_scales.reshape([1])
weights.update(kv_cache_weights)

# 4.2 attention.dense
v = [
Expand Down
8 changes: 7 additions & 1 deletion tensorrt_llm/models/llama/model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ def from_hugging_face(
dtype: str = 'auto',
mapping: Optional[Mapping] = None,
quant_config: Optional[QuantConfig] = None,
device: str = 'cuda',
calib_dataset: str = 'cnn_dailymail',
calib_batches: int = 512,
calib_max_seq_length: int = 512,
**kwargs):
''' Create a LLaMAForCausalLM object from give parameters
'''
Expand Down Expand Up @@ -413,7 +417,9 @@ def from_hugging_face(
if quant_config.quant_mode.is_int4_weight_only():
weights = load_weights_from_gptq(quant_ckpt_path, config)
elif quant_config.quant_mode.is_qserve_w4a8():
weights = load_weights_from_lmquant(quant_ckpt_path, config)
weights = load_weights_from_lmquant(quant_ckpt_path,
config, quant_config, hf_model_dir,
device, calib_dataset, calib_batches, calib_max_seq_length)
else:
raise ValueError(
"quant_ckpt_path should be specified only for GPTQ or QServe"
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/quantization/quantize.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def qserve_quantize(model, quant_config: QuantConfig):
def kv_cache_quantize(model):
for name, module in model.named_modules():
if isinstance(module,
(Attention, SmoothQuantAttention, Fp8RowwiseAttention)):
(Attention, SmoothQuantAttention, Fp8RowwiseAttention, QServeAttention)):
module.kv_cache_scaling_factor = Parameter(shape=(1, ),
dtype='float32')
return model
Expand Down