Skip to content

Commit

Permalink
[8/n][FSDP] make use_dtensor=True work with offload_to_cpu=True for o…
Browse files Browse the repository at this point in the history
…ptim.load_state_dict() (pytorch#105690)

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#105690
Approved by: https://github.com/fegin
  • Loading branch information
wz337 authored and pytorchmergebot committed Jul 21, 2023
1 parent 72b223c commit 6b2d48e
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 4 deletions.
64 changes: 60 additions & 4 deletions test/distributed/fsdp/test_fsdp_dtensor_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def get_input(self):
class TestDtensorShardedOptimStateDict(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
def test_dtensor_sharded_optim_state_dict(self):
@parametrize("offload_to_cpu", [True, False])
def test_dtensor_sharded_optim_state_dict(self, offload_to_cpu):
model = FSDP(TestDummyModel().cuda())
optim = torch.optim.Adam(model.parameters(), lr=0.1)
model(model.get_input()).sum().backward()
Expand All @@ -58,7 +59,7 @@ def test_dtensor_sharded_optim_state_dict(self):
model,
StateDictType.SHARDED_STATE_DICT,
optim_state_dict_config=ShardedOptimStateDictConfig(
use_dtensor=True, offload_to_cpu=False
use_dtensor=True, offload_to_cpu=offload_to_cpu
),
)
dtensor_osd = FSDP.optim_state_dict(model, optim)
Expand Down Expand Up @@ -90,13 +91,68 @@ def test_dtensor_sharded_optim_state_dict(self):
# check whether device are the same
self.assertEqual(v1.to_local().device, v2.local_tensor().device)

@with_comms
@skip_if_lt_x_gpu(2)
@parametrize("offload_to_cpu", [True, False])
def test_dtensor_sharded_optim_load_state_dict(self, offload_to_cpu):
model = FSDP(TestDummyModel().cuda())
optim = torch.optim.Adam(model.parameters(), lr=0.1)
model(model.get_input()).sum().backward()
optim.step()

FSDP.set_state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
optim_state_dict_config=ShardedOptimStateDictConfig(
use_dtensor=True,
offload_to_cpu=offload_to_cpu,
),
)

checkpoint = io.BytesIO()
torch.save(FSDP.optim_state_dict(model, optim), checkpoint)
# Deepcopy to save current optim_state_dict to compare with the optim_state_dict loaded back below.
ref_optim_state_dict = deepcopy(FSDP.optim_state_dict(model, optim))

# Update the parameters so FSDP.optim_state_dict() will be different from ref_optim_state_dict.
model(model.get_input()).sum().backward()
optim.step()

# Load ref_optim_state_dict back.
checkpoint.seek(0)
load_ref_optim_state_dict = torch.load(checkpoint)
optim.load_state_dict(
FSDP.optim_state_dict_to_load(model, optim, load_ref_optim_state_dict)
)
new_optim_state_dict = FSDP.optim_state_dict(model, optim)

# Check whether new_optim_state_dict is the same as ref_optim_state_dict.
for new_optim_state_dict, ref_optim_state_dict in zip(
new_optim_state_dict["state"].items(),
ref_optim_state_dict["state"].items(),
):
# check FQN are the same
self.assertEqual(new_optim_state_dict[0], ref_optim_state_dict[0])
for new_optim_hyper_param, ref_optim_hyper_param in zip(
new_optim_state_dict[1].items(),
ref_optim_state_dict[1].items(),
):
k1, v1 = new_optim_hyper_param
k2, v2 = ref_optim_hyper_param
# check whether keys are the same
self.assertEqual(k1, k2)
# check whether DTensor are the same
self.assertEqual(v1, v2)


# TODO: consolidate test cases once we test all DTensor usage
class TestDtensorShardedModelStateDict(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
@parametrize("offload_to_cpu", [True, False])
def test_dtensor_sharded_model_state_dict(self, offload_to_cpu):
# Compare the result of SHARDED_STATE_DICT using ShardedTensor and DTensor
# and check whether they are identical
model = FSDP(TestDummyModel().cuda())
model(model.get_input()).sum().backward()

Expand Down Expand Up @@ -150,7 +206,7 @@ def test_dtensor_sharded_model_load_state_dict(self, offload_to_cpu):

checkpoint = io.BytesIO()
torch.save(model.state_dict(), checkpoint)
# Deepcopy to save current state_dict to compare with the loaded state dict below.
# Deepcopy to save current state_dict to compare with the state_dict loaded back below.
ref_state_dict = deepcopy(model.state_dict())

# Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
Expand All @@ -159,7 +215,6 @@ def test_dtensor_sharded_model_load_state_dict(self, offload_to_cpu):

# Load ref_state_dict back.
checkpoint.seek(0)
# Test both parameters in state_dict are loaded to CPU and GPU.
load_ref_state_dict = torch.load(checkpoint)
model.load_state_dict(load_ref_state_dict)
new_state_dict = model.state_dict()
Expand All @@ -172,6 +227,7 @@ def test_dtensor_sharded_model_load_state_dict(self, offload_to_cpu):
self.assertEqual(v1, v2)


instantiate_parametrized_tests(TestDtensorShardedOptimStateDict)
instantiate_parametrized_tests(TestDtensorShardedModelStateDict)
if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torch/distributed/fsdp/_shard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def _gather_state_dict(
else:
tensor = output_tensor
elif isinstance(tensor, DTensor):
if tensor.device != tensor.device_mesh.device_type:
tensor = tensor.to(tensor.device_mesh.device_type)
tensor = tensor.redistribute(
device_mesh=tensor.device_mesh, placements=[Replicate()]
)
Expand Down

0 comments on commit 6b2d48e

Please sign in to comment.