Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TP][DTensor Perf]Fix DTensor Spec hash (pytorch#107181)
pytorch#106524 gets merged so fast that we didn't figure out that we should hash both stride and dtype in DTensorSpec. This is a forward fix. One analysis for why using just shape is not enough. 1. We use the hash value for sharding propogation cache. And the output sharding contains the stride, size of the output DTensor. If we don't consider stride, we will see errors. 2. One reason can be found below: ``` OpSchema(func_schema=aten::t(Tensor(a) self) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]), placements=(Shard(dim=0),), tensor_meta=TensorMetadata(shape=torch.Size([64, 128]), dtype=torch.float32, requires_grad=False, stride=(128, 1), memory_format=None, is_quantized=False, qparams={})),), kwargs_schema={}) ``` ``` OpSchema(func_schema=aten::t(Tensor(a) self) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]), placements=(Shard(dim=0),), tensor_meta=TensorMetadata(shape=torch.Size([64, 128]), dtype=torch.float32, requires_grad=False, stride=(1, 64), memory_format=None, is_quantized=False, qparams={})),), kwargs_schema={}) ``` The only difference between two op_schame is the tensor stride: <img width="151" alt="image" src="https://github.com/pytorch/pytorch/assets/6937752/161335df-bdfb-47c5-ba79-82616d070d15"> that makes the transpose op generates wrong result and leads to the add_/addmm_ op failing with errors: ``` Traceback (most recent call last): File "/data/users/fduwjj/pytorch/torch/multiprocessing/spawn.py", line 74, in _wrap fn(i, *args) File "/data/users/fduwjj/pytorch/benchmarks/distributed/tensor/tp_benchmark.py", line 210, in run_tp output.sum().backward() File "/data/users/fduwjj/pytorch/torch/_tensor.py", line 491, in backward torch.autograd.backward( File "/data/users/fduwjj/pytorch/torch/autograd/__init__.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/api.py", line 252, in __torch_dispatch__ return op_dispatch.operator_dispatch( File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/dispatch.py", line 116, in operator_dispatch out, _, _ = _operator_dispatch(op_call, args, kwargs, sharding_propagator) File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/dispatch.py", line 246, in _operator_dispatch local_results = op_call(*local_tensor_args, **local_tensor_kwargs) File "/data/users/fduwjj/pytorch/torch/_ops.py", line 435, in __call__ return self._op(*args, **kwargs or {}) RuntimeError: The size of tensor a (64) must match the size of tensor b (8) at non-singleton dimension 1 ``` Same thing with dtype, if we are using DTensor in the environment of mixed precision, we will run into situations like this. Pull Request resolved: pytorch#107181 Approved by: https://github.com/wanchaol ghstack dependencies: pytorch#106524
- Loading branch information