diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 89c484493f3df..57190f9ae092d 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -352,7 +352,7 @@ static void sinc_kernel(TensorIteratorBase& iter) { } static void sinh_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "sinh_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "sinh_cpu", [&]() { cpu_kernel_vec( iter, [=](scalar_t a) -> scalar_t { return std::sinh(a); }, @@ -361,7 +361,7 @@ static void sinh_kernel(TensorIteratorBase& iter) { } static void cosh_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "cosh_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "cosh_cpu", [&]() { cpu_kernel_vec( iter, [=](scalar_t a) -> scalar_t { return std::cosh(a); }, @@ -425,7 +425,7 @@ static void polygamma_kernel(TensorIteratorBase& iter, int64_t n) { } else if (n == 1) { trigamma_kernel(iter); } else { - AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "polygamma", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "polygamma", [&]() { cpu_kernel( iter, [=](scalar_t a) -> scalar_t { return calc_polygamma(a, n); }); }); @@ -460,10 +460,12 @@ static void nan_to_num_kernel( } static void kaiser_window_kernel(TensorIteratorBase& iter, int64_t window_length, double beta){ - AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "kaiser_window_cpu", [&](){ - const scalar_t alpha = static_cast((window_length - 1) / 2.0); + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "kaiser_window_cpu", [&](){ + using opmath_t = at::opmath_type; + const opmath_t alpha = static_cast((window_length - 1) / 2.0); + const opmath_t beta_ = static_cast(beta); cpu_kernel(iter, [=](scalar_t a){ - return calc_i0(static_cast(beta) * std::sqrt(1 - std::pow((a - alpha) / alpha, static_cast(2.0)))) / calc_i0(static_cast(beta)); + return calc_i0(beta_ * std::sqrt(1 - std::pow((static_cast(a) - alpha) / alpha, static_cast(2.0)))) / calc_i0(beta_); }); }); } @@ -480,8 +482,8 @@ void rsqrt_kernel(TensorIteratorBase& iter) { } static void entr_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND( - kBFloat16, iter.common_dtype(), "entr_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, iter.common_dtype(), "entr_cpu", [&] { cpu_kernel(iter, [](scalar_t x) -> scalar_t { if (at::_isnan(x)) { return x; @@ -528,8 +530,8 @@ static void log_ndtr_kernel(TensorIteratorBase& iter) { static void i0e_kernel(TensorIteratorBase& iter) { TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); - AT_DISPATCH_FLOATING_TYPES_AND( - kBFloat16, iter.common_dtype(), "i0e_cpu", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, iter.common_dtype(), "i0e_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t x) { return calc_i0e(x); }, @@ -560,18 +562,19 @@ static void erfcx_kernel(TensorIteratorBase& iter){ } static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) { - AT_DISPATCH_FLOATING_TYPES_AND( - ScalarType::BFloat16, iter.dtype(), "round_cpu", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, iter.dtype(), "round_cpu", [&]() { + using opmath_t = at::opmath_type; bool neg_flag = false; - scalar_t ten_pow_decimals; + opmath_t ten_pow_decimals; if (decimals < 0) { decimals = -decimals; neg_flag = true; } - ten_pow_decimals = static_cast(std::pow(10, decimals)); + ten_pow_decimals = static_cast(std::pow(10, decimals)); cpu_kernel(iter, [ten_pow_decimals, neg_flag](scalar_t a) -> scalar_t { - return neg_flag ? std::nearbyint(a / ten_pow_decimals) * ten_pow_decimals - : std::nearbyint(a * ten_pow_decimals) / ten_pow_decimals; + return neg_flag ? std::nearbyint(static_cast(a) / ten_pow_decimals) * ten_pow_decimals + : std::nearbyint(static_cast(a) * ten_pow_decimals) / ten_pow_decimals; }); }); } diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1cf3e6a2c9f48..b49416db56ae1 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10149,7 +10149,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): )), UnaryUfuncInfo('cosh', ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh), - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, supports_forward_ad=True, @@ -14072,7 +14072,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): )), UnaryUfuncInfo('sinh', ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh), - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, supports_forward_ad=True, @@ -14881,7 +14881,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs), variant_test_name='polygamma_n_0', ref=reference_polygamma if TEST_SCIPY else None, - dtypes=all_types_and(torch.bool, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half), supports_forward_ad=True, supports_fwgrad_bwgrad=True, diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py index 0e68df69df06d..22a6543791375 100644 --- a/torch/testing/_internal/opinfo/definitions/special.py +++ b/torch/testing/_internal/opinfo/definitions/special.py @@ -105,8 +105,7 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs): aten_name="special_i0e", ref=scipy.special.i0e if TEST_SCIPY else None, decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),), - dtypes=all_types_and(torch.bool, torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), backward_dtypes=floating_types(), sample_inputs_func=sample_inputs_i0_i1, supports_forward_ad=True, @@ -177,7 +176,7 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs): op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs), variant_test_name="special_polygamma_n_0", ref=reference_polygamma if TEST_SCIPY else None, - dtypes=all_types_and(torch.bool, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half), supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -249,8 +248,7 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),), - dtypes=all_types_and(torch.bool, torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), skips=( DecorateInfo( unittest.skip("Skipped!"),