From d15b46ae8f2014160a62558469dd974e7a90fc28 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Tue, 6 Feb 2024 13:59:44 +0400 Subject: [PATCH] Fix generation --- src/petals/models/mixtral/block.py | 15 ++++++++++++--- src/petals/models/mixtral/model.py | 6 +++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/petals/models/mixtral/block.py b/src/petals/models/mixtral/block.py index 7afd5ddd0..67eebce6a 100644 --- a/src/petals/models/mixtral/block.py +++ b/src/petals/models/mixtral/block.py @@ -26,7 +26,6 @@ def forward( use_cache: bool = False, **kwargs ): - print(self.layer_idx) batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length @@ -37,7 +36,6 @@ def forward( past_key_values_length = past_key_value[0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length _past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length) - print(_past_key_value) # TODO: remove DynamicCache past_key_value = DynamicCache() for idx in range(self.layer_idx): @@ -73,8 +71,19 @@ def forward( sliding_window=self.sliding_window, ) + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + outputs = super().forward( - hidden_states, *args, attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache, **kwargs + hidden_states, + *args, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + **kwargs ) if use_cache: diff --git a/src/petals/models/mixtral/model.py b/src/petals/models/mixtral/model.py index e1a1675c9..92571f0c9 100644 --- a/src/petals/models/mixtral/model.py +++ b/src/petals/models/mixtral/model.py @@ -94,6 +94,10 @@ def forward( hidden_states = inputs_embeds output_shape = input_shape + (hidden_states.size(-1),) + if past_key_values is None: + past_key_values = RemotePastKeyValues() + past_key_values.update_seen(hidden_states.size(1)) + hidden_states = self.layers( hidden_states, prompts=intermediate_prompts, @@ -109,7 +113,7 @@ def forward( hidden_states = hidden_states.view(output_shape) return MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=RemotePastKeyValues(), + past_key_values=past_key_values, hidden_states=None, attentions=None, )