From 0d7d97703086b3c642487d5cba4b0d99ab01b697 Mon Sep 17 00:00:00 2001 From: erfanzar Date: Thu, 16 May 2024 16:00:53 +0330 Subject: [PATCH] Fixing `MPT` model issue for being out dated --- python_test/test_models.py | 23 +- src/python/easydel/__init__.py | 2 + src/python/easydel/modules/__init__.py | 3 +- .../easydel/modules/auto_easydel_model.py | 7 +- .../modules/llama/modelling_llama_flax.py | 12 +- .../easydel/modules/mosaic_mpt/__init__.py | 7 +- .../modules/mosaic_mpt/modelling_mpt_flax.py | 614 +++++++++--------- .../mosaic_mpt/mosaic_configuration.py | 141 ++-- 8 files changed, 430 insertions(+), 379 deletions(-) diff --git a/python_test/test_models.py b/python_test/test_models.py index 5c5388f39..d55a4db8c 100644 --- a/python_test/test_models.py +++ b/python_test/test_models.py @@ -141,7 +141,7 @@ def create_test_for_models( partition_specs = match_partition_rules(config.get_partition_rules(True), params) shard, _ = make_shard_and_gather_fns(partition_specs, jnp.float32) - params = jax.tree_map(lambda p, f: f(p), params, shard) + params = jax.tree_util.tree_map(lambda p, f: f(p), params, shard) config.add_basic_configurations( attn_mechanism=self.attn_mechanism, block_k=self.block_k, @@ -171,7 +171,8 @@ def create_test_for_models( params=params, return_dict=True, add_params_field=False, - train=False + train=False, + determinstic=True ) loss, _ = cross_entropy_loss_and_accuracy( ed_output.logits, @@ -293,6 +294,24 @@ def test_llama(self): f"Llama model Failed [ERROR {err}]" ) + def test_mpt(self): + self.header_config = ed.MptConfig( + d_model=self.hidden_size, + n_heads=self.num_attention_heads, + n_layers=1, + ffn_config=ed.DbrxFFNConfig( + ffn_hidden_size=self.intermediate_size, + moe_top_k=self.num_experts_per_tok, + moe_num_experts=self.num_local_experts, + ), + attn_config=ed.MptAttentionConfig() + ) + res, err = self.create_test_for_models("mpt", transformers.MptForCausalLM) + self.assertTrue( + res, + f"MPT model Failed [ERROR {err}]" + ) + def test_falcon(self): res, err = self.create_test_for_models("falcon", transformers.FalconForCausalLM) self.assertTrue( diff --git a/src/python/easydel/__init__.py b/src/python/easydel/__init__.py index a07ffc409..b28f701f1 100644 --- a/src/python/easydel/__init__.py +++ b/src/python/easydel/__init__.py @@ -56,6 +56,7 @@ from .modules.mosaic_mpt import ( MptConfig as MptConfig, + MptAttentionConfig as MptAttentionConfig, FlaxMptForCausalLM as FlaxMptForCausalLM, FlaxMptModel as FlaxMptModel ) @@ -318,6 +319,7 @@ # Mpt Models "MptConfig", + "MptAttentionConfig", "FlaxMptForCausalLM", "FlaxMptModel", diff --git a/src/python/easydel/modules/__init__.py b/src/python/easydel/modules/__init__.py index 680dde8c7..66c904f05 100644 --- a/src/python/easydel/modules/__init__.py +++ b/src/python/easydel/modules/__init__.py @@ -25,6 +25,7 @@ FlaxMptModel as FlaxMptModel, FlaxMptForCausalLM as FlaxMptForCausalLM, MptConfig as MptConfig, + MptAttentionConfig as MptAttentionConfig ) from .falcon import ( FlaxFalconModel as FlaxFalconModel, @@ -166,7 +167,7 @@ "FlaxLTModel", "FlaxLTForCausalLM", "FlaxLTConfig", - "FlaxMptModel", "FlaxMptForCausalLM", "MptConfig", + "FlaxMptModel", "FlaxMptForCausalLM", "MptConfig", "MptAttentionConfig", "FlaxFalconModel", "FlaxFalconForCausalLM", "FalconConfig", diff --git a/src/python/easydel/modules/auto_easydel_model.py b/src/python/easydel/modules/auto_easydel_model.py index bc9fb1f91..a849ddb27 100644 --- a/src/python/easydel/modules/auto_easydel_model.py +++ b/src/python/easydel/modules/auto_easydel_model.py @@ -87,8 +87,11 @@ def get_modules_by_type(model_type: str) -> Tuple[ _FlaxMptForCausalLM, functools.partial( huggingface_to_easydel, - embedding_layer_names="wte", - rnn_based_or_rwkv=False + embedding_layer_names=["wte"], + rnn_based_or_rwkv=False, + layer_norm_names=[ + "norm_1", "norm_2","norm_f" + ] ) ) diff --git a/src/python/easydel/modules/llama/modelling_llama_flax.py b/src/python/easydel/modules/llama/modelling_llama_flax.py index 72d495c53..656108daf 100644 --- a/src/python/easydel/modules/llama/modelling_llama_flax.py +++ b/src/python/easydel/modules/llama/modelling_llama_flax.py @@ -370,7 +370,6 @@ def __call__( causal_mask=causal_mask ) - attn_output = self._merge_heads(attentions.attention_outputs) if self.config.shard_attention_computation: attn_output = with_sharding_constraint( @@ -1012,11 +1011,12 @@ class FlaxLlamaForCausalLMModule(nn.Module): precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self): - self.model = FlaxLlamaModule(self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - ) + self.model = FlaxLlamaModule( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) self.lm_head = Linear( self.config.vocab_size, diff --git a/src/python/easydel/modules/mosaic_mpt/__init__.py b/src/python/easydel/modules/mosaic_mpt/__init__.py index 7c3dc31fd..8405b8252 100644 --- a/src/python/easydel/modules/mosaic_mpt/__init__.py +++ b/src/python/easydel/modules/mosaic_mpt/__init__.py @@ -1,4 +1,7 @@ -from .mosaic_configuration import MptConfig +from .mosaic_configuration import ( + MptConfig as MptConfig, + MptAttentionConfig as MptAttentionConfig +) from .modelling_mpt_flax import ( FlaxMptForCausalLM, FlaxMptForCausalLMModule, @@ -6,4 +9,4 @@ FlaxMptModule ) -__all__ = "FlaxMptModel", "FlaxMptForCausalLM", "MptConfig" \ No newline at end of file +__all__ = "FlaxMptModel", "FlaxMptForCausalLM", "MptConfig", "MptAttentionConfig" diff --git a/src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py b/src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py index 7eae8ab5d..960d854f1 100644 --- a/src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py +++ b/src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py @@ -1,26 +1,24 @@ import math -from flax import linen as nn from flax.core import FrozenDict from typing import Optional, Union, Tuple - -from jax import numpy as jnp +from flax.linen import combine_masks +from jax import numpy as jnp, lax import jax -from jax.sharding import PartitionSpec from transformers.modeling_flax_outputs import FlaxCausalLMOutput, FlaxBaseModelOutput import flax from einops import rearrange from flax.linen.partitioning import remat -from ..attention_module import AttentionModule from ..flax_modelling_utils import ( get_gradient_checkpoint_policy, with_sharding_constraint, ACT2FN, - get_dot_general_by_bits, BaseJAXAttentionModule, block_wise_ffn + get_dot_general_by_bits, + BaseJAXAttentionModule, ) from ..easydel_modelling_utils import EasyDeLFlaxPretrainedModel import chex from fjformer.linen import Linear - +from fjformer import linen as nn from .mosaic_configuration import MptConfig @@ -31,12 +29,7 @@ class RMSNorm(nn.Module): param_dtype: jnp.dtype = jnp.float32 def setup(self) -> None: - self.weight = self.param( - 'kernel', - nn.initializers.ones, - (self.dim,), - self.param_dtype, - ) + self.weight = self.param("kernel", nn.initializers.ones, (self.dim,), self.param_dtype, ) def _norm(self, x: jnp.ndarray) -> jnp.ndarray: return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) @@ -44,7 +37,6 @@ def _norm(self, x: jnp.ndarray) -> jnp.ndarray: def __call__(self, x: jnp.ndarray) -> jnp.ndarray: x = x.astype(jnp.promote_types(self.dtype, jnp.bfloat16)) output = self._norm(x).astype(self.dtype) - weight = nn.linen.control_quantization(self.weight, self.dtype) return output * weight @@ -56,32 +48,39 @@ class FlaxMptMLP(nn.Module): precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self) -> None: - self.up = Linear( - self.config.d_model * self.config.expansion_ratio, - kernel_init=jax.nn.initializers.normal(), + self.up_proj = Linear( + self.config.expansion_ratio * self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), use_bias=self.config.use_bias, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, **get_dot_general_by_bits(self.config.bits, self.config.easy_method) ) - self.down = Linear( - self.config.d_model, - kernel_init=jax.nn.initializers.normal(), + self.down_proj = Linear( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), use_bias=self.config.use_bias, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, **get_dot_general_by_bits(self.config.bits, self.config.easy_method) ) - self.act = ACT2FN[self.config.act_fn] + self.hidden_dropout = nn.Dropout(self.config.attn_config.attn_pdrop) def __call__( - self, - hidden_states: chex.Array, - e: bool = True # Ignored + self, hidden_states: chex.Array, residual: chex.Array, deterministic: bool = False ): - return self.down(self.act(self.up(hidden_states))) + return self.hidden_dropout( + self.down_proj( + jax.nn.gelu( + self.up_proj( + hidden_states + ), approximate=False + ) + ), + deterministic=deterministic + ) + residual class FlaxMptAttention(BaseJAXAttentionModule): @@ -92,75 +91,42 @@ class FlaxMptAttention(BaseJAXAttentionModule): def setup(self) -> None: - self.w_qkv = Linear( - self.config.d_model * 3, - kernel_init=jax.nn.initializers.normal(), + self.Wqkv = Linear( + self.config.hidden_size * 3, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), use_bias=self.config.use_bias, **get_dot_general_by_bits(self.config.bits, self.config.easy_method), dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision) - self.wo = Linear( - self.config.d_model, - kernel_init=jax.nn.initializers.normal(), + self.out_proj = Linear( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), use_bias=self.config.use_bias, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, **get_dot_general_by_bits(self.config.bits, self.config.easy_method) ) - self.attention_performer = AttentionModule( - use_sharding_constraint=self.config.use_sharding_constraint, - block_k_major=self.config.block_k_major, - block_b=self.config.block_b, - block_q=self.config.block_q, - block_k=self.config.block_k, - block_q_major_dkv=self.config.block_q_major_dkv, - block_k_major_dkv=self.config.block_k_major_dkv, - block_k_major_dq=self.config.block_k_major_dq, - block_k_dkv=self.config.block_k_dkv, - block_q_dkv=self.config.block_q_dkv, - block_q_dq=self.config.block_q_dq, - block_k_dq=self.config.block_k_dq, - num_attention_heads=self.config.num_attention_heads, - attention_dropout=self.config.attention_dropout, - head_dims=self.head_dim, - attention_partition_spec=self.config.attention_partition_spec, - shard_attention_computation=self.config.shard_attention_computation, - precision=self.precision, - force_float32_tpu=True, - attn_mechanism=self.config.attn_mechanism, - dtype=self.dtype, - bias_partition_spec=self.config.bias_partition_spec, - key_partition_spec=self.config.key_partition_spec, - query_partition_spec=self.config.query_partition_spec, - generation_query_partition_spec=self.config.generation_query_partition_spec, - generation_bias_partition_spec=self.config.generation_bias_partition_spec, - generation_attention_partition_spec=self.config.generation_attention_partition_spec, - value_partition_spec=self.config.value_partition_spec, - scan_ring_attention=self.config.scan_ring_attention, - mesh=self.config.jax_mesh(), - sm_scale=1 / math.sqrt(self.config.n_heads), - axis_name=self.config.attention_axis_name, - backward_pass_impl=self.config.flash_attention_backward_pass_impl - ) - if self.config.qk_ln: - self.q_ln = nn.LayerNorm(use_bias=self.config.use_norm_bias) - self.k_ln = nn.LayerNorm(use_bias=self.config.use_norm_bias) - self.causal_mask = flax.linen.make_causal_mask( - jnp.ones( - (1, self.config.max_seq_len), - dtype="bool" - ), dtype="bool" - ) + self.dropout = nn.Dropout(self.config.attn_config.attn_pdrop) - def __call__(self, - hidden_states: chex.Array, - attention_mask: chex.Array, - position_ids: chex.Array, - attn_bias: chex.Array = None, - init_cache: bool = False - ): + self.hidden_size = self.config.hidden_size + self.n_heads = self.config.n_heads + self.max_seq_length = self.config.max_seq_len + self.head_dim = self.hidden_size // self.n_heads + self.softmax_scale = self.config.attn_config.softmax_scale + if self.softmax_scale is None: + self.softmax_scale = 1 / math.sqrt(self.hidden_size / self.n_heads) + + def __call__( + self, + hidden_states: chex.Array, + attention_mask: chex.Array, + position_bias: chex.Array, + causal_mask: chex.Array, + init_cache: bool = False, + deterministic: bool = False + ): """ The __call__ function is the main function of a JAX module. @@ -171,65 +137,87 @@ def __call__(self, :param self: Access variables that belong to the class :param hidden_states: chex.Array: Pass the input to the attention layer :param attention_mask: chex.Array: Mask out certain positions in the sequence - :param position_ids: chex.Array: Specify the position of each token in the sequence - :param attn_bias: chex.Array: Add a bias to the attention scores + :param position_bias: chex.Array: Add a bias to the attention scores + :param causal_mask: chex.Array: Mask out certain positions in the sequence :param init_cache: bool: Initialize the cache + :param deterministic: bool: deterministic to activate dropouts and detect training process :return: The output of the attention layer """ inp_shape = hidden_states.shape - b, s, ds = inp_shape - qkv = self.w_qkv(hidden_states) - q, k, v = jnp.split(qkv, 3, -1) - if self.config.qk_ln: - q = self.q_ln(q) - k = self.k_ln(k) - - q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_heads) - k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_heads) - v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_heads) - attention_mask = attention_mask.reshape(b, 1, 1, -1) - if self.has_variable('cache', 'key_states') or init_cache: - k, v, attention_mask = self._concatenate_to_cache(key_states=k, value=v, query=q, - attention_mask=attention_mask) - # TODO: MPT WONT WORK CAUSE OF NEW ATTENTION MEC ON FJFORMER - - # if self.config.use_sharding_constraint: - # q = with_sharding_constraint( - # q, jax.sharding.PartitionSpec(("dp", "fsdp"), "sp" if q.shape[1] != 1 else None, "tp",None) - # ) - # k = with_sharding_constraint(k, jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None)) - # v = with_sharding_constraint(v, jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp",None)) - q_l = q.shape[1] - k_l = k.shape[1] - dropout_rng = None - deterministic = False - if deterministic: - dropout_rng = self.make_rng("dropout") - - d = q.shape[-1] - attn_output = jnp.einsum('...qhd,...khd->...hqk', q, k, precision=self.precision) * jax.lax.rsqrt( - jnp.asarray(d).astype(v.dtype)) - attn_output = with_sharding_constraint(attn_output, PartitionSpec( - ("dp", "fsdp"), - "sp" if attn_output.shape[1] != 1 else None, - None, - None) - ) - if attn_bias is not None: - attn_output += attn_bias[:, :, :, :attn_output.shape[-1]] - mask = jnp.where(self.causal_mask == 1, 0, jnp.finfo(attn_output).min) - if attention_mask is not None: - attention_mask = jnp.where( - attention_mask == 1, - 0, - jnp.finfo(attn_output).min + mixed_qkv = self.Wqkv(hidden_states) + query_states, key_states, value_states = jnp.split(mixed_qkv, 3, -1) + + query_states = rearrange(query_states, "b s (h d) -> b s h d", h=self.config.n_heads) + key_states = rearrange(key_states, "b s (h d) -> b s h d", h=self.config.n_heads) + value_states = rearrange(value_states, "b s (h d) -> b s h d", h=self.config.n_heads) + query_length, key_length = query_states.shape[1], key_states.shape[1] + + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + causal_mask, + (0, 0, mask_shift, 0), + (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = causal_mask[:, :, :query_length, :key_length] + + batch_size = hidden_states.shape[0] + causal_mask = jnp.broadcast_to( + causal_mask, (batch_size,) + causal_mask.shape[1:]) + attention_mask = jnp.broadcast_to(jnp.expand_dims( + attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + if attention_mask.ndim == 2: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + if self.has_variable("cache", "cached_key") or init_cache: + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, + value_states, + query_states, + attention_mask ) - attn_output += attention_mask - attn_output += mask[:, :, :attn_output.shape[-2], :attn_output.shape[-1]] - attn_output = nn.softmax(attn_output, -1) - attn_output = jnp.einsum('...hqk,...khd->...qhd', attn_output, v) - return self.wo(attn_output.reshape(inp_shape)) + attention_bias = lax.select( + attention_mask.astype("bool"), + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo( + self.dtype).min).astype(self.dtype), + ) + qps = self.config.generation_query_partition_spec if query_length == 1 else self.config.query_partition_spec + kps = self.config.value_partition_spec + vps = self.config.key_partition_spec + bps = self.config.generation_bias_partition_spec if query_length == 1 else self.config.bias_partition_spec + aps = self.config.generation_attention_partition_spec if query_length == 1 else self.config.attention_partition_spec + query_states = with_sharding_constraint(query_states, qps) + key_states = with_sharding_constraint(key_states, kps) + value_states = with_sharding_constraint(value_states, vps) + attention_bias = with_sharding_constraint(attention_bias, bps) + attention_scores = jnp.einsum( + "bqhd,bkhd->bhqk", query_states, key_states, precision=self.precision + ) * self.softmax_scale + + if position_bias is not None: + key_length = key_states.shape[1] + + position_bias_query_index = max(0, position_bias.shape[2] - query_length) + position_bias_key_index = max(0, position_bias.shape[3] - key_length) + + position_bias = position_bias[:, :, position_bias_query_index:, position_bias_key_index:] + attention_scores = attention_scores + position_bias + attn_weights = jax.nn.softmax((attention_scores + attention_bias).astype("float32"), axis=-1) + attn_weights = self.dropout(attn_weights, deterministic=deterministic) + context_states = with_sharding_constraint( + jnp.einsum( + "bhqk,bkhd->bqhd", attn_weights, value_states, precision=self.precision + ), + aps + ) + attn_output = self.out_proj(context_states.reshape(inp_shape)) + + return attn_output, attn_weights class FlaxMptBlock(nn.Module): @@ -239,33 +227,39 @@ class FlaxMptBlock(nn.Module): precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self) -> None: - self.norm_1 = nn.LayerNorm(use_bias=self.config.use_norm_bias) - self.norm_2 = nn.LayerNorm(use_bias=self.config.use_norm_bias) attn_block = FlaxMptAttention mlp_block = FlaxMptMLP if self.config.gradient_checkpointing != "": mlp_block = remat( mlp_block, policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), - static_argnums=(1,) + static_argnums=(2,) ) - # hidden_states: chex.Array - # attention_mask: chex.Array - # position_ids: chex.Array - # attn_bias: chex.Array = None - # init_cache: bool = False - attn_block = remat( attn_block, policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), - static_argnums=(4,) + static_argnums=(3, 4, 5) ) + + self.norm_1 = nn.LayerNorm( + epsilon=self.config.layer_norm_epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=self.config.use_norm_bias + ) self.attn = attn_block( config=self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision ) + + self.norm_2 = nn.LayerNorm( + epsilon=self.config.layer_norm_epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=self.config.use_norm_bias + ) self.ffn = mlp_block( config=self.config, dtype=self.dtype, @@ -273,46 +267,37 @@ def setup(self) -> None: precision=self.precision ) - def __call__(self, - hidden_states: chex.Array, - attention_mask: chex.Array, - position_ids: chex.Array, - attn_bias: chex.Array = None, - init_cache: bool = False - ): - - # hidden_states: chex.Array - # attention_mask: chex.Array - # position_ids: chex.Array - # attn_bias: chex.Array = None - # init_cache: bool = False - - hidden_states = ( - self.attn( - self.norm_1(hidden_states), - attention_mask, - position_ids, - attn_bias, - init_cache - ) + hidden_states + self.dropout_rate = self.config.attn_config.attn_pdrop + self.resid_attn_dropout = nn.Dropout(self.dropout_rate) + + def __call__( + self, + hidden_states: chex.Array, + attention_mask: chex.Array, + position_bias: chex.Array, + causal_mask: chex.Array, + init_cache: bool = False, + deterministic: bool = False, + output_attentions: bool = False, + ): + attn_outputs, attn_weights = self.attn( + self.norm_1(hidden_states), + attention_mask, + position_bias, + causal_mask, + init_cache, + deterministic ) - ffn_input = self.norm_2(hidden_states) - if self.config.use_scan_mlp: - feed_forward_hidden_states = block_wise_ffn( - self.ffn, - hidden_states, - self.config.scan_mlp_chunk_size, - False - ) - else: - feed_forward_hidden_states = self.ffn( - hidden_states, - False, - ) - return feed_forward_hidden_states + hidden_states + hidden_states = self.resid_attn_dropout(attn_outputs, deterministic=deterministic) + hidden_states + output = self.ffn(self.norm_2(hidden_states), hidden_states) + outputs = (output,) + if output_attentions: + outputs += (attn_weights,) + + return outputs # hidden_states, attentions -class FlaxMptCollection(nn.Module): +class FlaxMptDecoratorCollection(nn.Module): config: MptConfig dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32 @@ -337,37 +322,48 @@ def __call__( self, hidden_states: chex.Array, attention_mask: chex.Array, - position_ids: chex.Array, - attn_bias: chex.Array = None, + position_bias: chex.Array, + causal_mask: chex.Array, init_cache: bool = False, + deterministic: bool = False, + output_attentions: bool = False, output_hidden_states: bool = True ): all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None for block in self.blocks: - hidden_states = block( + output = block( hidden_states=hidden_states, - attn_bias=attn_bias, + deterministic=deterministic, attention_mask=attention_mask, - position_ids=position_ids, - init_cache=init_cache + causal_mask=causal_mask, + output_attentions=output_attentions, + init_cache=init_cache, + position_bias=position_bias, ) - + hidden_states = output[0] + if output_attentions: + all_attentions += (output[-1],) if output_hidden_states: all_hidden_states += (hidden_states,) - return hidden_states, all_hidden_states - - -def build_alibi(max_length, num_attention_heads, alibi_max: int = 8): - w_range = jnp.arange(1 - max_length, 1).reshape(1, 1, 1, max_length) - # cp2 = jnp.power(2, jnp.ceil(jnp.log2(num_attention_heads))) - cp2 = 2 ** math.ceil(math.log2(num_attention_heads)) - h_range = jnp.arange(1, 1 + num_attention_heads, ).reshape(1, -1, 1, 1) - h_range = jnp.matmul(h_range, jnp.asarray(alibi_max / cp2).reshape(1, 1)) - slop = 1 / jnp.power(2, h_range) - if cp2 != num_attention_heads: - slop = jnp.concatenate([slop[1::2], slop[::2]], axis=-1)[:num_attention_heads] - alibi = (w_range * slop).reshape(1, num_attention_heads, 1, max_length) + return hidden_states, all_hidden_states, all_attentions + + +def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8): + alibi = jnp.arange(1 - sequence_length, 1, dtype="i4").reshape(1, 1, 1, sequence_length) + num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads)) + jax.config.update("jax_enable_x64", True) + base = jnp.arange(1, num_heads_power_of_2 + 1, dtype=jnp.int64).astype("float32") + base = base * (alibi_bias_max / num_heads_power_of_2) + + slopes = 1.0 / jnp.pow(2, base) + slopes = slopes.reshape(1, num_heads_power_of_2, 1, 1) + + if num_heads_power_of_2 != num_heads: + slopes = jnp.concat([slopes[:, 1::2, ...], slopes[:, ::2, ...]], axis=1)[:, :num_heads, ...] + + alibi = alibi * slopes return alibi @@ -379,69 +375,86 @@ class FlaxMptModule(nn.Module): def setup(self) -> None: self.wte = nn.Embed(num_embeddings=self.config.vocab_size, features=self.config.d_model) - if not self.config.alibi: - self.wpe = nn.Embed(num_embeddings=self.config.vocab_size, features=self.config.max_seq_len) - self.h = FlaxMptCollection( + + self.blocks = FlaxMptDecoratorCollection( config=self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision ) - self.norm_f = nn.LayerNorm(use_bias=self.config.use_norm_bias) - self.alibi = build_alibi(self.config.max_seq_len, self.config.n_heads) + self.norm_f = nn.LayerNorm( + dtype=self.dtype, + param_dtype=self.param_dtype, + epsilon=self.config.layer_norm_epsilon, + use_bias=self.config.use_norm_bias, + ) + self.alibi = build_mpt_alibi_tensor( + sequence_length=self.config.max_seq_len, + num_heads=self.config.n_heads, + ) + self.causal_mask = jnp.tril( + jnp.ones( + (self.config.max_seq_len, self.config.max_seq_len), dtype="bool" + ) + ).reshape(1, 1, self.config.max_seq_len, self.config.max_seq_len) def __call__( self, input_ids: chex.Array, - attention_mask: chex.Array = None, - position_ids: chex.Array = None, + attention_mask: Optional[chex.Array] = None, + input_embeds: Optional[chex.Array] = None, + extra_embedding: Optional[chex.Array] = None, init_cache: bool = False, - return_dict: bool = True, + deterministic: bool = False, + output_attentions: bool = False, output_hidden_states: bool = True, - extra_embedding: Optional[Union[jnp.ndarray, None]] = None + return_dict: bool = True, ): - b, s = input_ids.shape - hidden_states = self.wte(input_ids) - hidden_states = hidden_states + extra_embedding if extra_embedding is not None else hidden_states + if input_embeds is None: + input_embeds = self.wte(input_ids) + hidden_states = input_embeds + extra_embedding if extra_embedding is not None else input_embeds - if self.config.alibi: - alibi = self.alibi - else: - pos_id = self.wpe(jnp.arange(s, dtype='i4').reshape(1, -1)) - hidden_states += pos_id - alibi = None - hidden_states, all_hidden_states = self.h( - hidden_states, - attn_bias=alibi, - attention_mask=attention_mask, - position_ids=position_ids, + hidden_states, all_hidden_states, all_attentions = self.blocks( + position_bias=self.alibi, + causal_mask=self.causal_mask, init_cache=init_cache, - output_hidden_states=output_hidden_states - ) - hidden_states = self.norm_f( - hidden_states + output_attentions=output_attentions, + attention_mask=attention_mask, + deterministic=deterministic, + output_hidden_states=output_hidden_states, + hidden_states=hidden_states, ) + hidden_states = self.norm_f(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) if return_dict: - return FlaxBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states) - else: - return hidden_states, all_hidden_states + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions + ) + + return ( + hidden_states, + all_hidden_states, + all_attentions + ) class FlaxMptPretrainedModel(EasyDeLFlaxPretrainedModel): module_class: nn.Module = None config_class: MptConfig = MptConfig - def __init__(self, - config, - dtype: jnp.dtype = jnp.float32, - param_dtype: jnp.dtype = jnp.float32, - _do_init: bool = False, - precision: Optional[Union[jax.lax.Precision, None]] = jax.lax.Precision("fastest"), - input_shape: Tuple = (1, 16), - **kwargs - ): + def __init__( + self, + config, + dtype: jnp.dtype = jnp.float32, + param_dtype: jnp.dtype = jnp.float32, + precision: lax.PrecisionLike = None, + _do_init: bool = False, + input_shape: Tuple = (1, 16), + **kwargs + ): module = self.module_class( config, dtype=dtype, @@ -452,7 +465,7 @@ def __init__(self, def init_cache(self, batch_size, max_length): - input_ids = jnp.ones((batch_size, max_length), dtype='i4') + input_ids = jnp.ones((batch_size, max_length), dtype="i4") attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) @@ -467,60 +480,59 @@ def init_cache(self, batch_size, max_length): return init_variables["cache"] def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - input_ids = jnp.ones(input_shape, dtype='i4') + input_ids = jnp.ones(input_shape, dtype="i4") if params is None: return self.module.init( rngs=rng, input_ids=input_ids, - attention_mask=jnp.ones(input_shape, dtype='i4'), - position_ids=jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape), + attention_mask=jnp.ones(input_shape, dtype="i4"), init_cache=False - )['params'] + )["params"] else: return params - def __call__(self, - input_ids, - attention_mask=None, - past_key_values=None, - position_ids=None, - output_hidden_states: Optional[bool] = None, - init_cache: bool = False, - params: dict = None, - add_params_field: bool = False, - return_dict: bool = True, - extra_embedding: Optional[Union[jnp.ndarray, None]] = None, - **kwargs - ): + def __call__( + self, + input_ids: chex.Array, + attention_mask: Optional[chex.Array] = None, + input_embeds: Optional[chex.Array] = None, + extra_embedding: Optional[chex.Array] = None, + init_cache: bool = False, + deterministic: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = True, + return_dict: bool = True, + params: dict = None, + add_params_field: bool = False, + past_key_values: Optional[Tuple[Tuple[chex.Array]]] = None, + **kwargs + ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - params = {'params': params or self.params} if add_params_field else params or self.params - input_ids = jnp.asarray(input_ids, dtype='i4') + params = {"params": params or self.params} if add_params_field else params or self.params + input_ids = jnp.asarray(input_ids, dtype="i4") mutable = False if attention_mask is None: - attention_mask = jnp.ones_like(input_ids, dtype='i4') - if position_ids is None: - position_ids = jnp.arange(0, attention_mask.shape[-1], 1, dtype='i4').reshape( - 1, -1 - ).repeat(input_ids.shape[0], axis=0) - + attention_mask = jnp.ones_like(input_ids, dtype="i4") if past_key_values is not None: - params['cache'] = past_key_values - mutable = ['cache'] + params["cache"] = past_key_values + mutable = ["cache"] rngs = {} if self.config.bits is not None: - rngs['params'] = jax.random.key(0) + rngs["params"] = jax.random.key(0) predict = self.module.apply( params, - input_ids=input_ids, - attention_mask=jnp.asarray(attention_mask, dtype='i4'), - return_dict=return_dict, + input_ids, + attention_mask=attention_mask, + input_embeds=input_embeds, extra_embedding=extra_embedding, - position_ids=position_ids, init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict, mutable=mutable, rngs=rngs ) @@ -558,21 +570,27 @@ def setup(self) -> None: precision=self.precision ) - if self.config.use_lm_head: - self.lm_head = Linear(self.config.vocab_size, kernel_init=jax.nn.initializers.normal(), - use_bias=self.config.use_bias, - dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, - **get_dot_general_by_bits(self.config.bits, self.config.easy_method)) + self.lm_head = Linear( + self.config.vocab_size, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + use_bias=self.config.use_bias, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) def __call__( self, input_ids: chex.Array, - attention_mask: chex.Array = None, + attention_mask: Optional[chex.Array] = None, + input_embeds: Optional[chex.Array] = None, + extra_embedding: Optional[chex.Array] = None, init_cache: bool = False, - position_ids: chex.Array = None, - return_dict: bool = True, + deterministic: bool = False, + output_attentions: bool = False, output_hidden_states: bool = True, - extra_embedding: Optional[Union[jnp.ndarray, None]] = None + return_dict: bool = True, ): predict: FlaxBaseModelOutput = self.transformer( input_ids=input_ids, @@ -580,21 +598,26 @@ def __call__( return_dict=True, extra_embedding=extra_embedding, output_hidden_states=output_hidden_states, - position_ids=position_ids, - init_cache=init_cache + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + input_embeds=input_embeds ) + last_hidden_state = predict.last_hidden_state + if self.config.use_lm_head: - logits = self.lm_head(predict.last_hidden_state) + shared_kernel = self.model.variables["params"]["wte"]["embedding"] + shared_kernel = nn.linen.control_quantization(shared_kernel, self.param_dtype).T + logits = self.lm_head.apply( + {"params": {"kernel": shared_kernel}}, last_hidden_state) else: - logits = predict.last_hidden_state @ self.transformer.wte.embedding.T + logits = self.lm_head(last_hidden_state) if return_dict: - return FlaxCausalLMOutput( logits=logits, hidden_states=predict.hidden_states ) - else: - return logits, predict.hidden_states if output_hidden_states else (logits,) + return logits, predict.hidden_states if output_hidden_states else (logits,) class FlaxMptForCausalLM(FlaxMptPretrainedModel): @@ -619,26 +642,17 @@ def get_output_embeddings(self): return self.module.lm_head def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache( - batch_size, max_length - ) + past_key_values = self.init_cache(batch_size, max_length) extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = jax.lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) return { "past_key_values": past_key_values, "attention_mask": extended_attention_mask, - "position_ids": position_ids, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 return model_kwargs diff --git a/src/python/easydel/modules/mosaic_mpt/mosaic_configuration.py b/src/python/easydel/modules/mosaic_mpt/mosaic_configuration.py index 1cb978b04..2964bafc9 100644 --- a/src/python/easydel/modules/mosaic_mpt/mosaic_configuration.py +++ b/src/python/easydel/modules/mosaic_mpt/mosaic_configuration.py @@ -5,8 +5,58 @@ from ..easydel_modelling_utils import EasyDeLPretrainedConfig +class MptAttentionConfig(EasyDeLPretrainedConfig): + def __init__( + self, + attn_type="multihead_attention", + attn_pdrop=0, + attn_impl="torch", + clip_qkv=None, + softmax_scale=None, + prefix_lm=False, + qk_ln=False, + attn_uses_sequence_id=False, + alibi=True, + alibi_bias_max=8, + **kwargs, + ): + super().__init__() + self.attn_type = attn_type + self.attn_pdrop = attn_pdrop + self.attn_impl = attn_impl + self.clip_qkv = clip_qkv + self.softmax_scale = softmax_scale + self.prefix_lm = prefix_lm + self.attn_uses_sequence_id = attn_uses_sequence_id + self.alibi = alibi + self.qk_ln = qk_ln + self.alibi_bias_max = alibi_bias_max + + if attn_type not in ["multihead_attention", "multiquery_attention"]: + raise ValueError( + f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}" + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + **kwargs + ) -> "EasyDeLPretrainedConfig": + cls._set_token_in_kwargs(kwargs) + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + if config_dict.get("model_type") == "mpt": + config_dict = config_dict["attn_config"] + return cls.from_dict(config_dict, **kwargs) + + class MptConfig(EasyDeLPretrainedConfig): - model_type: str = 'mpt' + model_type = "mpt" + attribute_map = { + "num_attention_heads": "n_heads", + "hidden_size": "d_model", + "num_hidden_layers": "n_layers", + } def __init__(self, d_model: int = 2048, @@ -16,16 +66,21 @@ def __init__(self, max_seq_len: int = 2048, vocab_size: int = 50368, resid_prob_drop: float = 0.0, + layer_norm_epsilon: float = 1e-5, emb_prob_drop: float = 0.0, - alibi: bool = True, - use_bias: bool = False, learned_pos_emb: bool = True, - act_fn: str = 'gelu', + attn_config: MptAttentionConfig = None, + init_device: str = "cpu", logit_scale: Optional[Union[float, str]] = None, - no_bias: bool = False, + no_bias: bool = True, verbose: int = 0, embedding_fraction: float = 1.0, + norm_type: str = "low_precision_layernorm", use_cache: bool = False, + initializer_range=0.02, + alibi: bool = True, + use_bias: bool = False, + act_fn: str = "gelu", qk_ln: bool = False, use_lm_head: bool = False, use_norm_bias: bool = False, @@ -46,6 +101,7 @@ def __init__(self, self.use_bias = use_bias self.emb_prob_drop = emb_prob_drop self.gradient_checkpointing = gradient_checkpointing + self.norm_type = norm_type self.learned_pos_emb = learned_pos_emb self.act_fn = act_fn self.logit_scale = logit_scale @@ -53,15 +109,14 @@ def __init__(self, self.qk_ln = qk_ln self.alibi = alibi self.verbose = verbose + self.initializer_range = initializer_range self.embedding_fraction = embedding_fraction + self.init_device = init_device self.use_cache = use_cache self.bits = bits - + self.layer_norm_epsilon = layer_norm_epsilon self.from_pt = False - if 'name' in kwargs: - del kwargs['name'] - if 'loss_fn' in kwargs: - del kwargs['loss_fn'] + self.attn_config = attn_config super().__init__( bits=bits, **kwargs @@ -74,8 +129,7 @@ def _set_config_defaults(config, config_defaults): config[k] = v return config - @staticmethod - def get_partition_rules(fully_sharded_data_parallel: bool = False): + def get_partition_rules(self, fully_sharded_data_parallel: bool = True): return ( ("transformer/wte/embedding", PartitionSpec("dp", "fsdp")), @@ -126,62 +180,17 @@ def get_partition_rules(fully_sharded_data_parallel: bool = False): (".*", PartitionSpec(("fsdp", "sp"))), ) - def add_jax_args(self, - d_model: int = 2048, - n_heads: int = 16, - n_layers: int = 24, - expansion_ratio: int = 4, - max_seq_len: int = 2048, - vocab_size: int = 50368, - resid_prob_drop: float = 0.0, - emb_prob_drop: float = 0.0, - alibi: bool = True, - use_bias: bool = True, - learned_pos_emb: bool = True, - act_fn: str = 'gelu', - logit_scale: Optional[Union[float, str]] = None, - no_bias: bool = False, - verbose: int = 0, - embedding_fraction: float = 1.0, - use_cache: bool = False, - qk_ln: bool = True, - use_lm_head: bool = False, - use_norm_bias: bool = False, - gradient_checkpointing: str = "nothing_saveable", - bits: Optional[int] = None, - **kwargs, - ): - if hasattr(self, 'attn_config'): - for k, v in self.attn_config.items(): + def add_jax_args( + self, + gradient_checkpointing: str = "nothing_saveable", + bits: Optional[int] = None, + **kwargs, + ): + if hasattr(self, "attn_config"): + for k, v in self.attn_config.__dict__.items(): setattr(self, k, v) - basics = dict( - bits=bits, - d_model=d_model, - n_heads=n_heads, - n_layers=n_layers, - expansion_ratio=expansion_ratio, - max_seq_len=max_seq_len, - vocab_size=vocab_size, - resid_prob_drop=resid_prob_drop, - emb_prob_drop=emb_prob_drop, - alibi=alibi, - use_bias=use_bias, - learned_pos_emb=learned_pos_emb, - act_fn=act_fn, - logit_scale=logit_scale, - no_bias=no_bias, - verbose=verbose, - embedding_fraction=embedding_fraction, - use_cache=use_cache, - qk_ln=qk_ln, - use_lm_head=use_lm_head, - use_norm_bias=use_norm_bias, - gradient_checkpointing=gradient_checkpointing, - **kwargs - ) + basics = dict(bits=bits, gradient_checkpointing=gradient_checkpointing, **kwargs) for k, v in basics.items(): if not hasattr(self, k): - print(f' Key {k} not found in loaded config setting that to default of {v}') setattr(self, k, v) - self.from_pt = False