Skip to content

Commit

Permalink
Fix generation
Browse files Browse the repository at this point in the history
  • Loading branch information
artek0chumak committed Feb 27, 2024
1 parent cac4254 commit d15b46a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
15 changes: 12 additions & 3 deletions src/petals/models/mixtral/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def forward(
use_cache: bool = False,
**kwargs
):
print(self.layer_idx)
batch_size, seq_length, _ = hidden_states.shape

seq_length_with_past = seq_length
Expand All @@ -37,7 +36,6 @@ def forward(
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
print(_past_key_value)
# TODO: remove DynamicCache
past_key_value = DynamicCache()
for idx in range(self.layer_idx):
Expand Down Expand Up @@ -73,8 +71,19 @@ def forward(
sliding_window=self.sliding_window,
)

position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)

outputs = super().forward(
hidden_states, *args, attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache, **kwargs
hidden_states,
*args,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
**kwargs
)

if use_cache:
Expand Down
6 changes: 5 additions & 1 deletion src/petals/models/mixtral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def forward(
hidden_states = inputs_embeds
output_shape = input_shape + (hidden_states.size(-1),)

if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))

hidden_states = self.layers(
hidden_states,
prompts=intermediate_prompts,
Expand All @@ -109,7 +113,7 @@ def forward(
hidden_states = hidden_states.view(output_shape)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
Expand Down

0 comments on commit d15b46a

Please sign in to comment.