Skip to content

Commit

Permalink
Choose right attn
Browse files Browse the repository at this point in the history
  • Loading branch information
artek0chumak committed Feb 27, 2024
1 parent 2a0bde4 commit 574f8eb
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 7 deletions.
2 changes: 0 additions & 2 deletions src/petals/models/mixtral/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ def forward(
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
# TODO: remove DynamicCache
past_key_value = DynamicCache()
for idx in range(self.layer_idx):
past_key_value.update(torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx)
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)

# TODO: make sure it's working
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
Expand Down
2 changes: 1 addition & 1 deletion src/petals/models/mixtral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def forward(
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
assert not output_router_logits, f"{output_router_logits=} is not supported" # TODO: check this
assert not output_router_logits, f"{output_router_logits=} is not supported"

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
Expand Down
6 changes: 2 additions & 4 deletions src/petals/server/from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from hivemind.utils.logging import get_logger
from huggingface_hub import get_hf_file_metadata, hf_hub_url
from huggingface_hub.utils import EntryNotFoundError
from transformers import PretrainedConfig
from transformers import PretrainedConfig, PreTrainedModel
from transformers.utils import get_file_from_repo

from petals.constants import DTYPE_MAP
Expand Down Expand Up @@ -52,10 +52,8 @@ def load_pretrained_block(
torch_dtype = resolve_block_dtype(config, torch_dtype)

with init_empty_weights():
# TODO: Remove this
if config.block_class == WrappedMixtralBlock:
# TODO: figure out why sdpa is always choosen
config._attn_implementation = "sdpa"
config = PreTrainedModel._autoset_attn_implementation(config)
block = config.block_class(config, block_index)
else:
block = config.block_class(config)
Expand Down

0 comments on commit 574f8eb

Please sign in to comment.