diff --git a/lade/decoding.py b/lade/decoding.py index d0b5893..24348a3 100644 --- a/lade/decoding.py +++ b/lade/decoding.py @@ -411,40 +411,19 @@ def copy_from_last(): attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat((attention_mask, torch.ones(1, max_hit, device=attention_mask.device, dtype=attention_mask.dtype)), dim=1) - #support awq - - if not USE_AWQ: - past_key_values = [] - for idx, kv in enumerate(outputs.past_key_values): - for hh in range(max_hit): - assert outputs.step_len == kv[idx][0].size(2) - kv[idx][0][:,:,outputs.kvcache_len + hh,:] = kv[idx][0][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:] - kv[idx][1][:,:,outputs.kvcache_len + hh,:] = kv[idx][1][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:] - past_key_values.append( (kv[idx][0][:,:,:outputs.kvcache_len + max_hit,:], kv[idx][1][:,:,:outputs.kvcache_len + max_hit,:]) ) - outputs.past_key_values = past_key_values - - else: - + #not support awq + #print("kv: ", outputs.past_key_values) + assert not USE_AWQ + past_key_values = [] + for idx, kv in enumerate(outputs.past_key_values): for hh in range(max_hit): - #print("cache: ", outputs.kvcache_len, max_hit, outputs.step_len, window_cache[0].k.size(), window_cache[0].v.size()) - for idx, kv in enumerate(window_cache): - kv.k[:,:,:,outputs.kvcache_len + hh,:] = kv.k[:,:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:] - kv.v[:,:,outputs.kvcache_len + hh,:] = kv.v[:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:] - - - past_key_values = [] - for idx, kv in enumerate(outputs.past_key_values): - for hh in range(max_hit): - assert outputs.step_len == kv[idx][0].size(2) - past_key_values.append( (kv[idx][0][:,:,:outputs.kvcache_len + max_hit,:], kv[idx][1][:,:,:outputs.kvcache_len + max_hit,:]) ) - outputs.past_key_values = past_key_values - + assert outputs.step_len == kv[0].size(2) + kv[0][:,:,outputs.kvcache_len + hh,:] = kv[0][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:] + kv[1][:,:,outputs.kvcache_len + hh,:] = kv[1][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:] + past_key_values.append( (kv[0][:,:,:outputs.kvcache_len + max_hit,:], kv[1][:,:,:outputs.kvcache_len + max_hit,:]) ) + outputs.past_key_values = past_key_values lst_token = hits[max_hit] - def sublist(lst1, lst2): - ls1 = [element for element in lst1 if element in lst2] - ls2 = [element for element in lst2 if element in lst1] - return ls1 == ls2 for hh in range(max_hit + 1): if eos_token_id is not None and hits[hh] == eos_token_id[0]: @@ -455,9 +434,6 @@ def sublist(lst1, lst2): max_hit = hh break else: - # - # - # all_old_tokens.append(hits[hh]) if chat: diff --git a/lade/models/llama.py b/lade/models/llama.py index 381bc43..b632d53 100644 --- a/lade/models/llama.py +++ b/lade/models/llama.py @@ -214,9 +214,9 @@ def LlamaModeljforward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -243,7 +243,7 @@ def LlamaModeljforward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -254,7 +254,9 @@ def LlamaModeljforward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( diff --git a/requirements.txt b/requirements.txt index fcc18e1..069d082 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers==4.34.0 +transformers==4.36.2 accelerate==0.23.0 fschat==0.2.31 openai