Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
artek0chumak committed Feb 27, 2024
1 parent d15b46a commit 2a0bde4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
6 changes: 0 additions & 6 deletions src/petals/models/mixtral/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ def forward(
past_key_value.update(torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx)
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)

# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)

# TODO: make sure it's working
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
Expand Down
2 changes: 2 additions & 0 deletions src/petals/server/from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def load_pretrained_block(
with init_empty_weights():
# TODO: Remove this
if config.block_class == WrappedMixtralBlock:
# TODO: figure out why sdpa is always choosen
config._attn_implementation = "sdpa"
block = config.block_class(config, block_index)
else:
block = config.block_class(config)
Expand Down

0 comments on commit 2a0bde4

Please sign in to comment.