From def738a5c8b1e25758edfb02b35c06ad821c912d Mon Sep 17 00:00:00 2001 From: "877825076@qq.com" <877825076@qq.com> Date: Thu, 11 Jan 2024 13:43:04 +0800 Subject: [PATCH] fix(modeling): norm weight should be fp32 --- internlm/model/modeling_internlm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index a47a5cdd..faf0d7b1 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -13,7 +13,7 @@ from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.random import _SEED_MANAGER -from internlm.core.naive_amp import set_output_attr_to_module +from internlm.core.naive_amp import set_fp32_attr_to_module, set_output_attr_to_module from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.initialize.launch import GLOBAL_SEED from internlm.model.embedding import Embedding1D @@ -113,6 +113,8 @@ def __init__( else: self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + set_fp32_attr_to_module(self.norm1) + set_fp32_attr_to_module(self.norm2) if use_swiglu: self.mlp = FeedForward( @@ -360,6 +362,7 @@ def __init__( self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) else: self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + set_fp32_attr_to_module(self.norm) self.head = head_cls( in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,