diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index 99255ccf785d2..8352b4cd0d42b 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -378,11 +378,20 @@ class DTensorSpec: def __hash__(self) -> int: # hashing and equality check for DTensorSpec are used to cache the sharding - # propagation results. We only need to consider the mesh, placements and shape + # propagation results. We only need to consider the mesh, placements, shape + # dtype and stride. # Caveat: we need to keep this in mind and sync hash and eq if we add more - # fields to them, + # fields to them. if self.tensor_meta is not None: - return hash((self.mesh, self.placements, self.tensor_meta.shape)) + return hash( + ( + self.mesh, + self.placements, + self.tensor_meta.shape, + self.tensor_meta.dtype, + self.tensor_meta.stride, + ) + ) else: return hash((self.mesh, self.placements))