Skip to content

Commit

Permalink
pattern
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Sep 26, 2024
1 parent 22b0f0e commit b731ea9
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
3 changes: 1 addition & 2 deletions dolomite_engine/hf_models/models/rnn_dolomite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class RNNDolomitePreTrainedModel(PreTrainedModelMixin):
config_class = RNNDolomiteConfig
layer_class = RNNDolomiteBlock
_no_split_modules = ["RNNDolomiteBlock"]
_supports_sdpa = False

def __init__(self, config: RNNDolomiteConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
Expand Down Expand Up @@ -52,7 +51,7 @@ def _init_model(self, config: RNNDolomiteConfig, **kwargs) -> None:
self.layer_class(
config,
normalization_implementation=self.normalization_implementation,
attention_implementation=self.attention_pattern[i],
attention_pattern=self.attention_pattern[i],
use_padding_free_transformer=self._use_padding_free_transformer,
layer_idx=i,
)
Expand Down
8 changes: 4 additions & 4 deletions dolomite_engine/hf_models/models/rnn_dolomite/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
self,
config: RNNDolomiteConfig,
normalization_implementation: str,
attention_implementation: str,
attention_pattern: str,
use_padding_free_transformer: bool,
layer_idx: int | None = None,
) -> None:
Expand All @@ -38,12 +38,12 @@ def __init__(
normalization_implementation=normalization_implementation,
)

if attention_implementation == "DeltaNet":
if attention_pattern == "DeltaNet":
self.attn = DeltaNet(config=config, layer_idx=layer_idx)
elif attention_implementation == "flash_attention_2":
elif attention_pattern == "flash_attention_2":
self.attn = RNNFlashAttention2(config, True, layer_idx)
else:
raise ValueError(f"Attention implementation {attention_implementation} not supported.")
raise ValueError(f"Attention pattern {attention_pattern} not supported.")

self.ln_2 = get_normalization_function(
config.normalization_function,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ model_args:
init_method: mup
tie_word_embeddings: true
upcast_logits_for_loss: true
attention_implementation: flash_attention_2

tuning_args:
tuning_method: pretraining
Expand Down

0 comments on commit b731ea9

Please sign in to comment.