diff --git a/test/distributed/fsdp/test_fsdp_dtensor_state_dict.py b/test/distributed/fsdp/test_fsdp_dtensor_state_dict.py index 971ac065032a5..9a0dada42d177 100644 --- a/test/distributed/fsdp/test_fsdp_dtensor_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_dtensor_state_dict.py @@ -48,7 +48,8 @@ def get_input(self): class TestDtensorShardedOptimStateDict(DTensorTestBase): @with_comms @skip_if_lt_x_gpu(2) - def test_dtensor_sharded_optim_state_dict(self): + @parametrize("offload_to_cpu", [True, False]) + def test_dtensor_sharded_optim_state_dict(self, offload_to_cpu): model = FSDP(TestDummyModel().cuda()) optim = torch.optim.Adam(model.parameters(), lr=0.1) model(model.get_input()).sum().backward() @@ -58,7 +59,7 @@ def test_dtensor_sharded_optim_state_dict(self): model, StateDictType.SHARDED_STATE_DICT, optim_state_dict_config=ShardedOptimStateDictConfig( - use_dtensor=True, offload_to_cpu=False + use_dtensor=True, offload_to_cpu=offload_to_cpu ), ) dtensor_osd = FSDP.optim_state_dict(model, optim) @@ -90,6 +91,59 @@ def test_dtensor_sharded_optim_state_dict(self): # check whether device are the same self.assertEqual(v1.to_local().device, v2.local_tensor().device) + @with_comms + @skip_if_lt_x_gpu(2) + @parametrize("offload_to_cpu", [True, False]) + def test_dtensor_sharded_optim_load_state_dict(self, offload_to_cpu): + model = FSDP(TestDummyModel().cuda()) + optim = torch.optim.Adam(model.parameters(), lr=0.1) + model(model.get_input()).sum().backward() + optim.step() + + FSDP.set_state_dict_type( + model, + StateDictType.SHARDED_STATE_DICT, + optim_state_dict_config=ShardedOptimStateDictConfig( + use_dtensor=True, + offload_to_cpu=offload_to_cpu, + ), + ) + + checkpoint = io.BytesIO() + torch.save(FSDP.optim_state_dict(model, optim), checkpoint) + # Deepcopy to save current optim_state_dict to compare with the optim_state_dict loaded back below. + ref_optim_state_dict = deepcopy(FSDP.optim_state_dict(model, optim)) + + # Update the parameters so FSDP.optim_state_dict() will be different from ref_optim_state_dict. + model(model.get_input()).sum().backward() + optim.step() + + # Load ref_optim_state_dict back. + checkpoint.seek(0) + load_ref_optim_state_dict = torch.load(checkpoint) + optim.load_state_dict( + FSDP.optim_state_dict_to_load(model, optim, load_ref_optim_state_dict) + ) + new_optim_state_dict = FSDP.optim_state_dict(model, optim) + + # Check whether new_optim_state_dict is the same as ref_optim_state_dict. + for new_optim_state_dict, ref_optim_state_dict in zip( + new_optim_state_dict["state"].items(), + ref_optim_state_dict["state"].items(), + ): + # check FQN are the same + self.assertEqual(new_optim_state_dict[0], ref_optim_state_dict[0]) + for new_optim_hyper_param, ref_optim_hyper_param in zip( + new_optim_state_dict[1].items(), + ref_optim_state_dict[1].items(), + ): + k1, v1 = new_optim_hyper_param + k2, v2 = ref_optim_hyper_param + # check whether keys are the same + self.assertEqual(k1, k2) + # check whether DTensor are the same + self.assertEqual(v1, v2) + # TODO: consolidate test cases once we test all DTensor usage class TestDtensorShardedModelStateDict(DTensorTestBase): @@ -97,6 +151,8 @@ class TestDtensorShardedModelStateDict(DTensorTestBase): @skip_if_lt_x_gpu(2) @parametrize("offload_to_cpu", [True, False]) def test_dtensor_sharded_model_state_dict(self, offload_to_cpu): + # Compare the result of SHARDED_STATE_DICT using ShardedTensor and DTensor + # and check whether they are identical model = FSDP(TestDummyModel().cuda()) model(model.get_input()).sum().backward() @@ -150,7 +206,7 @@ def test_dtensor_sharded_model_load_state_dict(self, offload_to_cpu): checkpoint = io.BytesIO() torch.save(model.state_dict(), checkpoint) - # Deepcopy to save current state_dict to compare with the loaded state dict below. + # Deepcopy to save current state_dict to compare with the state_dict loaded back below. ref_state_dict = deepcopy(model.state_dict()) # Update the parameters so model.state_dict() will be different from ref_dtensor_sd. @@ -159,7 +215,6 @@ def test_dtensor_sharded_model_load_state_dict(self, offload_to_cpu): # Load ref_state_dict back. checkpoint.seek(0) - # Test both parameters in state_dict are loaded to CPU and GPU. load_ref_state_dict = torch.load(checkpoint) model.load_state_dict(load_ref_state_dict) new_state_dict = model.state_dict() @@ -172,6 +227,7 @@ def test_dtensor_sharded_model_load_state_dict(self, offload_to_cpu): self.assertEqual(v1, v2) +instantiate_parametrized_tests(TestDtensorShardedOptimStateDict) instantiate_parametrized_tests(TestDtensorShardedModelStateDict) if __name__ == "__main__": run_tests() diff --git a/torch/distributed/fsdp/_shard_utils.py b/torch/distributed/fsdp/_shard_utils.py index 67adc803584d1..281263584b967 100644 --- a/torch/distributed/fsdp/_shard_utils.py +++ b/torch/distributed/fsdp/_shard_utils.py @@ -71,6 +71,8 @@ def _gather_state_dict( else: tensor = output_tensor elif isinstance(tensor, DTensor): + if tensor.device != tensor.device_mesh.device_type: + tensor = tensor.to(tensor.device_mesh.device_type) tensor = tensor.redistribute( device_mesh=tensor.device_mesh, placements=[Replicate()] )