Skip to content

Commit

Permalink
Fixing MPT model issue for being out dated
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed May 16, 2024
1 parent d60370b commit 0d7d977
Show file tree
Hide file tree
Showing 8 changed files with 430 additions and 379 deletions.
23 changes: 21 additions & 2 deletions python_test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def create_test_for_models(
partition_specs = match_partition_rules(config.get_partition_rules(True), params)
shard, _ = make_shard_and_gather_fns(partition_specs, jnp.float32)

params = jax.tree_map(lambda p, f: f(p), params, shard)
params = jax.tree_util.tree_map(lambda p, f: f(p), params, shard)
config.add_basic_configurations(
attn_mechanism=self.attn_mechanism,
block_k=self.block_k,
Expand Down Expand Up @@ -171,7 +171,8 @@ def create_test_for_models(
params=params,
return_dict=True,
add_params_field=False,
train=False
train=False,
determinstic=True
)
loss, _ = cross_entropy_loss_and_accuracy(
ed_output.logits,
Expand Down Expand Up @@ -293,6 +294,24 @@ def test_llama(self):
f"Llama model Failed [ERROR {err}]"
)

def test_mpt(self):
self.header_config = ed.MptConfig(
d_model=self.hidden_size,
n_heads=self.num_attention_heads,
n_layers=1,
ffn_config=ed.DbrxFFNConfig(
ffn_hidden_size=self.intermediate_size,
moe_top_k=self.num_experts_per_tok,
moe_num_experts=self.num_local_experts,
),
attn_config=ed.MptAttentionConfig()
)
res, err = self.create_test_for_models("mpt", transformers.MptForCausalLM)
self.assertTrue(
res,
f"MPT model Failed [ERROR {err}]"
)

def test_falcon(self):
res, err = self.create_test_for_models("falcon", transformers.FalconForCausalLM)
self.assertTrue(
Expand Down
2 changes: 2 additions & 0 deletions src/python/easydel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

from .modules.mosaic_mpt import (
MptConfig as MptConfig,
MptAttentionConfig as MptAttentionConfig,
FlaxMptForCausalLM as FlaxMptForCausalLM,
FlaxMptModel as FlaxMptModel
)
Expand Down Expand Up @@ -318,6 +319,7 @@

# Mpt Models
"MptConfig",
"MptAttentionConfig",
"FlaxMptForCausalLM",
"FlaxMptModel",

Expand Down
3 changes: 2 additions & 1 deletion src/python/easydel/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
FlaxMptModel as FlaxMptModel,
FlaxMptForCausalLM as FlaxMptForCausalLM,
MptConfig as MptConfig,
MptAttentionConfig as MptAttentionConfig
)
from .falcon import (
FlaxFalconModel as FlaxFalconModel,
Expand Down Expand Up @@ -166,7 +167,7 @@

"FlaxLTModel", "FlaxLTForCausalLM", "FlaxLTConfig",

"FlaxMptModel", "FlaxMptForCausalLM", "MptConfig",
"FlaxMptModel", "FlaxMptForCausalLM", "MptConfig", "MptAttentionConfig",

"FlaxFalconModel", "FlaxFalconForCausalLM", "FalconConfig",

Expand Down
7 changes: 5 additions & 2 deletions src/python/easydel/modules/auto_easydel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ def get_modules_by_type(model_type: str) -> Tuple[
_FlaxMptForCausalLM,
functools.partial(
huggingface_to_easydel,
embedding_layer_names="wte",
rnn_based_or_rwkv=False
embedding_layer_names=["wte"],
rnn_based_or_rwkv=False,
layer_norm_names=[
"norm_1", "norm_2","norm_f"
]
)
)

Expand Down
12 changes: 6 additions & 6 deletions src/python/easydel/modules/llama/modelling_llama_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def __call__(
causal_mask=causal_mask
)


attn_output = self._merge_heads(attentions.attention_outputs)
if self.config.shard_attention_computation:
attn_output = with_sharding_constraint(
Expand Down Expand Up @@ -1012,11 +1011,12 @@ class FlaxLlamaForCausalLMModule(nn.Module):
precision: Optional[Union[jax.lax.Precision, str]] = None

def setup(self):
self.model = FlaxLlamaModule(self.config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)
self.model = FlaxLlamaModule(
self.config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)

self.lm_head = Linear(
self.config.vocab_size,
Expand Down
7 changes: 5 additions & 2 deletions src/python/easydel/modules/mosaic_mpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from .mosaic_configuration import MptConfig
from .mosaic_configuration import (
MptConfig as MptConfig,
MptAttentionConfig as MptAttentionConfig
)
from .modelling_mpt_flax import (
FlaxMptForCausalLM,
FlaxMptForCausalLMModule,
FlaxMptModel,
FlaxMptModule
)

__all__ = "FlaxMptModel", "FlaxMptForCausalLM", "MptConfig"
__all__ = "FlaxMptModel", "FlaxMptForCausalLM", "MptConfig", "MptAttentionConfig"
Loading

0 comments on commit 0d7d977

Please sign in to comment.