Skip to content

Commit

Permalink
Small fix to not affect other cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscilloscope98 committed Jan 15, 2025
1 parent 15e25e8 commit b343681
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,15 @@ def merge_qkv(module: torch.nn.Module):

def scale_rope(module: torch.nn.Module):
if module.__class__.__name__ == "LlamaLinearScalingRotaryEmbedding":
module.register_buffer("inv_freq_scaled", None, persistent=False)
if hasattr(module, "scaling_factor"):
module.inv_freq /= module.scaling_factor
module.inv_freq_scaled = module.inv_freq / module.scaling_factor
elif hasattr(module, "rope_kwargs"):
module.inv_freq /= module.rope_kwargs.factor
module.inv_freq_scaled = module.inv_freq / module.rope_kwargs.factor
elif module.__class__.__name__ == "LlamaRotaryEmbedding":
if hasattr(module, "rope_kwargs") and module.rope_kwargs.rope_type == "linear":
module.inv_freq /= module.rope_kwargs.factor
module.register_buffer("inv_freq_scaled", None, persistent=False)
module.inv_freq_scaled = module.inv_freq / module.rope_kwargs.factor


def llama_attention_forward(
Expand Down Expand Up @@ -158,8 +160,12 @@ def llama_attention_forward(
import xe_addons
if hasattr(self, "rotary_emb"):
# transformers < 4.46
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
if hasattr(self.rotary_emb, "inv_freq_scaled"):
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq_scaled, position_ids,
query_states, key_states)
else:
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
# transformers >= 4.46
cos, sin = position_embeddings
Expand Down

0 comments on commit b343681

Please sign in to comment.