diff --git a/dolomite_engine/hf_models/modeling_utils_TP/TP.py b/dolomite_engine/hf_models/modeling_utils_TP/TP.py index 62298373..81e1cd5f 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/TP.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/TP.py @@ -97,6 +97,7 @@ def modify_state_dict_to_dtensor_dict(module: nn.Module, state_dict: dict, prefi device_mesh = param.device_mesh placements = param.placements result[key] = DTensor.from_local(tensor, device_mesh=device_mesh, placements=placements) + return result