Skip to content

Commit

Permalink
[Inductor] Use sleef implementation for CPP backend asinh codegen (py…
Browse files Browse the repository at this point in the history
…torch#142360)

**Summary**
Fix pytorch#142345. Previously, we use `asinh(x) = log(x + sqrt(1 + x**2))` to calculate the result of `asinh`, the issue happens when input with `-10000.1`, which makes `x + sqrt(1 + x**2)` close to 0 and log(0) is invalid. We use the `sleef` implementation in this PR to fix this issue.

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_asinh_with_corner_inputs
```

Pull Request resolved: pytorch#142360
Approved by: https://github.com/jgong5
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed Dec 14, 2024
1 parent d531648 commit 00b0210
Show file tree
Hide file tree
Showing 16 changed files with 59 additions and 3 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec/sve/vec_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ template <> class Vectorized<double> {
Vectorized<double> asin() const {
return USE_SLEEF(Vectorized<double>(Sleef_asindx_u10sve(values)),map(std::asin));
}
Vectorized<double> asinh() const {
return USE_SLEEF(Vectorized<double>(Sleef_asinhdx_u10sve(values)),map(std::asinh));
}
Vectorized<double> atan() const {
return USE_SLEEF(Vectorized<double>(Sleef_atandx_u10sve(values)),map(std::atan));
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec/sve/vec_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ template <> class Vectorized<float> {
Vectorized<float> asin() const {
return USE_SLEEF(Vectorized<float>(Sleef_asinfx_u10sve(values)),map(std::asin));
}
Vectorized<float> asinh() const {
return USE_SLEEF(Vectorized<float>(Sleef_asinhfx_u10sve(values)),map(std::asinh));
}
Vectorized<float> atan() const {
return USE_SLEEF(Vectorized<float>(Sleef_atanfx_u10sve(values)),map(std::atan));
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ template <> class Vectorized<float> {
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acos)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acosh)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(asin)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(asinh)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atan)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atanh)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ struct Vectorized16 {
Derived asin() const {
return static_cast<const Derived*>(this)->map_with_vec_float_method(&Vectorized<float>::asin);
}
Derived asinh() const {
return static_cast<const Derived*>(this)->map_with_vec_float_method(&Vectorized<float>::asinh);
}
Derived atan() const {
return static_cast<const Derived*>(this)->map_with_vec_float_method(&Vectorized<float>::atan);
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ static_assert(
Vectorized<T> asin() const {
return map(Sleef_asinf8_u10);
}
Vectorized<T> asinh() const {
return map(Sleef_asinhf8_u10);
}
Vectorized<T> atan() const {
return map(Sleef_atanf8_u10);
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ template <> class Vectorized<double> {
Vectorized<double> asin() const {
return Vectorized<double>(Sleef_asind4_u10(values));
}
Vectorized<double> asinh() const {
return Vectorized<double>(Sleef_asinhd4_u10(values));
}
Vectorized<double> atan() const {
return Vectorized<double>(Sleef_atand4_u10(values));
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ template <> class Vectorized<float> {
Vectorized<float> asin() const {
return Vectorized<float>(Sleef_asinf8_u10(values));
}
Vectorized<float> asinh() const {
return Vectorized<float>(Sleef_asinhf8_u10(values));
}
Vectorized<float> atan() const {
return Vectorized<float>(Sleef_atanf8_u10(values));
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ class Vectorized<double> {
Vectorized<double> C10_ALWAYS_INLINE asin() const {
return {Sleef_asind2_u10(_vec0), Sleef_asind2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE asinh() const {
return {Sleef_asinhd2_u10(_vec0), Sleef_asinhd2_u10(_vec1)};
}
Vectorized<double> atan() const {
return {Sleef_atand2_u10(_vec0), Sleef_atand2_u10(_vec1)};
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ class Vectorized<float> {
Vectorized<float> C10_ALWAYS_INLINE asin() const {
return {Sleef_asinf4_u10(_vec0), Sleef_asinf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE asinh() const {
return {Sleef_asinhf4_u10(_vec0), Sleef_asinhf4_u10(_vec1)};
}
Vectorized<float> atan() const {
return {Sleef_atanf4_u10(_vec0), Sleef_atanf4_u10(_vec1)};
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ static_assert(
Vectorized<T> asin() const {
return map(Sleef_asinf16_u10);
}
Vectorized<T> asinh() const {
return map(Sleef_asinhf16_u10);
}
Vectorized<T> atan() const {
return map(Sleef_atanf16_u10);
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec/vec512/vec512_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ template <> class Vectorized<double> {
Vectorized<double> asin() const {
return Vectorized<double>(Sleef_asind8_u10(values));
}
Vectorized<double> asinh() const {
return Vectorized<double>(Sleef_asinhd8_u10(values));
}
Vectorized<double> atan() const {
return Vectorized<double>(Sleef_atand8_u10(values));
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec/vec512/vec512_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ template <> class Vectorized<float> {
Vectorized<float> asin() const {
return Vectorized<float>(Sleef_asinf16_u10(values));
}
Vectorized<float> asinh() const {
return Vectorized<float>(Sleef_asinhf16_u10(values));
}
Vectorized<float> atan() const {
return Vectorized<float>(Sleef_atanf16_u10(values));
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec/vec_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ struct Vectorized {
Vectorized<T> asin() const {
return map(std::asin);
}
Vectorized<T> asinh() const {
return map(std::asinh);
}
Vectorized<T> atan() const {
return map(std::atan);
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/cpu/vec/vec_n.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ class VectorizedN {
VECTORIZEDN_DEFINE_UNARY_OP(acos)
VECTORIZEDN_DEFINE_UNARY_OP(acosh)
VECTORIZEDN_DEFINE_UNARY_OP(asin)
VECTORIZEDN_DEFINE_UNARY_OP(asinh)
VECTORIZEDN_DEFINE_UNARY_OP(atan)
VECTORIZEDN_DEFINE_UNARY_OP(atanh)
VECTORIZEDN_DEFINE_BINARY_OP(atan2)
Expand Down
20 changes: 20 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,26 @@ def fn(input):
(_x,),
)

@requires_vectorization
def test_asinh_with_corner_inputs(self):
# https://github.com/pytorch/pytorch/issues/142345

def fn(input):
out = torch.asinh(input)
return out

x = torch.tensor([0, 0, 0, -10000.1]).repeat(3, 4)

bit_widths = [isa._bit_width for isa in cpu_vec_isa.valid_vec_isa_list()]
for dtype in [torch.float32, torch.bfloat16, torch.float16, torch.double]:
for simdlen in bit_widths:
with torch.no_grad(), config.patch({"cpp.simdlen": simdlen}):
torch._dynamo.reset()
metrics.reset()
_x = x.to(dtype)
self.common(fn, (_x,))
check_metrics_vec_kernel_count(1)

@config.patch(implicit_fallbacks=True)
def test_repeat_interleave(self):
def fn(y):
Expand Down
4 changes: 1 addition & 3 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,9 +1404,7 @@ def atanh(x):

@staticmethod
def asinh(x):
# For real x, asinh(x) = log(x + sqrt(1 + x**2))
vec_one = f"decltype({x})(1)"
return f"({x} + ({vec_one} + {x}*{x}).sqrt()).log()"
return f"{x}.asinh()"

@staticmethod
def acosh(x):
Expand Down

0 comments on commit 00b0210

Please sign in to comment.