From b412b75b42b27918c7cccd7a4f1121c8da14cb71 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 19 Apr 2024 09:54:05 +0000 Subject: [PATCH] [optim] add fused_adam/adamw_kernel support for CPU device (#123074) On par with `CUDA` implementation. For `autocast` logic, same with `CUDA` + `Fused Adam`: - check inf in `gradscalar.step` - In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param. **TestPlan**: ``` # extend CUDA only test for CPU fused adagrad python test_optim.py -k test_fused_matches_forloop python test_optim.py -k test_fused_large_tensor python test_torch.py -k test_grad_scaling_autocast_fused # extend fused test python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step python test_optim.py -k test_can_load_older_state_dict # newly added test (follow https://github.com/pytorch/pytorch/blob/6b1f13ea2f3b1bcd575620eecd7d84a4d2e3eb76/test/test_cuda.py#L1108) python test_optim.py -k test_grad_scaling_autocast_fused_optimizers ``` **Benchmark**: **5.1x** on 56 core SPR **Parameter-size=1M** **Nparams=10** [test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7) ``` numactl -C 0-55 -m 0 python bench_adam.py non-fused 6.0174267292022705 s fused 1.1787631511688232 s ``` **Note: Fused kernel accuracy** The accuracy failure in CI shows a little higher than default tolerance ``` 2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%) 2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed) 2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed) ``` I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations. For example, in non-fused impl ``` exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) ``` and in fused impl ``` exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d]; // std::cout << "exp_avg_sq " << exp_avg_sq_ptr[d] << std::endl; exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] + scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val; ``` If I keep `std::cout`, I can get exactly same results in UT ``` ===============param 0.6796758770942688 0.6796758770942688 ``` But when I comment out it, there will be a difference ``` ===============param 0.6796758770942688 0.6796759366989136 ``` So I will make the tolerance a little higher than default one. Co-authored-by: Jane Xu Pull Request resolved: https://github.com/pytorch/pytorch/pull/123074 Approved by: https://github.com/jgong5, https://github.com/janeyx99 --- aten/src/ATen/native/FusedAdam.cpp | 175 +++++++++ aten/src/ATen/native/FusedAdam.h | 30 ++ aten/src/ATen/native/cpu/FusedAdamKernel.cpp | 379 +++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 4 + build_variables.bzl | 2 + test/test_optim.py | 209 +++++++--- test/test_torch.py | 7 +- torch/optim/adam.py | 2 +- torch/optim/adamw.py | 2 +- torch/testing/_internal/common_optimizers.py | 95 +++-- 10 files changed, 827 insertions(+), 78 deletions(-) create mode 100644 aten/src/ATen/native/FusedAdam.cpp create mode 100644 aten/src/ATen/native/FusedAdam.h create mode 100644 aten/src/ATen/native/cpu/FusedAdamKernel.cpp diff --git a/aten/src/ATen/native/FusedAdam.cpp b/aten/src/ATen/native/FusedAdam.cpp new file mode 100644 index 00000000000000..b3be769b24f185 --- /dev/null +++ b/aten/src/ATen/native/FusedAdam.cpp @@ -0,0 +1,175 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif +namespace at { + +namespace native { + +void _fused_adam_kernel_cpu_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + const float* grad_scale_ptr = + grad_scale.has_value() ? grad_scale->data_ptr() : nullptr; + const float* found_inf_ptr = + found_inf.has_value() ? found_inf->data_ptr() : nullptr; + if (found_inf_ptr && *found_inf_ptr == 1.0) { + return; + } + size_t n_tensors = params.size(); + TORCH_CHECK(grads.size() == n_tensors); + TORCH_CHECK(exp_avgs.size() == n_tensors); + TORCH_CHECK(exp_avg_sqs.size() == n_tensors); + if (amsgrad) { + TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors); + } else { + TORCH_CHECK(max_exp_avg_sqs.size() == 0); + } + TORCH_CHECK(state_steps.size() == n_tensors); + at::Tensor max_exp_avg_sq = at::Tensor(); + for (size_t i = 0; i < n_tensors; i++){ + if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i]; + fused_adam_stub( + kCPU, + params[i], + grads[i], + exp_avgs[i], + exp_avg_sqs[i], + max_exp_avg_sq, + state_steps[i], + lr, + beta1, + beta2, + weight_decay, + eps, + amsgrad, + maximize, + grad_scale_ptr, + ADAM_MODE::ORIGINAL); + } +} + +// The following overload simply has a Tensor lr +void _fused_adam_kernel_cpu_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + _fused_adam_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); +} + +void _fused_adamw_kernel_cpu_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + const float* grad_scale_ptr = + grad_scale.has_value() ? grad_scale->data_ptr() : nullptr; + const float* found_inf_ptr = + found_inf.has_value() ? found_inf->data_ptr() : nullptr; + if (found_inf_ptr && *found_inf_ptr == 1.0) { + return; + } + size_t n_tensors = params.size(); + TORCH_CHECK(grads.size() == n_tensors); + TORCH_CHECK(exp_avgs.size() == n_tensors); + TORCH_CHECK(exp_avg_sqs.size() == n_tensors); + if (amsgrad) { + TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors); + } else { + TORCH_CHECK(max_exp_avg_sqs.size() == 0); + } + TORCH_CHECK(state_steps.size() == n_tensors); + at::Tensor max_exp_avg_sq = at::Tensor(); + for (size_t i = 0; i < n_tensors; i++){ + if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i]; + fused_adam_stub( + kCPU, + params[i], + grads[i], + exp_avgs[i], + exp_avg_sqs[i], + max_exp_avg_sq, + state_steps[i], + lr, + beta1, + beta2, + weight_decay, + eps, + amsgrad, + maximize, + grad_scale_ptr, + ADAM_MODE::ADAMW); + } +} + +// The following overload simply has a Tensor lr +void _fused_adamw_kernel_cpu_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + _fused_adamw_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); +} + + +DEFINE_DISPATCH(fused_adam_stub); + +} +} diff --git a/aten/src/ATen/native/FusedAdam.h b/aten/src/ATen/native/FusedAdam.h new file mode 100644 index 00000000000000..6fbbaf2441297e --- /dev/null +++ b/aten/src/ATen/native/FusedAdam.h @@ -0,0 +1,30 @@ +#include +#include + +namespace at { + +namespace native { + +enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; + +using fused_adam_fn = void (*)( + const at::Tensor& param, + const at::Tensor& grad, + const at::Tensor& exp_avg, + const at::Tensor& exp_avg_sq, + const at::Tensor& max_exp_avg_sq, + const at::Tensor& state_step, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const float* grad_scale_ptr, + const ADAM_MODE); + +DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub); + +} +} diff --git a/aten/src/ATen/native/cpu/FusedAdamKernel.cpp b/aten/src/ATen/native/cpu/FusedAdamKernel.cpp new file mode 100644 index 00000000000000..4a10fe202c4a0e --- /dev/null +++ b/aten/src/ATen/native/cpu/FusedAdamKernel.cpp @@ -0,0 +1,379 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +namespace at::native { + +namespace{ + +template +typename std::enable_if< + std::is_same::value || std::is_same::value, + void>:: + type inline adam_math( + scalar_t* param_ptr, + scalar_t* exp_avg_ptr, + scalar_t* exp_avg_sq_ptr, + scalar_t* grad_ptr, + scalar_t* max_exp_avg_sq_ptr, + double lr, + double bias_correction1, + double bias_correction2, + double exp_avg_grad_coefficient, + double exp_avg_sq_grad_coefficient, + double bias_correction2_sqrt, + double eps, + double weight_decay, + double beta2, + bool amsgrad, + bool maximize, + const float* grad_scale_ptr, + int64_t size +){ + double step_size = lr / bias_correction1; + using lpVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + lpVec grad_vec_to_store; + int64_t d = 0; + fVec param_vec1, param_vec2; + fVec grad_vec1, grad_vec2; + fVec exp_avg_vec1, exp_avg_vec2; + fVec exp_avg_sq_vec1, exp_avg_sq_vec2; + fVec max_exp_avg_sq_vec1, max_exp_avg_sq_vec2; + for (; d < size - (size % lpVec::size()); d += lpVec::size()) { + lpVec param_lpvec = lpVec::loadu(param_ptr + d); + std::tie(param_vec1, param_vec2) = vec::convert_to_float(param_lpvec); + lpVec grad_lpvec = lpVec::loadu(grad_ptr + d); + std::tie(grad_vec1, grad_vec2) = vec::convert_to_float(grad_lpvec); + if (grad_scale_ptr) { + grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr)); + grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr)); + grad_vec_to_store = vec::convert_from_float(grad_vec1, grad_vec2); + grad_vec_to_store.store(grad_ptr + d); + } + if (maximize){ + grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0)); + grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0)); + } + if (weight_decay != 0.f){ + if constexpr (adam_mode == ADAM_MODE::ORIGINAL) { + grad_vec1 += param_vec1 * fVec(opmath_t(weight_decay)); + grad_vec2 += param_vec2 * fVec(opmath_t(weight_decay)); + } else if constexpr (adam_mode == ADAM_MODE::ADAMW) { + param_vec1 = param_vec1 * fVec(opmath_t(1 - lr * weight_decay)); + param_vec2 = param_vec2 * fVec(opmath_t(1 - lr * weight_decay)); + } + } + + lpVec exp_avg_lpvec = lpVec::loadu(exp_avg_ptr + d); + std::tie(exp_avg_vec1, exp_avg_vec2) = vec::convert_to_float(exp_avg_lpvec); + + // exp_avg.lerp_(grad, 1 - beta1) + const fVec lerp_weight = fVec(opmath_t(exp_avg_grad_coefficient)); + auto mask = lerp_weight.abs() < fVec(0.5); + auto coeff = fVec::blendv(lerp_weight - fVec(1), lerp_weight, mask); + + auto base1 = fVec::blendv(grad_vec1, exp_avg_vec1, mask); + exp_avg_vec1 = vec::fmadd(coeff, grad_vec1 - exp_avg_vec1, base1); + + auto base2 = fVec::blendv(grad_vec2, exp_avg_vec2, mask); + exp_avg_vec2 = vec::fmadd(coeff, grad_vec2 - exp_avg_vec2, base2); + + lpVec exp_avg_sq_lpvec = lpVec::loadu(exp_avg_sq_ptr + d); + std::tie(exp_avg_sq_vec1, exp_avg_sq_vec2) = vec::convert_to_float(exp_avg_sq_lpvec); + exp_avg_sq_vec1 = exp_avg_sq_vec1 * fVec(opmath_t(beta2)) + + fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec1 * grad_vec1; + exp_avg_sq_vec2 = exp_avg_sq_vec2 * fVec(opmath_t(beta2)) + + fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec2 * grad_vec2; + + vec::convert_from_float(exp_avg_vec1, exp_avg_vec2).store(exp_avg_ptr + d); + vec::convert_from_float(exp_avg_sq_vec1, exp_avg_sq_vec2).store(exp_avg_sq_ptr + d); + + fVec denom_vec1, denom_vec2; + if (amsgrad) { + lpVec max_exp_avg_sq_lpvec = lpVec::loadu(max_exp_avg_sq_ptr + d); + std::tie(max_exp_avg_sq_vec1, max_exp_avg_sq_vec2) = vec::convert_to_float(max_exp_avg_sq_lpvec); + max_exp_avg_sq_vec1 = maximum(max_exp_avg_sq_vec1, exp_avg_sq_vec1); + max_exp_avg_sq_vec2 = maximum(max_exp_avg_sq_vec2, exp_avg_sq_vec2); + vec::convert_from_float(max_exp_avg_sq_vec1, max_exp_avg_sq_vec2).store(max_exp_avg_sq_ptr + d); + denom_vec1 = + (max_exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps)); + denom_vec2 = + (max_exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps)); + } else { + denom_vec1 = + (exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps)); + denom_vec2 = + (exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps)); + } + param_vec1 = param_vec1 + fVec(opmath_t(-step_size)) * exp_avg_vec1 / denom_vec1; + param_vec2 = param_vec2 + fVec(opmath_t(-step_size)) * exp_avg_vec2 / denom_vec2; + vec::convert_from_float(param_vec1, param_vec2).store(param_ptr + d); + } + scalar_t grad_val_to_store; + for (; d < size; d++) { + opmath_t grad_val = grad_ptr[d]; + opmath_t param_val = param_ptr[d]; + if (grad_scale_ptr) { + grad_val = grad_ptr[d] / float(*grad_scale_ptr); + grad_val_to_store = scalar_t(grad_val); + grad_ptr[d] = grad_val_to_store; + } + if (maximize) grad_val = -grad_val; + if (weight_decay != 0.f){ + if constexpr (adam_mode == ADAM_MODE::ORIGINAL) { + grad_val += param_val * opmath_t(weight_decay); + } else if constexpr (adam_mode == ADAM_MODE::ADAMW) { + param_val = param_val * opmath_t(1 - lr * weight_decay); + } + } + // exp_avg.lerp_(grad, 1 - beta1) + opmath_t exp_avg_var = exp_avg_ptr[d]; + auto is_lerp_weight_small = std::abs(opmath_t(exp_avg_grad_coefficient)) < opmath_t(0.5); + if (is_lerp_weight_small) { + exp_avg_var = exp_avg_var + opmath_t(exp_avg_grad_coefficient) * (grad_val - exp_avg_var); + } else { + exp_avg_var = grad_val - (grad_val - exp_avg_var) * (opmath_t(1) - opmath_t(exp_avg_grad_coefficient)); + } + exp_avg_ptr[d] = scalar_t(exp_avg_var); + opmath_t exp_avg_sq_var = exp_avg_sq_ptr[d]; + exp_avg_sq_var = exp_avg_sq_var * opmath_t(beta2); + exp_avg_sq_var = exp_avg_sq_var + + opmath_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val; + exp_avg_sq_ptr[d] = scalar_t(exp_avg_sq_var); + opmath_t demon_val; + if (amsgrad) { + opmath_t max_exp_avg_sq_var = max_exp_avg_sq_ptr[d]; + max_exp_avg_sq_var = std::max(max_exp_avg_sq_var, exp_avg_sq_var); + max_exp_avg_sq_ptr[d] = + scalar_t(max_exp_avg_sq_var); + demon_val = + std::sqrt(max_exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps); + } else { + demon_val = std::sqrt(exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps); + } + param_ptr[d] = param_val - opmath_t(step_size) * exp_avg_var / demon_val; + } +} + + +template +typename std::enable_if< + std::is_same::value || std::is_same::value, + void>:: + type inline adam_math( + scalar_t* param_ptr, + scalar_t* exp_avg_ptr, + scalar_t* exp_avg_sq_ptr, + scalar_t* grad_ptr, + scalar_t* max_exp_avg_sq_ptr, + double lr, + double bias_correction1, + double bias_correction2, + double exp_avg_grad_coefficient, + double exp_avg_sq_grad_coefficient, + double bias_correction2_sqrt, + double eps, + double weight_decay, + double beta2, + bool amsgrad, + bool maximize, + const float* grad_scale_ptr, + int64_t size +){ + double step_size = lr / bias_correction1; + using Vec = at::vec::Vectorized; + Vec grad_vec_to_store; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec param_vec = Vec::loadu(param_ptr + d); + Vec grad_vec = Vec::loadu(grad_ptr + d); + if (grad_scale_ptr) { + grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr)); + grad_vec_to_store = grad_vec; + grad_vec_to_store.store(grad_ptr + d); + } + if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0)); + if (weight_decay != 0.f){ + if constexpr (adam_mode == ADAM_MODE::ORIGINAL) { + grad_vec += param_vec * Vec(scalar_t(weight_decay)); + } else if constexpr (adam_mode == ADAM_MODE::ADAMW) { + param_vec = param_vec * Vec(scalar_t(1 - lr * weight_decay)); + } + } + Vec exp_avg_vec = Vec::loadu(exp_avg_ptr + d); + // exp_avg.lerp_(grad, 1 - beta1) + const Vec lerp_weight = Vec(scalar_t(exp_avg_grad_coefficient)); + auto mask = lerp_weight.abs() < Vec(0.5); + auto coeff = Vec::blendv(lerp_weight - Vec(1), lerp_weight, mask); + auto base = Vec::blendv(grad_vec, exp_avg_vec, mask); + exp_avg_vec = vec::fmadd(coeff, grad_vec - exp_avg_vec, base); + + Vec exp_avg_sq_vec = Vec::loadu(exp_avg_sq_ptr + d) * Vec(scalar_t(beta2)) + + Vec(scalar_t(exp_avg_sq_grad_coefficient)) * grad_vec * grad_vec; + exp_avg_vec.store(exp_avg_ptr + d); + exp_avg_sq_vec.store(exp_avg_sq_ptr + d); + + Vec denom_vec; + if (amsgrad) { + Vec max_exp_avg_sq_vec = + maximum(Vec::loadu(max_exp_avg_sq_ptr + d), exp_avg_sq_vec); + max_exp_avg_sq_vec.store(max_exp_avg_sq_ptr + d); + denom_vec = + (max_exp_avg_sq_vec.sqrt() / Vec(scalar_t(bias_correction2_sqrt))) + Vec(scalar_t(eps)); + } else { + denom_vec = + (exp_avg_sq_vec.sqrt() / Vec(scalar_t(bias_correction2_sqrt))) + Vec(scalar_t(eps)); + } + param_vec = param_vec + Vec(scalar_t(-step_size)) * exp_avg_vec / denom_vec; + param_vec.store(param_ptr + d); + } + scalar_t grad_val_to_store; + for (; d < size; d++) { + scalar_t grad_val = grad_ptr[d]; + if (grad_scale_ptr) { + grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr); + grad_val_to_store = grad_val; + grad_ptr[d] = grad_val_to_store; + } + if (maximize) grad_val = -grad_val; + if (weight_decay != 0.f){ + if constexpr (adam_mode == ADAM_MODE::ORIGINAL) { + grad_val += param_ptr[d] * scalar_t(weight_decay); + } else if constexpr (adam_mode == ADAM_MODE::ADAMW) { + param_ptr[d] = param_ptr[d] * scalar_t(1 - lr * weight_decay); + } + } + // exp_avg.lerp_(grad, 1 - beta1) + auto is_lerp_weight_small = std::abs(scalar_t(exp_avg_grad_coefficient)) < scalar_t(0.5); + if (is_lerp_weight_small) { + exp_avg_ptr[d] = exp_avg_ptr[d] + scalar_t(exp_avg_grad_coefficient) * (grad_val - exp_avg_ptr[d]); + } else { + exp_avg_ptr[d] = grad_val - (grad_val - exp_avg_ptr[d]) * (scalar_t(1) - scalar_t(exp_avg_grad_coefficient)); + } + exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] * scalar_t(beta2); + exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] + + scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val; + scalar_t demon_val; + if (amsgrad) { + max_exp_avg_sq_ptr[d] = + std::max(max_exp_avg_sq_ptr[d], exp_avg_sq_ptr[d]); + demon_val = + std::sqrt(max_exp_avg_sq_ptr[d]) / scalar_t(bias_correction2_sqrt) + scalar_t(eps); + } else { + demon_val = std::sqrt(exp_avg_sq_ptr[d]) / scalar_t(bias_correction2_sqrt) + scalar_t(eps); + } + param_ptr[d] = param_ptr[d] - scalar_t(step_size) * exp_avg_ptr[d] / demon_val; + } +} + + +template +void adam_fused_step_impl( + const at::Tensor& param, + const at::Tensor& grad, + const at::Tensor& exp_avg, + const at::Tensor& exp_avg_sq, + const at::Tensor& max_exp_avg_sq, + const at::Tensor& state_step, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const float* grad_scale_ptr) { + using opmath_t = at::opmath_type; + double step = state_step.item(); + scalar_t* param_data = param.data_ptr(); + scalar_t* exp_avg_data = exp_avg.data_ptr(); + scalar_t* exp_avg_sq_data = exp_avg_sq.data_ptr(); + scalar_t* max_exp_avg_sq_data = amsgrad ? max_exp_avg_sq.data_ptr() : nullptr; + scalar_t* grad_data = grad.data_ptr(); + + // need to use double here to align with non-fused adam + double bias_correction1 = 1 - std::pow(beta1, step); + double bias_correction2 = 1 - std::pow(beta2, step); + double exp_avg_grad_coefficient = 1 - beta1; + double exp_avg_sq_grad_coefficient = 1 - beta2; + double bias_correction2_sqrt = std::sqrt(bias_correction2); + + + constexpr size_t cache_line_size = 64; + constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t); + size_t num_units = divup(param.numel(), cache_line_aligned_task_unit); + + auto adam_fn = [&](int64_t begin, int64_t end) { + // local pointers + begin *= cache_line_aligned_task_unit; + end = std::min(end * cache_line_aligned_task_unit, param.numel()); + scalar_t* param_ptr = param_data + begin; + scalar_t* exp_avg_ptr = exp_avg_data + begin; + scalar_t* exp_avg_sq_ptr = exp_avg_sq_data + begin; + scalar_t* grad_ptr = grad_data + begin; + scalar_t* max_exp_avg_sq_ptr = amsgrad ? max_exp_avg_sq_data + begin : nullptr; + + const int64_t size = end - begin; + adam_math( + param_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + grad_ptr, + max_exp_avg_sq_ptr, + lr, + bias_correction1, + bias_correction2, + exp_avg_grad_coefficient, + exp_avg_sq_grad_coefficient, + bias_correction2_sqrt, + eps, + weight_decay, + beta2, + amsgrad, + maximize, + grad_scale_ptr, + size + ); + }; + at::parallel_for( + 0, num_units, 0, adam_fn); +} + +void fused_adam_kernel( + const at::Tensor& param, + const at::Tensor& grad, + const at::Tensor& exp_avg, + const at::Tensor& exp_avg_sq, + const at::Tensor& max_exp_avg_sq, + const at::Tensor& state_step, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const float* grad_scale_ptr, + const ADAM_MODE adam_mode + ) { + Tensor grad_contiguous = grad.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_adam_kernel", [&] { + if(adam_mode == ADAM_MODE::ORIGINAL){ + adam_fused_step_impl(param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, state_step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_ptr); + } else { + adam_fused_step_impl(param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, state_step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_ptr); + } + + }); +} + +} + +REGISTER_DISPATCH(fused_adam_stub, &fused_adam_kernel); +} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6e96a8a6aabc2c..f3f56833503b71 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -15517,6 +15517,7 @@ # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). variants: function dispatch: + CPU: _fused_adam_kernel_cpu_ CUDA: _fused_adam_kernel_cuda_ autogen: _fused_adam, _fused_adam.out @@ -15526,6 +15527,7 @@ device_check: NoCheck variants: function dispatch: + CPU: _fused_adam_kernel_cpu_ CUDA: _fused_adam_kernel_cuda_ autogen: _fused_adam.tensor_lr, _fused_adam.tensor_lr_out @@ -15533,6 +15535,7 @@ # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). variants: function dispatch: + CPU: _fused_adamw_kernel_cpu_ CUDA: _fused_adamw_kernel_cuda_ autogen: _fused_adamw, _fused_adamw.out @@ -15542,6 +15545,7 @@ device_check: NoCheck variants: function dispatch: + CPU: _fused_adamw_kernel_cpu_ CUDA: _fused_adamw_kernel_cuda_ autogen: _fused_adamw.tensor_lr, _fused_adamw.tensor_lr_out diff --git a/build_variables.bzl b/build_variables.bzl index 6a152fb9099b53..36e54ffda40ff1 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1168,6 +1168,7 @@ aten_native_source_codegen_list = [ "aten/src/ATen/native/cpu/SpmmReduceKernel.cpp", "aten/src/ATen/native/cpu/SparseFactories.cpp", "aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp", + "aten/src/ATen/native/cpu/FusedAdamKernel.cpp", ] # This aten native source file list will not go through aten codegen process @@ -1402,6 +1403,7 @@ aten_native_source_non_codegen_list = [ "aten/src/ATen/native/xnnpack/OpContext.cpp", "aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp", "aten/src/ATen/native/xnnpack/Shim.cpp", + "aten/src/ATen/native/FusedAdam.cpp", # Files not in native, but depends on native symbols # "aten/src/ATen/TensorIndexing.cpp", "aten/src/ATen/TensorIterator.cpp", diff --git a/test/test_optim.py b/test/test_optim.py index 680d967a26d822..9eea11ffda14a5 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -21,9 +21,10 @@ from torch.testing._internal.common_optimizers import ( optim_db, optims, OptimizerErrorEnum, _get_optim_inputs_including_global_cliquey_kwargs, TensorTracker) from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM) + instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM, onlyNativeDeviceTypes) from torch.testing._internal.common_utils import markDynamoStrictTest, parametrize, run_tests, TestCase - +from torch.testing._internal.common_cuda import _create_scaling_case +from torch.testing._internal.common_dtype import floating_types_and FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4} @@ -581,6 +582,49 @@ def closure2(): self.assertTrue(a1_grad_imags.all_popped()) self.assertTrue(losses.all_popped()) + def _compare_between(self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None): + # why 7? iteration 7 is where we start to see differences for RAdam + # params interacting with the small eps value, because that's right + # after rho_t becomes greater than 5 in step 6. + if assert_eq_kwargs is None: + assert_eq_kwargs = {} + kIterations = 7 + tracker = TensorTracker(assert_eq_kwargs) + for i in range(kIterations): + state, updated_params = [], [] + if not isinstance(inputs, list): + inputs = [inputs, inputs] + for input, model, optimizer in zip(inputs, models, optimizers): + optimizer.zero_grad() + + # Test that step behaves as expected (a no-op) when grads are set to None + if i != 3: + output = model(input) + loss = output.sum() + loss.backward() + + optimizer.step() + state.append(optimizer.state) + updated_params.append(model.parameters()) + + og_state, new_state = state + for og_p, new_p in zip(updated_params[0], updated_params[1]): + tracker.add(og_p) + tracker.pop_check_set(new_p, self) + + # check that optimizer states are the same + og_p_state = og_state[og_p] + new_p_state = new_state[new_p] + if assert_step_dtype is not None: + if torch.is_tensor(og_p_state.get("step", None)): + self.assertEqual(og_p_state["step"].dtype, assert_step_dtype) + if torch.is_tensor(new_p_state.get("step", None)): + self.assertEqual(new_p_state["step"].dtype, assert_step_dtype) + for k in og_p_state: + tracker.add(og_p_state[k]) + tracker.pop_check_set(new_p_state[k], self) + + self.assertTrue(tracker.all_popped()) def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_precision=False, assert_step_dtype=None): """ @@ -589,16 +633,12 @@ def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_prec for provided optimizer configurations. """ assert flag in ("foreach", "fused") + assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION - # why 7? iteration 7 is where we start to see differences for RAdam - # params interacting with the small eps value, because that's right - # after rho_t becomes greater than 5 in step 6. - kIterations = 7 - - optim_inputs = optim_info.optim_inputs_func(device=device) + optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype) optim_cls = optim_info.optim_cls for optim_input in optim_inputs: - updated_params, state = [], [] + models, optimizers = [], [] kwargs = deepcopy(optim_input.kwargs) if kwargs.get("capturable", False) and str(device) == "cpu": # capturable is not supported on CPU @@ -626,39 +666,10 @@ def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_prec params = list(model.parameters()) + [empty_param] optimizer = optim_cls(params, **kwargs) + models.append(model) + optimizers.append(optimizer) - for i in range(kIterations): - optimizer.zero_grad() - - # Test that step behaves as expected (a no-op) when grads are set to None - if i != 3: - output = model(input) - loss = output.sum() - loss.backward() - - optimizer.step() - - if assert_step_dtype is not None: - p_state = optimizer.state[params[0]] - if torch.is_tensor(p_state.get("step", None)): - self.assertEqual(p_state["step"].dtype, assert_step_dtype) - - state.append(optimizer.state) - updated_params.append(model.parameters()) - - assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION - - og_state, new_state = state - for og_p, new_p in zip(updated_params[0], updated_params[1]): - self.assertEqual(og_p, new_p, **assert_eq_kwargs) - - # check that optimizer states are the same - og_p_state = og_state[og_p] - new_p_state = new_state[new_p] - - for k in og_p_state: - self.assertEqual(og_p_state[k], new_p_state[k], **assert_eq_kwargs) - + self._compare_between(input, models, optimizers, assert_eq_kwargs, assert_step_dtype) @skipMPS # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350 @optims([optim for optim in optim_db if "foreach" in optim.supported_impls], dtypes=[torch.float64]) @@ -847,16 +858,23 @@ def test_peak_memory_foreach(self, device, dtype, optim_info): self.assertLessEqual(mt_max_mem, expected_max_mem) - @onlyCUDA - @optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float64]) + @onlyNativeDeviceTypes + @optims( + [optim for optim in optim_db if "fused" in optim.supported_impls], + dtypes=floating_types_and(torch.bfloat16, torch.float16, ) + ) def test_fused_matches_forloop(self, device, dtype, optim_info): + if device not in optim_info.supports_fused_on: + self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}") self._test_derived_optimizers(device, dtype, optim_info, "fused") - @onlyCUDA - @largeTensorTest("64GB", "cuda") + @onlyNativeDeviceTypes + @largeTensorTest("64GB") @optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float16]) def test_fused_large_tensor(self, device, dtype, optim_info): + if device not in optim_info.supports_fused_on: + self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}") optim_cls = optim_info.optim_cls optim_inputs = optim_info.optim_inputs_func(device=device) for optim_input in optim_inputs: @@ -1304,10 +1322,11 @@ def closure(): # Make sure that device of state['step'] is still CPU _unless_ torch.compile() added a capturable! capturable = state_dict_cpu["param_groups"][0].get("capturable", False) + fused = state_dict_cpu["param_groups"][0].get("fused", False) new_state_dict = optimizer_cuda.state_dict() for state_cpu, state_cuda in zip(state_dict_cpu["state"].values(), new_state_dict["state"].values()): if "step" in state_cpu and torch.is_tensor(state_cpu["step"]): - self.assertEqual(state_cuda["step"].device.type, "cuda" if capturable else "cpu") + self.assertEqual(state_cuda["step"].device.type, "cuda" if capturable or fused else "cpu") for _ in range(5): optimizer.step(closure) @@ -1615,6 +1634,104 @@ def closure(): res2 = optim_neg_inf.step(closure) self.assertEqual(type(res1), type(res2)) + @onlyCUDA + @optims( + [optim for optim in optim_db if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on], + dtypes=floating_types_and(torch.bfloat16, torch.float16,) + ) + def test_fused_cpu_matches_cuda(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + optim_inputs = optim_info.optim_inputs_func(device="cpu") + for optim_input in optim_inputs: + inpts, models, optimizers = [], [], [] + for dev in ('cpu', 'cuda'): + kwargs = optim_input.kwargs + kwargs["fused"] = True + inpt = torch.tensor( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev + ).reshape(3, 2) + + torch.manual_seed(1) + model = torch.nn.Sequential( + torch.nn.Linear(2, 3), + torch.nn.Sigmoid(), + torch.nn.Linear(3, 1), + torch.nn.Sigmoid(), + ) + model.to(dtype=dtype, device=dev) + + # foreach/fused optimizers should be tested with a + # zero_size tensor as its last param. + # ref: https://github.com/pytorch/pytorch/issues/100701 + empty_param = torch.empty((), device=dev, dtype=dtype, requires_grad=True) + empty_param.grad = torch.rand_like(empty_param) + params = list(model.parameters()) + [empty_param] + + optimizer = optim_cls(params, **kwargs) + inpts.append(inpt) + models.append(model) + optimizers.append(optimizer) + self._compare_between(inpts, models, optimizers) + + @onlyCPU + @optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float32]) + def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info): + # This ut is from test_cuda.py test_grad_scaling_autocast_fused_optimizers + # but only test Adam/AdamW on CPU + # TODO: haozhe, support SGD and unified this ut with the CUDA only one + if device not in optim_info.supports_fused_on: + self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}") + optim_inputs = optim_info.optim_inputs_func(device=device) + optim_cls = optim_info.optim_cls + for optim_input in optim_inputs: + kwargs = optim_input.kwargs + for _separate_unscale in (True, False): + self._grad_scaling_autocast_fused_optimizers( + optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, separate_unscale=_separate_unscale) + + def _grad_scaling_autocast_fused_optimizers(self, optimizer_ctor, optimizer_kwargs, separate_unscale): + ( + mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, _, + ) = _create_scaling_case(optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, device='cpu') + kwargs = deepcopy(optimizer_kwargs) + kwargs["fused"] = False + if 'lr' not in optimizer_kwargs: + # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr + kwargs['lr'] = 1.0 + opt_control = optimizer_ctor(mod_control.parameters(), **kwargs) + + scaler = torch.cpu.amp.GradScaler(init_scale=128.0) + for input, target in data: + opt_control.zero_grad() + with torch.autocast('cpu', dtype=torch.half): + output_control = mod_control(input) + loss_control = loss_fn(output_control, target) + scaler.scale(loss_control).backward() + scaler.step(opt_control) + scaler.update() + + opt_scaling.zero_grad() + with torch.autocast('cpu', dtype=torch.half): + output_scaling = mod_scaling(input) + loss_scaling = loss_fn(output_scaling, target) + scaler.scale(loss_scaling).backward() + if separate_unscale: + scaler.unscale_(opt_scaling) + scaler.step(opt_scaling) + scaler.update() + + self.assertEqual(loss_control, loss_scaling,) + for param_control, param_scaling in zip(mod_control.parameters(), mod_scaling.parameters()): + self.assertEqual(param_control.grad, param_scaling.grad,) + self.assertEqual(param_control, param_scaling,) + + state_control, state_scaling = opt_control.state[param_control], opt_scaling.state[param_scaling] + + for k in state_control: + actual = state_scaling[k] + if k == "step": + actual = actual.squeeze() + self.assertEqual(state_control[k], actual,) @onlyCUDA @optims([o for o in optim_db if "foreach" in o.supported_impls], dtypes=[torch.float32]) diff --git a/test/test_torch.py b/test/test_torch.py index 9a1b619903254f..a11919328b67f3 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -47,8 +47,7 @@ instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, - skipMeta, - PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes, + skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes, get_all_device_types, skipXLA) from typing import Tuple import torch.backends.quantized @@ -5932,7 +5931,7 @@ def test_grad_scaling_autocast_foreach(self, device): for optimizer_ctor in (torch.optim.SGD, torch.optim.Adam, torch.optim.AdamW): self._grad_scaling_autocast_test(device=device.type, optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": True}) - @onlyCUDA + @onlyNativeDeviceTypes def test_grad_scaling_autocast_fused(self, device): device = torch.device(device) for optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW): @@ -5952,8 +5951,6 @@ def test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self {"foreach": False, "fused": True}, ), ): - if device.type != "cuda": - optimizer_kwargs['fused'] = False with self.subTest(optimizer=optimizer_ctor, optimizer_kwargs=optimizer_kwargs): self._test_grads_invalidated_between_unscale_and_step(device.type, optimizer_ctor, optimizer_kwargs) diff --git a/torch/optim/adam.py b/torch/optim/adam.py index cd45a197b378be..e74ad4e1abb87f 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -76,7 +76,7 @@ def __init__( # Support AMP with FP16/BF16 model params which would need # higher prec copy of params to do update math in higher prec to # alleviate the loss of information. - fused_supported_devices = _get_fused_kernels_supported_devices() + fused_supported_devices = _get_fused_kernels_supported_devices() + ["cpu"] if not all( p.device.type in fused_supported_devices and torch.is_floating_point(p) for pg in self.param_groups diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index bbe03c1ce5cd4e..89e776558d5e6c 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -75,7 +75,7 @@ def __init__( # Suppor AMP with FP16/BF16 model params which would need # higher prec copy of params to do update math in higher prec to # alleviate the loss of information. - fused_supported_devices = _get_fused_kernels_supported_devices() + fused_supported_devices = _get_fused_kernels_supported_devices() + ["cpu"] if not all( p.device.type in fused_supported_devices and torch.is_floating_point(p) for pg in self.param_groups diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 6cf83218ded8b9..a1bea634ffb8fb 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -44,10 +44,7 @@ skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, ) -from torch.utils._foreach_utils import ( - _get_foreach_kernels_supported_devices, - _get_fused_kernels_supported_devices, -) +from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices class OptimizerInput: @@ -143,6 +140,7 @@ def __init__( skips=(), # Indicates which tests to skip decorators=None, # Additional decorators to apply to generated tests optim_error_inputs_func=None, # Function to generate optim inputs that error + supports_fused_on: Tuple[str] = (), ): self.optim_cls = optim_cls self.optim_inputs_func = optim_inputs_func @@ -160,6 +158,7 @@ def __init__( *(skips if skips else []), ) self.optim_error_inputs_func = optim_error_inputs_func + self.supports_fused_on = supports_fused_on def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): result = [set_single_threaded_if_parallel_tbb] @@ -291,7 +290,7 @@ def get_error_inputs_for_all_optims(device, dtype): # global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs. -def optim_inputs_func_adadelta(device): +def optim_inputs_func_adadelta(device, dtype=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), OptimizerInput( @@ -340,7 +339,7 @@ def optim_error_inputs_func_adadelta(device, dtype): return error_inputs -def optim_inputs_func_adagrad(device): +def optim_inputs_func_adagrad(device, dtype=None): return [ OptimizerInput(params=None, kwargs={}, desc="default"), OptimizerInput( @@ -384,7 +383,7 @@ def optim_error_inputs_func_adagrad(device, dtype): # TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work # with all implementation code paths... -def optim_inputs_func_adam(device): +def optim_inputs_func_adam(device, dtype=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), OptimizerInput( @@ -399,7 +398,7 @@ def optim_inputs_func_adam(device): ), ] - return [ + total = [ OptimizerInput(params=None, kwargs={}, desc="default"), OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"), OptimizerInput( @@ -414,6 +413,19 @@ def optim_inputs_func_adam(device): params=None, kwargs={"weight_decay": 0.1, "amsgrad": True}, desc="amsgrad" ), ] + (cuda_supported_configs if "cuda" in str(device) else []) + if dtype in (torch.float16,): + for input in total: + """ + Too small eps will make denom to be zero for low precision dtype + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + For example, + >>> a + tensor([0.], dtype=torch.float16) + >>> a + 1e-8 + tensor([0.], dtype=torch.float16) + """ + input.kwargs["eps"] = 0.1 + return total def optim_error_inputs_func_adam(device, dtype): @@ -473,7 +485,7 @@ def optim_error_inputs_func_adam(device, dtype): return error_inputs -def optim_inputs_func_adamax(device): +def optim_inputs_func_adamax(device, dtype=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), OptimizerInput( @@ -524,15 +536,15 @@ def optim_error_inputs_func_adamax(device, dtype): return error_inputs -def optim_inputs_func_adamw(device): - return optim_inputs_func_adam(device) +def optim_inputs_func_adamw(device, dtype=None): + return optim_inputs_func_adam(device, dtype) def optim_error_inputs_func_adamw(device, dtype): return optim_error_inputs_func_adam(device, dtype) -def optim_inputs_func_asgd(device): +def optim_inputs_func_asgd(device, dtype=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), OptimizerInput( @@ -584,7 +596,7 @@ def optim_error_inputs_func_asgd(device, dtype): return error_inputs -def optim_inputs_func_lbfgs(device): +def optim_inputs_func_lbfgs(device, dtype=None): return [ OptimizerInput(params=None, kwargs={}, desc="default"), OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"), @@ -605,7 +617,7 @@ def optim_error_inputs_func_lbfgs(device, dtype): # Weird story bro, NAdam and RAdam do not have maximize. -def optim_inputs_func_nadam(device): +def optim_inputs_func_nadam(device, dtype=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), OptimizerInput( @@ -676,7 +688,7 @@ def optim_error_inputs_func_nadam(device, dtype): # Weird story bro, NAdam and RAdam do not have maximize. -def optim_inputs_func_radam(device=None): +def optim_inputs_func_radam(device=None, dtype=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), OptimizerInput( @@ -738,7 +750,7 @@ def optim_error_inputs_func_radam(device, dtype): return error_inputs -def optim_inputs_func_rmsprop(device): +def optim_inputs_func_rmsprop(device, dtype=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), OptimizerInput( @@ -799,7 +811,7 @@ def optim_error_inputs_func_rmsprop(device, dtype): return error_inputs -def optim_inputs_func_rprop(device): +def optim_inputs_func_rprop(device, dtype=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), OptimizerInput( @@ -841,7 +853,7 @@ def optim_error_inputs_func_rprop(device, dtype): return error_inputs -def optim_inputs_func_sgd(device): +def optim_inputs_func_sgd(device, dtype=None): return [ OptimizerInput(params=None, kwargs={}, desc="default"), OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"), @@ -886,7 +898,7 @@ def optim_error_inputs_func_sgd(device, dtype): return error_inputs -def optim_inputs_func_sparseadam(device): +def optim_inputs_func_sparseadam(device, dtype=None): return [ OptimizerInput(params=None, kwargs={}, desc="default"), OptimizerInput( @@ -995,10 +1007,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( x for x in optim_info.supported_impls if x not in skip - and ( - _get_device_type(device) in _get_fused_kernels_supported_devices() - or x != "fused" - ) + and (_get_device_type(device) in optim_info.supports_fused_on or x != "fused") and ( _get_device_type(device) in _get_foreach_kernels_supported_devices() or x != "foreach" @@ -1196,6 +1205,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( ), optim_error_inputs_func=optim_error_inputs_func_adam, supported_impls=("foreach", "differentiable", "fused"), + supports_fused_on=("cpu", "cuda"), decorators=( # Expected floating point error between fused and compiled forloop DecorateInfo( @@ -1205,6 +1215,21 @@ def _get_optim_inputs_including_global_cliquey_kwargs( active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO and kwargs["dtype"] == torch.float64, ), + DecorateInfo( + # Note on tolerances: + # difference comes from the fact that the non fused kernel have + # more dtype cast operations. We have another test test_fused_cpu_matches_cuda + # to make sure there is no discrepancies between cuda fused kernel + # and cpu fused kernel + toleranceOverride( + { + torch.bfloat16: tol(atol=5e-3, rtol=5e-3), + torch.float16: tol(atol=5e-3, rtol=5e-3), + } + ), + "TestOptimRenewed", + "test_fused_matches_forloop", + ), ), skips=( DecorateInfo( @@ -1364,6 +1389,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adamw, optim_error_inputs_func=optim_error_inputs_func_adamw, supported_impls=("foreach", "differentiable", "fused"), + supports_fused_on=("cpu", "cuda"), decorators=( # Expected error between compiled forloop and fused optimizers DecorateInfo( @@ -1373,6 +1399,21 @@ def _get_optim_inputs_including_global_cliquey_kwargs( active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO and kwargs["dtype"] == torch.float64, ), + DecorateInfo( + toleranceOverride( + # Note on tolerances: + # difference comes from the fact that the non fused kernel have + # more dtype cast operations. We have another test test_fused_cpu_matches_cuda + # to make sure there is no discrepancies between cuda fused kernel + # and cpu fused kernel + { + torch.bfloat16: tol(atol=5e-3, rtol=5e-3), + torch.float16: tol(atol=5e-3, rtol=5e-3), + } + ), + "TestOptimRenewed", + "test_fused_matches_forloop", + ), ), skips=( DecorateInfo( @@ -1865,6 +1906,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( }, [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)], ), + supports_fused_on=("cuda",), skips=( DecorateInfo( skipIfTorchDynamo( @@ -2060,7 +2102,10 @@ class TensorTracker: numerical discrepancies, and so when the test fails, it is likely a real problem. """ - def __init__(self): + def __init__(self, assert_eq_kwargs=None): + if assert_eq_kwargs is None: + assert_eq_kwargs = {} + self.assert_eq_kwargs = assert_eq_kwargs self.tensors = [] def add(self, tensor): @@ -2080,7 +2125,7 @@ def pop_check_set(self, tensor_to_set, testcase): ref = self.tensors.pop(0) testcase.assertTrue(isinstance(ref, Tensor), f"{type(ref)=}") - testcase.assertEqual(tensor_to_set, ref) + testcase.assertEqual(tensor_to_set, ref, **self.assert_eq_kwargs) with torch.no_grad(): tensor_to_set.copy_(ref)