diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 7263900fbe1..0341988b1a5 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -121,13 +121,15 @@ def merge_qkv(module: torch.nn.Module): def scale_rope(module: torch.nn.Module): if module.__class__.__name__ == "LlamaLinearScalingRotaryEmbedding": + module.register_buffer("inv_freq_scaled", None, persistent=False) if hasattr(module, "scaling_factor"): - module.inv_freq /= module.scaling_factor + module.inv_freq_scaled = module.inv_freq / module.scaling_factor elif hasattr(module, "rope_kwargs"): - module.inv_freq /= module.rope_kwargs.factor + module.inv_freq_scaled = module.inv_freq / module.rope_kwargs.factor elif module.__class__.__name__ == "LlamaRotaryEmbedding": if hasattr(module, "rope_kwargs") and module.rope_kwargs.rope_type == "linear": - module.inv_freq /= module.rope_kwargs.factor + module.register_buffer("inv_freq_scaled", None, persistent=False) + module.inv_freq_scaled = module.inv_freq / module.rope_kwargs.factor def llama_attention_forward( @@ -158,8 +160,12 @@ def llama_attention_forward( import xe_addons if hasattr(self, "rotary_emb"): # transformers < 4.46 - xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, - query_states, key_states) + if hasattr(self.rotary_emb, "inv_freq_scaled"): + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq_scaled, position_ids, + query_states, key_states) + else: + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: # transformers >= 4.46 cos, sin = position_embeddings