Skip to content

Commit

Permalink
[optim] add fused_adam/adamw_kernel support for CPU device (pytorch#1…
Browse files Browse the repository at this point in the history
…23074)

On par with `CUDA` implementation.

For `autocast` logic, same with `CUDA` + `Fused Adam`:
 - check inf in `gradscalar.step`
 - In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param.

**TestPlan**:
```
# extend CUDA only test for CPU fused adagrad
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_torch.py -k test_grad_scaling_autocast_fused

# extend fused test
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
python test_optim.py -k test_can_load_older_state_dict

# newly added test (follow https://github.com/pytorch/pytorch/blob/6b1f13ea2f3b1bcd575620eecd7d84a4d2e3eb76/test/test_cuda.py#L1108)
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
```

**Benchmark**:
**5.1x** on 56 core SPR
**Parameter-size=1M**
**Nparams=10**
[test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7)

```
numactl -C 0-55 -m 0 python bench_adam.py
non-fused 6.0174267292022705 s
fused 1.1787631511688232 s
```

**Note: Fused kernel accuracy**
The accuracy failure in CI shows a little higher than default tolerance
```
2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%)
2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed)
2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed)
```
I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations.
For example, in non-fused impl
```
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
and in fused impl
```
  exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d];
  //  std::cout << "exp_avg_sq " <<   exp_avg_sq_ptr[d] << std::endl;
  exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
      scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
```
If I keep `std::cout`, I can get exactly same results in UT
```
===============param
0.6796758770942688
0.6796758770942688
```
But when I comment out it, there will be a difference
```
===============param
0.6796758770942688
0.6796759366989136
```
So I will make the tolerance a little higher than default one.

Co-authored-by: Jane Xu <[email protected]>
Pull Request resolved: pytorch#123074
Approved by: https://github.com/jgong5, https://github.com/janeyx99
  • Loading branch information
janeyx99 authored and pytorchmergebot committed Apr 19, 2024
1 parent 9a71d12 commit b412b75
Show file tree
Hide file tree
Showing 10 changed files with 827 additions and 78 deletions.
175 changes: 175 additions & 0 deletions aten/src/ATen/native/FusedAdam.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/FusedAdam.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_fused_adam.h>
#include <ATen/ops/_fused_adam_native.h>
#include <ATen/ops/_fused_adamw.h>
#include <ATen/ops/_fused_adamw_native.h>
#endif
namespace at {

namespace native {

void _fused_adam_kernel_cpu_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
const float* grad_scale_ptr =
grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
const float* found_inf_ptr =
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
if (found_inf_ptr && *found_inf_ptr == 1.0) {
return;
}
size_t n_tensors = params.size();
TORCH_CHECK(grads.size() == n_tensors);
TORCH_CHECK(exp_avgs.size() == n_tensors);
TORCH_CHECK(exp_avg_sqs.size() == n_tensors);
if (amsgrad) {
TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
} else {
TORCH_CHECK(max_exp_avg_sqs.size() == 0);
}
TORCH_CHECK(state_steps.size() == n_tensors);
at::Tensor max_exp_avg_sq = at::Tensor();
for (size_t i = 0; i < n_tensors; i++){
if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i];
fused_adam_stub(
kCPU,
params[i],
grads[i],
exp_avgs[i],
exp_avg_sqs[i],
max_exp_avg_sq,
state_steps[i],
lr,
beta1,
beta2,
weight_decay,
eps,
amsgrad,
maximize,
grad_scale_ptr,
ADAM_MODE::ORIGINAL);
}
}

// The following overload simply has a Tensor lr
void _fused_adam_kernel_cpu_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const at::Tensor& lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
_fused_adam_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item<double>(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
}

void _fused_adamw_kernel_cpu_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
const float* grad_scale_ptr =
grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
const float* found_inf_ptr =
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
if (found_inf_ptr && *found_inf_ptr == 1.0) {
return;
}
size_t n_tensors = params.size();
TORCH_CHECK(grads.size() == n_tensors);
TORCH_CHECK(exp_avgs.size() == n_tensors);
TORCH_CHECK(exp_avg_sqs.size() == n_tensors);
if (amsgrad) {
TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
} else {
TORCH_CHECK(max_exp_avg_sqs.size() == 0);
}
TORCH_CHECK(state_steps.size() == n_tensors);
at::Tensor max_exp_avg_sq = at::Tensor();
for (size_t i = 0; i < n_tensors; i++){
if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i];
fused_adam_stub(
kCPU,
params[i],
grads[i],
exp_avgs[i],
exp_avg_sqs[i],
max_exp_avg_sq,
state_steps[i],
lr,
beta1,
beta2,
weight_decay,
eps,
amsgrad,
maximize,
grad_scale_ptr,
ADAM_MODE::ADAMW);
}
}

// The following overload simply has a Tensor lr
void _fused_adamw_kernel_cpu_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const at::Tensor& lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
_fused_adamw_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item<double>(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
}


DEFINE_DISPATCH(fused_adam_stub);

}
}
30 changes: 30 additions & 0 deletions aten/src/ATen/native/FusedAdam.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>

namespace at {

namespace native {

enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };

using fused_adam_fn = void (*)(
const at::Tensor& param,
const at::Tensor& grad,
const at::Tensor& exp_avg,
const at::Tensor& exp_avg_sq,
const at::Tensor& max_exp_avg_sq,
const at::Tensor& state_step,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const float* grad_scale_ptr,
const ADAM_MODE);

DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub);

}
}
Loading

0 comments on commit b412b75

Please sign in to comment.