From 65744ac6f44eed4a572a0e33d86f03902d3d6a7b Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:33:18 -0500 Subject: [PATCH] Shard load, full tensor sendaround --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 3b1258999..f68204ff5 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -604,7 +604,7 @@ def load_distributed_state_dict( """ base = loader.state_dict() nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] - dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Replicate()], True) + dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)], True) inp = {"state":deepcopy(base), "dstate":dstate} # Read distributed state dict reader = checkpoint.FileSystemReader(path)