Skip to content

Commit

Permalink
Changes to make code compile under HIPRTC
Browse files Browse the repository at this point in the history
  • Loading branch information
stijnh committed Nov 18, 2024
1 parent 5490ea7 commit 9385655
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 249 deletions.
2 changes: 1 addition & 1 deletion include/kernel_float/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ using promoted_vector_value_type = promote_t<vector_value_type<Vs>...>;

template<typename V>
KERNEL_FLOAT_INLINE vector_storage_type<V> into_vector_storage(V&& input) {
return into_vector_impl<V>::call(std::forward<V>(input));
return into_vector_impl<V>::call(static_cast<V&&>(input));
}

} // namespace kernel_float
Expand Down
69 changes: 28 additions & 41 deletions include/kernel_float/bf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
#include "macros.h"

#if KERNEL_FLOAT_BF16_AVAILABLE
//#define CUDA_NO_BFLOAT16 (1)
//#define __CUDA_NO_BFLOAT16_OPERATORS__ (1)
//#define __CUDA_NO_BFLOAT162_OPERATORS__ (1)
//#define __CUDA_NO_BFLOAT16_CONVERSIONS__ (1)

#if KERNEL_FLOAT_IS_CUDA
#include <cuda_bf16.h>
#elif KERNEL_FLOAT_IS_HIP
Expand Down Expand Up @@ -76,21 +81,24 @@ struct allow_float_fallback<__bfloat16> {
}; \
}

KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2)
KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos)

KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp)
KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10)
KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor)
KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log)
KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2)
KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt)
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)

KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt)
KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt)
KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp)

KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2)
KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor)
KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
#endif

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
Expand All @@ -99,7 +107,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp)
template<> \
struct NAME<__bfloat16> { \
KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 left, __bfloat16 right) const { \
return FUN1(left, right); \
return ops::cast<decltype(FUN1(left, right)), __bfloat16> {}(FUN1(left, right)); \
} \
}; \
} \
Expand Down Expand Up @@ -159,29 +167,6 @@ struct apply_impl<ops::fma<__bfloat16>, 2, __bfloat16, __bfloat16, __bfloat16, _
} // namespace detail
#endif

namespace ops {
template<>
struct cast<double, __bfloat16> {
KERNEL_FLOAT_INLINE __bfloat16 operator()(double input) {
return __double2bfloat16(input);
};
};

template<>
struct cast<float, __bfloat16> {
KERNEL_FLOAT_INLINE __bfloat16 operator()(float input) {
return __float2bfloat16(input);
};
};

template<>
struct cast<__bfloat16, float> {
KERNEL_FLOAT_INLINE float operator()(__bfloat16 input) {
return __bfloat162float(input);
};
};
} // namespace ops

#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \
namespace ops { \
template<> \
Expand All @@ -198,31 +183,33 @@ struct cast<__bfloat16, float> {
}; \
}

KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input))
KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), __bfloat162float(input))

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
// clang-format off
// there are no official char casts. Instead, cast to int and then to char
KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input));
KERNEL_FLOAT_BF16_CAST(signed char, __int2bfloat16_rn(input), (signed char)__bfloat162int_rz(input));
KERNEL_FLOAT_BF16_CAST(unsigned char, __int2bfloat16_rn(input), (unsigned char)__bfloat162int_rz(input));

KERNEL_FLOAT_BF16_CAST(signed short, __bfloat162short_rz(input), __short2bfloat16_rn(input));
KERNEL_FLOAT_BF16_CAST(signed int, __bfloat162int_rz(input), __int2bfloat16_rn(input));
KERNEL_FLOAT_BF16_CAST(signed short, __short2bfloat16_rn(input), __bfloat162short_rz(input));
KERNEL_FLOAT_BF16_CAST(signed int, __int2bfloat16_rn(input), __bfloat162int_rz(input));
KERNEL_FLOAT_BF16_CAST(signed long, __ll2bfloat16_rn(input), (signed long)(__bfloat162ll_rz(input)));
KERNEL_FLOAT_BF16_CAST(signed long long, __ll2bfloat16_rn(input), __bfloat162ll_rz(input));

KERNEL_FLOAT_BF16_CAST(unsigned short, __bfloat162ushort_rz(input), __ushort2bfloat16_rn(input));
KERNEL_FLOAT_BF16_CAST(unsigned int, __bfloat162uint_rz(input), __uint2bfloat16_rn(input));
KERNEL_FLOAT_BF16_CAST(unsigned short, __ushort2bfloat16_rn(input), __bfloat162ushort_rz(input));
KERNEL_FLOAT_BF16_CAST(unsigned int, __uint2bfloat16_rn(input), __bfloat162uint_rz(input));
KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input)));
KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input));
// clang-format on
#endif

#if KERNEL_FLOAT_IS_CUDA
KERNEL_FLOAT_BF16_CAST(
bool,
__nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00},
(__nv_bfloat16_raw(input).x & 0x7FFF) != 0);

//KERNEL_FLOAT_BF16_CAST(
// bool,
// __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00},
// (__nv_bfloat16_raw(input).x & 0x7FFF) != 0);
#elif KERNEL_FLOAT_IS_HIP
KERNEL_FLOAT_BF16_CAST(
bool,
Expand Down
8 changes: 4 additions & 4 deletions include/kernel_float/binops.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
return result;
}

#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \
template<typename L, typename R, typename C = promoted_vector_value_type<L, R>> \
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME(L&& left, R&& right) { \
return zip_common(ops::NAME<C> {}, std::forward<L>(left), std::forward<R>(right)); \
#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \
template<typename L, typename R, typename C = promoted_vector_value_type<L, R>> \
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME(L&& left, R&& right) { \
return zip_common(ops::NAME<C> {}, static_cast<L&&>(left), static_cast<R&&>(right)); \
}

#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR, EXPR_F64, EXPR_F32) \
Expand Down
68 changes: 38 additions & 30 deletions include/kernel_float/fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
#include "macros.h"

#if KERNEL_FLOAT_FP16_AVAILABLE
//#define CUDA_NO_HALF (1)
//#define __CUDA_NO_HALF_OPERATORS__ (1)
//#define __CUDA_NO_HALF2_OPERATORS__ (1)
//#define __CUDA_NO_HALF_CONVERSIONS__ (1)

#if KERNEL_FLOAT_IS_CUDA
#include <cuda_fp16.h>
#elif KERNEL_FLOAT_IS_HIP
Expand Down Expand Up @@ -64,41 +69,44 @@ struct allow_float_fallback<__half> {
#define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2)
#endif

KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2)
KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2)
KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil)
KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin)
KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos)

KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp)
KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10)
KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor)
KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log)
KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2)
KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint)
KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt)
KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin)

KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt)
KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc)
KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt)
KERNEL_FLOAT_FP16_UNARY_FUN(rcp, hrcp, h2rcp)

KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2)
KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor)
KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil)
KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint)
KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc)
KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2)

#if KERNEL_FLOAT_IS_DEVICE
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \
namespace ops { \
template<> \
struct NAME<__half> { \
KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \
return FUN1(left, right); \
} \
}; \
} \
namespace detail { \
template<> \
struct apply_impl<ops::NAME<__half>, 2, __half, __half, __half> { \
KERNEL_FLOAT_INLINE static void \
call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \
__half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \
result[0] = r.x, result[1] = r.y; \
} \
}; \
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \
namespace ops { \
template<> \
struct NAME<__half> { \
KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \
return ops::cast<decltype(FUN1(left, right)), __half> {}(FUN1(left, right)); \
} \
}; \
} \
namespace detail { \
template<> \
struct apply_impl<ops::NAME<__half>, 2, __half, __half, __half> { \
KERNEL_FLOAT_INLINE static void \
call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \
__half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \
result[0] = r.x, result[1] = r.y; \
} \
}; \
}
#else
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2)
Expand Down Expand Up @@ -190,13 +198,13 @@ KERNEL_FLOAT_FP16_CAST(char, __int2half_rn(input), (char)__half2int_rz(input));
KERNEL_FLOAT_FP16_CAST(signed char, __int2half_rn(input), (signed char)__half2int_rz(input));
KERNEL_FLOAT_FP16_CAST(unsigned char, __int2half_rn(input), (unsigned char)__half2int_rz(input));

KERNEL_FLOAT_FP16_CAST(signed short, __half2short_rz(input), __short2half_rn(input));
KERNEL_FLOAT_FP16_CAST(signed int, __half2int_rz(input), __int2half_rn(input));
KERNEL_FLOAT_FP16_CAST(signed short, __short2half_rn(input), __half2short_rz(input));
KERNEL_FLOAT_FP16_CAST(signed int, __int2half_rn(input), __half2int_rz(input));
KERNEL_FLOAT_FP16_CAST(signed long, __ll2half_rn(input), (signed long)(__half2ll_rz(input)));
KERNEL_FLOAT_FP16_CAST(signed long long, __ll2half_rn(input), __half2ll_rz(input));

KERNEL_FLOAT_FP16_CAST(unsigned short, __half2ushort_rz(input), __ushort2half_rn(input));
KERNEL_FLOAT_FP16_CAST(unsigned int, __half2uint_rz(input), __uint2half_rn(input));
KERNEL_FLOAT_FP16_CAST(unsigned short, __ushort2half_rn(input), __half2ushort_rz(input));
KERNEL_FLOAT_FP16_CAST(unsigned int, __uint2half_rn(input), __half2uint_rz(input));
KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input)));
KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input));
#endif
Expand Down
3 changes: 0 additions & 3 deletions include/kernel_float/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,6 @@ struct enable_if_impl<true, T> {
template<bool C, typename T = void>
using enable_if_t = typename detail::enable_if_impl<C, T>::type;

template<typename T, typename...>
using identity_t = T;

KERNEL_FLOAT_INLINE
constexpr size_t round_up_to_power_of_two(size_t n) {
size_t result = 1;
Expand Down
8 changes: 4 additions & 4 deletions include/kernel_float/prelude.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ KERNEL_FLOAT_TYPE_ALIAS(float16x, __half)
#endif

#if KERNEL_FLOAT_BF16_AVAILABLE
KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __nv_bfloat16)
KERNEL_FLOAT_TYPE_ALIAS(bf16x, __nv_bfloat16)
KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __bfloat16)
KERNEL_FLOAT_TYPE_ALIAS(bf16x, __bfloat16)
#endif

#if KERNEL_FLOAT_BF8_AVAILABLE
Expand All @@ -82,12 +82,12 @@ static constexpr extent<N> kextent = {};

template<typename... Args>
KERNEL_FLOAT_INLINE kvec<promote_t<Args...>, sizeof...(Args)> make_kvec(Args&&... args) {
return ::kernel_float::make_vec(std::forward<Args>(args)...);
return ::kernel_float::make_vec(static_cast<Args&&>(args)...);
};

template<typename V>
KERNEL_FLOAT_INLINE into_vector_type<V> into_kvec(V&& input) {
return ::kernel_float::into_vec(std::forward<V>(input));
return ::kernel_float::into_vec(static_cast<V&&>(input));
}

template<typename T = double>
Expand Down
66 changes: 35 additions & 31 deletions include/kernel_float/unops.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast<bool, T> {}(!ops::cast<T
template<typename T> \
struct NAME<T, enable_if_t<detail::allow_float_fallback<T>::value>> { \
KERNEL_FLOAT_INLINE T operator()(T input_arg) { \
float input = float(input_arg); \
return T(EXPR_F32); \
float input = ops::cast<T, float> {}(input_arg); \
return ops::cast<decltype(EXPR_F32), T> {}(EXPR_F32); \
} \
}; \
\
Expand All @@ -140,52 +140,56 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast<bool, T> {}(!ops::cast<T
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(NAME, ::NAME(input), ::NAME(input)) \
KERNEL_FLOAT_DEFINE_UNARY_FUN(NAME)

KERNEL_FLOAT_DEFINE_UNARY_MATH(sin)
KERNEL_FLOAT_DEFINE_UNARY_MATH(cos)
KERNEL_FLOAT_DEFINE_UNARY_MATH(tan)
KERNEL_FLOAT_DEFINE_UNARY_MATH(asin)
KERNEL_FLOAT_DEFINE_UNARY_MATH(acos)
KERNEL_FLOAT_DEFINE_UNARY_MATH(abs)
KERNEL_FLOAT_DEFINE_UNARY_MATH(atan)

KERNEL_FLOAT_DEFINE_UNARY_MATH(sinh)
KERNEL_FLOAT_DEFINE_UNARY_MATH(cosh)
KERNEL_FLOAT_DEFINE_UNARY_MATH(tanh)
KERNEL_FLOAT_DEFINE_UNARY_MATH(acosh)
KERNEL_FLOAT_DEFINE_UNARY_MATH(asin)
KERNEL_FLOAT_DEFINE_UNARY_MATH(asinh)
KERNEL_FLOAT_DEFINE_UNARY_MATH(atan)
KERNEL_FLOAT_DEFINE_UNARY_MATH(atanh)
KERNEL_FLOAT_DEFINE_UNARY_MATH(cbrt)
KERNEL_FLOAT_DEFINE_UNARY_MATH(ceil)
KERNEL_FLOAT_DEFINE_UNARY_MATH(cos)
KERNEL_FLOAT_DEFINE_UNARY_MATH(cosh)
KERNEL_FLOAT_DEFINE_UNARY_MATH(erf)
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfc)
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcinv)
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcx)
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfinv)

KERNEL_FLOAT_DEFINE_UNARY_MATH(exp)
KERNEL_FLOAT_DEFINE_UNARY_MATH(exp10)
KERNEL_FLOAT_DEFINE_UNARY_MATH(exp2)
KERNEL_FLOAT_DEFINE_UNARY_MATH(exp10)
KERNEL_FLOAT_DEFINE_UNARY_MATH(expm1)
KERNEL_FLOAT_DEFINE_UNARY_MATH(fabs)
KERNEL_FLOAT_DEFINE_UNARY_MATH(floor)
KERNEL_FLOAT_DEFINE_UNARY_MATH(ilogb)
KERNEL_FLOAT_DEFINE_UNARY_MATH(lgamma)
KERNEL_FLOAT_DEFINE_UNARY_MATH(log)
KERNEL_FLOAT_DEFINE_UNARY_MATH(log10)
KERNEL_FLOAT_DEFINE_UNARY_MATH(log2)
KERNEL_FLOAT_DEFINE_UNARY_MATH(nearbyint)
KERNEL_FLOAT_DEFINE_UNARY_MATH(log10)
KERNEL_FLOAT_DEFINE_UNARY_MATH(log1p)

KERNEL_FLOAT_DEFINE_UNARY_MATH(erf)
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfinv)
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfc)
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcx)
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcinv)
KERNEL_FLOAT_DEFINE_UNARY_MATH(normcdf)
KERNEL_FLOAT_DEFINE_UNARY_MATH(rcbrt)
KERNEL_FLOAT_DEFINE_UNARY_MATH(sin)
KERNEL_FLOAT_DEFINE_UNARY_MATH(sinh)
KERNEL_FLOAT_DEFINE_UNARY_MATH(sqrt)
KERNEL_FLOAT_DEFINE_UNARY_MATH(tan)
KERNEL_FLOAT_DEFINE_UNARY_MATH(tanh)
KERNEL_FLOAT_DEFINE_UNARY_MATH(lgamma)
KERNEL_FLOAT_DEFINE_UNARY_MATH(tgamma)
KERNEL_FLOAT_DEFINE_UNARY_MATH(trunc)
KERNEL_FLOAT_DEFINE_UNARY_MATH(rint)

KERNEL_FLOAT_DEFINE_UNARY_MATH(sqrt)
KERNEL_FLOAT_DEFINE_UNARY_MATH(rsqrt)
KERNEL_FLOAT_DEFINE_UNARY_MATH(cbrt)
KERNEL_FLOAT_DEFINE_UNARY_MATH(rcbrt)

KERNEL_FLOAT_DEFINE_UNARY_MATH(abs)
KERNEL_FLOAT_DEFINE_UNARY_MATH(fabs)
KERNEL_FLOAT_DEFINE_UNARY_MATH(floor)
KERNEL_FLOAT_DEFINE_UNARY_MATH(round)
KERNEL_FLOAT_DEFINE_UNARY_MATH(ceil)
KERNEL_FLOAT_DEFINE_UNARY_MATH(trunc)
KERNEL_FLOAT_DEFINE_UNARY_MATH(rint)

// There are not support on HIP
#if !KERNEL_FLOAT_IS_HIP
KERNEL_FLOAT_DEFINE_UNARY_MATH(signbit)
KERNEL_FLOAT_DEFINE_UNARY_MATH(isinf)
KERNEL_FLOAT_DEFINE_UNARY_MATH(isnan)
KERNEL_FLOAT_DEFINE_UNARY_MATH(isinf)
KERNEL_FLOAT_DEFINE_UNARY_MATH(isfinite)
#endif

// CUDA offers special reciprocal functions (rcp), but only on the device.
Expand Down
Loading

0 comments on commit 9385655

Please sign in to comment.