Skip to content

Commit

Permalink
add Half support for sinh, cosh, ploygamma, entr and i0e on CPU (pyto…
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayisunx authored and pytorchmergebot committed May 23, 2023
1 parent f7c736e commit 5c3cf76
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 24 deletions.
35 changes: 19 additions & 16 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ static void sinc_kernel(TensorIteratorBase& iter) {
}

static void sinh_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "sinh_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "sinh_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return std::sinh(a); },
Expand All @@ -361,7 +361,7 @@ static void sinh_kernel(TensorIteratorBase& iter) {
}

static void cosh_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "cosh_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "cosh_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return std::cosh(a); },
Expand Down Expand Up @@ -425,7 +425,7 @@ static void polygamma_kernel(TensorIteratorBase& iter, int64_t n) {
} else if (n == 1) {
trigamma_kernel(iter);
} else {
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "polygamma", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "polygamma", [&]() {
cpu_kernel(
iter, [=](scalar_t a) -> scalar_t { return calc_polygamma(a, n); });
});
Expand Down Expand Up @@ -460,10 +460,12 @@ static void nan_to_num_kernel(
}

static void kaiser_window_kernel(TensorIteratorBase& iter, int64_t window_length, double beta){
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "kaiser_window_cpu", [&](){
const scalar_t alpha = static_cast<scalar_t>((window_length - 1) / 2.0);
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "kaiser_window_cpu", [&](){
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t alpha = static_cast<opmath_t>((window_length - 1) / 2.0);
const opmath_t beta_ = static_cast<opmath_t>(beta);
cpu_kernel(iter, [=](scalar_t a){
return calc_i0(static_cast<scalar_t>(beta) * std::sqrt(1 - std::pow((a - alpha) / alpha, static_cast<scalar_t>(2.0)))) / calc_i0(static_cast<scalar_t>(beta));
return calc_i0(beta_ * std::sqrt(1 - std::pow((static_cast<opmath_t>(a) - alpha) / alpha, static_cast<opmath_t>(2.0)))) / calc_i0(beta_);
});
});
}
Expand All @@ -480,8 +482,8 @@ void rsqrt_kernel(TensorIteratorBase& iter) {
}

static void entr_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND(
kBFloat16, iter.common_dtype(), "entr_cpu", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, iter.common_dtype(), "entr_cpu", [&] {
cpu_kernel(iter, [](scalar_t x) -> scalar_t {
if (at::_isnan(x)) {
return x;
Expand Down Expand Up @@ -528,8 +530,8 @@ static void log_ndtr_kernel(TensorIteratorBase& iter) {

static void i0e_kernel(TensorIteratorBase& iter) {
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES_AND(
kBFloat16, iter.common_dtype(), "i0e_cpu", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, iter.common_dtype(), "i0e_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t x) { return calc_i0e(x); },
Expand Down Expand Up @@ -560,18 +562,19 @@ static void erfcx_kernel(TensorIteratorBase& iter){
}

static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, iter.dtype(), "round_cpu", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, iter.dtype(), "round_cpu", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
bool neg_flag = false;
scalar_t ten_pow_decimals;
opmath_t ten_pow_decimals;
if (decimals < 0) {
decimals = -decimals;
neg_flag = true;
}
ten_pow_decimals = static_cast<scalar_t>(std::pow(10, decimals));
ten_pow_decimals = static_cast<opmath_t>(std::pow(10, decimals));
cpu_kernel(iter, [ten_pow_decimals, neg_flag](scalar_t a) -> scalar_t {
return neg_flag ? std::nearbyint(a / ten_pow_decimals) * ten_pow_decimals
: std::nearbyint(a * ten_pow_decimals) / ten_pow_decimals;
return neg_flag ? std::nearbyint(static_cast<opmath_t>(a) / ten_pow_decimals) * ten_pow_decimals
: std::nearbyint(static_cast<opmath_t>(a) * ten_pow_decimals) / ten_pow_decimals;
});
});
}
Expand Down
6 changes: 3 additions & 3 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10149,7 +10149,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
)),
UnaryUfuncInfo('cosh',
ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -14072,7 +14072,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
)),
UnaryUfuncInfo('sinh',
ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -14881,7 +14881,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
variant_test_name='polygamma_n_0',
ref=reference_polygamma if TEST_SCIPY else None,
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down
8 changes: 3 additions & 5 deletions torch/testing/_internal/opinfo/definitions/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
aten_name="special_i0e",
ref=scipy.special.i0e if TEST_SCIPY else None,
decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
backward_dtypes=floating_types(),
sample_inputs_func=sample_inputs_i0_i1,
supports_forward_ad=True,
Expand Down Expand Up @@ -177,7 +176,7 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs),
variant_test_name="special_polygamma_n_0",
ref=reference_polygamma if TEST_SCIPY else None,
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down Expand Up @@ -249,8 +248,7 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
Expand Down

0 comments on commit 5c3cf76

Please sign in to comment.