Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[optim] add fused_adam/adamw_kernel support for CPU device (pytorch#1…
…23074) On par with `CUDA` implementation. For `autocast` logic, same with `CUDA` + `Fused Adam`: - check inf in `gradscalar.step` - In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param. **TestPlan**: ``` # extend CUDA only test for CPU fused adagrad python test_optim.py -k test_fused_matches_forloop python test_optim.py -k test_fused_large_tensor python test_torch.py -k test_grad_scaling_autocast_fused # extend fused test python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step python test_optim.py -k test_can_load_older_state_dict # newly added test (follow https://github.com/pytorch/pytorch/blob/6b1f13ea2f3b1bcd575620eecd7d84a4d2e3eb76/test/test_cuda.py#L1108) python test_optim.py -k test_grad_scaling_autocast_fused_optimizers ``` **Benchmark**: **5.1x** on 56 core SPR **Parameter-size=1M** **Nparams=10** [test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7) ``` numactl -C 0-55 -m 0 python bench_adam.py non-fused 6.0174267292022705 s fused 1.1787631511688232 s ``` **Note: Fused kernel accuracy** The accuracy failure in CI shows a little higher than default tolerance ``` 2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%) 2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed) 2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed) ``` I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations. For example, in non-fused impl ``` exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) ``` and in fused impl ``` exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d]; // std::cout << "exp_avg_sq " << exp_avg_sq_ptr[d] << std::endl; exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] + scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val; ``` If I keep `std::cout`, I can get exactly same results in UT ``` ===============param 0.6796758770942688 0.6796758770942688 ``` But when I comment out it, there will be a difference ``` ===============param 0.6796758770942688 0.6796759366989136 ``` So I will make the tolerance a little higher than default one. Co-authored-by: Jane Xu <[email protected]> Pull Request resolved: pytorch#123074 Approved by: https://github.com/jgong5, https://github.com/janeyx99
- Loading branch information