Skip to content

Commit

Permalink
[TP][DTensor Perf]Fix DTensor Spec hash (pytorch#107181)
Browse files Browse the repository at this point in the history
pytorch#106524 gets merged so fast that we didn't figure out that we should hash both stride and dtype in DTensorSpec. This is a forward fix.

One analysis for why using just shape is not enough.
1. We use the hash value for sharding propogation cache. And the output sharding contains the stride, size of the output DTensor. If we don't consider stride, we will see errors.
2. One reason can be found below:
```
OpSchema(func_schema=aten::t(Tensor(a) self) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]), placements=(Shard(dim=0),), tensor_meta=TensorMetadata(shape=torch.Size([64, 128]), dtype=torch.float32, requires_grad=False, stride=(128, 1), memory_format=None, is_quantized=False, qparams={})),), kwargs_schema={})
```

```
OpSchema(func_schema=aten::t(Tensor(a) self) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]), placements=(Shard(dim=0),), tensor_meta=TensorMetadata(shape=torch.Size([64, 128]), dtype=torch.float32, requires_grad=False, stride=(1, 64), memory_format=None, is_quantized=False, qparams={})),), kwargs_schema={})
```

The only difference between two op_schame is the tensor stride:
<img width="151" alt="image" src="https://github.com/pytorch/pytorch/assets/6937752/161335df-bdfb-47c5-ba79-82616d070d15">

that makes the transpose op generates wrong result and leads to the add_/addmm_ op failing with errors:
```
Traceback (most recent call last):
  File "/data/users/fduwjj/pytorch/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/data/users/fduwjj/pytorch/benchmarks/distributed/tensor/tp_benchmark.py", line 210, in run_tp
    output.sum().backward()
  File "/data/users/fduwjj/pytorch/torch/_tensor.py", line 491, in backward
    torch.autograd.backward(
  File "/data/users/fduwjj/pytorch/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/api.py", line 252, in __torch_dispatch__
    return op_dispatch.operator_dispatch(
  File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/dispatch.py", line 116, in operator_dispatch
    out, _, _ = _operator_dispatch(op_call, args, kwargs, sharding_propagator)
  File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/dispatch.py", line 246, in _operator_dispatch
    local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
  File "/data/users/fduwjj/pytorch/torch/_ops.py", line 435, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: The size of tensor a (64) must match the size of tensor b (8) at non-singleton dimension 1
```

Same thing with dtype, if we are using DTensor in the environment of mixed precision, we will run into situations like this.

Pull Request resolved: pytorch#107181
Approved by: https://github.com/wanchaol
ghstack dependencies: pytorch#106524
  • Loading branch information
fduwjj authored and pytorchmergebot committed Aug 15, 2023
1 parent 2d841bc commit d6c120d
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions torch/distributed/_tensor/placement_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,20 @@ class DTensorSpec:

def __hash__(self) -> int:
# hashing and equality check for DTensorSpec are used to cache the sharding
# propagation results. We only need to consider the mesh, placements and shape
# propagation results. We only need to consider the mesh, placements, shape
# dtype and stride.
# Caveat: we need to keep this in mind and sync hash and eq if we add more
# fields to them,
# fields to them.
if self.tensor_meta is not None:
return hash((self.mesh, self.placements, self.tensor_meta.shape))
return hash(
(
self.mesh,
self.placements,
self.tensor_meta.shape,
self.tensor_meta.dtype,
self.tensor_meta.stride,
)
)
else:
return hash((self.mesh, self.placements))

Expand Down

0 comments on commit d6c120d

Please sign in to comment.