Skip to content

Commit

Permalink
Shard load, full tensor sendaround
Browse files Browse the repository at this point in the history
  • Loading branch information
daviswer committed Feb 6, 2025
1 parent 4675681 commit 65744ac
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchdata/stateful_dataloader/ibm_rescalable.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def load_distributed_state_dict(
"""
base = loader.state_dict()
nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"]
dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Replicate()], True)
dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)], True)
inp = {"state":deepcopy(base), "dstate":dstate}
# Read distributed state dict
reader = checkpoint.FileSystemReader(path)
Expand Down

0 comments on commit 65744ac

Please sign in to comment.