From 574f8ebde44ba1cad52bc7ba3bb4ff5c693f8f79 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Tue, 27 Feb 2024 15:08:44 +0400 Subject: [PATCH] Choose right attn --- src/petals/models/mixtral/block.py | 2 -- src/petals/models/mixtral/model.py | 2 +- src/petals/server/from_pretrained.py | 6 ++---- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/petals/models/mixtral/block.py b/src/petals/models/mixtral/block.py index 5949b1584..8dee322cc 100644 --- a/src/petals/models/mixtral/block.py +++ b/src/petals/models/mixtral/block.py @@ -36,13 +36,11 @@ def forward( past_key_values_length = past_key_value[0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length _past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length) - # TODO: remove DynamicCache past_key_value = DynamicCache() for idx in range(self.layer_idx): past_key_value.update(torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx) past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx) - # TODO: make sure it's working if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None diff --git a/src/petals/models/mixtral/model.py b/src/petals/models/mixtral/model.py index 92571f0c9..13a8b32e7 100644 --- a/src/petals/models/mixtral/model.py +++ b/src/petals/models/mixtral/model.py @@ -78,7 +78,7 @@ def forward( assert not output_attentions, f"{output_attentions=} is not supported" assert not output_hidden_states, f"{output_hidden_states=} is not supported" assert return_dict is None or return_dict, f"{return_dict=} is not supported" - assert not output_router_logits, f"{output_router_logits=} is not supported" # TODO: check this + assert not output_router_logits, f"{output_router_logits=} is not supported" if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index bab17a3f1..95cfbd824 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -19,7 +19,7 @@ from hivemind.utils.logging import get_logger from huggingface_hub import get_hf_file_metadata, hf_hub_url from huggingface_hub.utils import EntryNotFoundError -from transformers import PretrainedConfig +from transformers import PretrainedConfig, PreTrainedModel from transformers.utils import get_file_from_repo from petals.constants import DTYPE_MAP @@ -52,10 +52,8 @@ def load_pretrained_block( torch_dtype = resolve_block_dtype(config, torch_dtype) with init_empty_weights(): - # TODO: Remove this if config.block_class == WrappedMixtralBlock: - # TODO: figure out why sdpa is always choosen - config._attn_implementation = "sdpa" + config = PreTrainedModel._autoset_attn_implementation(config) block = config.block_class(config, block_index) else: block = config.block_class(config)