From f1ae813615ff079d76a72e1a4ff69225fa165143 Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 20 Apr 2024 13:39:56 +0200 Subject: [PATCH 1/2] Upgrade to Transformers v4.39.x --- hf_transformers | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hf_transformers b/hf_transformers index a0857740c0..09f9f566de 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit a0857740c0e6127485c11476650314df3accc2b6 +Subproject commit 09f9f566de83eef1f13ee83b5a1bbeebde5c80c1 diff --git a/setup.py b/setup.py index 84beca00ae..ef6f386fbc 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ "sphinx-multiversion==0.2.4", "timeout-decorator", "torch>=1.10,!=1.12.0", - "transformers~=4.38.1", + "transformers~=4.39.3", ] From 9181bb81a14ed03dc43088170f74f77aa525c1f9 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 24 Apr 2024 22:33:26 +0200 Subject: [PATCH 2/2] llama fixes --- src/adapters/models/llama/modeling_llama.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 752cfccb10..7c99f286e4 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -90,7 +90,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -107,8 +107,7 @@ def forward( bsz = key_states.shape[0] if attention_mask is not None: # no matter the length, we just slice it - if cache_position is not None: - causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 @@ -184,7 +183,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -284,10 +283,11 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # In case static cache is used, it is an instance attribute. past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -302,8 +302,9 @@ def forward( bsz = key_states.shape[0] causal_mask = attention_mask - if attention_mask is not None and cache_position is not None: - causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] + # if attention_mask is not None and cache_position is not None: + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577.