diff --git a/src/instructlab/dolomite/hf_models/model_conversion/llama.py b/src/instructlab/dolomite/hf_models/model_conversion/llama.py index cf028ab..e8e6754 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/llama.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/llama.py @@ -229,6 +229,9 @@ def export_to_huggingface_llama( config.num_key_value_heads, config.n_embd // config.n_head, AttentionHeadType(config.attention_head_type), + m_emb=config.m_emb, + m_residual=config.m_residual, + # m_width=config.m_width, ) SafeTensorsWeightsManager.save_state_dict(state_dict, save_path) @@ -285,11 +288,25 @@ def _export_state_dict_to_huggingface( num_key_value_heads: int, head_dim: int, attention_head_type: AttentionHeadType, + m_residual: float = None, + m_emb: float = None, + m_width: float = None, ) -> None: + if m_residual is None: + m_residual = 1. + if m_emb is None: + m_emb = 1. + + # NOTE: this will not work since the norms are tied + # has_m_width = False + # if m_width is None: + # has_m_width = True + # m_width = 1. + state_dict = { "model.embed_tokens.weight": safetensors_weight_manager.get_tensor( "transformer.wte.weight" - ), + ) * m_emb, "model.norm.weight": safetensors_weight_manager.get_tensor( "transformer.ln_f.weight" ), @@ -298,7 +315,12 @@ def _export_state_dict_to_huggingface( if safetensors_weight_manager.has_tensor("lm_head.weight"): state_dict["lm_head.weight"] = safetensors_weight_manager.get_tensor( "lm_head.weight" - ) + ) / m_width + # elif has_m_width: + # # int this we cannot tie + # state_dict["lm_head.weight"] = safetensors_weight_manager.get_tensor( + # "transformer.wte.weight" + # ) / m_width for layer_idx in range(num_layers): state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = ( @@ -332,13 +354,13 @@ def _export_state_dict_to_huggingface( state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = ( safetensors_weight_manager.get_tensor( f"transformer.h.{layer_idx}.mlp.c_proj.weight" - ) + ) * m_residual ) if f"transformer.h.{layer_idx}.mlp.c_proj.bias" in safetensors_weight_manager: state_dict[f"model.layers.{layer_idx}.mlp.down_proj.bias"] = ( safetensors_weight_manager.get_tensor( f"transformer.h.{layer_idx}.mlp.c_proj.bias" - ) + ) * m_residual ) query_weight, key_weight, value_weight = ( @@ -376,12 +398,12 @@ def _export_state_dict_to_huggingface( safetensors_weight_manager.get_tensor( f"transformer.h.{layer_idx}.attn.c_proj.weight" ) - ) + ) * m_residual if f"transformer.h.{layer_idx}.attn.c_proj.bias" in safetensors_weight_manager: state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.bias"] = ( safetensors_weight_manager.get_tensor( f"transformer.h.{layer_idx}.attn.c_proj.bias" ) - ) + ) * m_residual return state_dict