Skip to content

Commit

Permalink
[6/N] Fix Wextra-semi warning (pytorch#139605)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#139605
Approved by: https://github.com/ezyang
  • Loading branch information
cyyever authored and pytorchmergebot committed Nov 4, 2024
1 parent 2ce2e4d commit 419a7e1
Show file tree
Hide file tree
Showing 105 changed files with 399 additions and 399 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ void OperatorEntry::reportSignatureError(const CppSignature& call_signature, con
"This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). ",
"Please make sure that the function signature matches the signature in the operator registration call."
);
};
}

#ifndef STRIP_ERROR_MESSAGES
static std::string post_process_dispatch_key_str(std::string dispatch_key) {
Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ template <typename T, typename std::enable_if_t<is_reduced_floating_point_v<T>,
inline void cvt_to_fp32(const __m128i& a, __m256& o);
template <> inline void cvt_to_fp32<BFloat16>(const __m128i& a, __m256& o) {
cvtbf16_fp32(a, o);
};
}
template <> inline void cvt_to_fp32<Half>(const __m128i& a, __m256& o) {
cvtfp16_fp32(a, o);
}
Expand Down Expand Up @@ -1071,8 +1071,8 @@ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(c
inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const Vectorized<float>& b) { \
return cvt_from_fp32<type>(__m256(a), __m256(b)); \
}
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_VECTORIZED_INIT(Half, half);
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16)
CONVERT_VECTORIZED_INIT(Half, half)

#else // defined(CPU_CAPABILITY_AVX2)

Expand All @@ -1096,9 +1096,9 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
convert(arr, arr2, K); \
return Vectorized<type>::loadu(arr2); \
}
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16)
#if !(defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE256))
CONVERT_NON_VECTORIZED_INIT(Half, half);
CONVERT_NON_VECTORIZED_INIT(Half, half)
#endif

#endif // defined(CPU_CAPABILITY_AVX2)
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/functorch/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
}

void Interpreter::process(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
INTERPRETER_DISPATCH(key_, SINGLE_ARG(processImpl(op, stack)));
INTERPRETER_DISPATCH(key_, SINGLE_ARG(processImpl(op, stack)))
}

void Interpreter::sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case) {
INTERPRETER_DISPATCH(key_, SINGLE_ARG(sendToNextInterpreterImpl(op, stack, grad_special_case)));
INTERPRETER_DISPATCH(key_, SINGLE_ARG(sendToNextInterpreterImpl(op, stack, grad_special_case)))
}

} // namespace at::functorch
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSAllocator.mm
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace at::mps {

C10_DEFINE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
C10_DEFINE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback)

namespace HeapAllocator {

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/BlasKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ void gemv_fast_path<at::Half>(
y,
*incy);
}
INSTANTIATE(c10::BFloat16);
INSTANTIATE(c10::BFloat16)
#else
template <>
bool scal_use_fast_path<at::Half>(
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/EmbeddingBag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,7 @@ embedding_bag(const Tensor &weight, const Tensor &indices,
mode, sparse, per_sample_weights, include_last_offset, padding_idx);
}
return out;
};
}

std::tuple<Tensor, Tensor, Tensor, Tensor>
embedding_bag(const Tensor &weight, const Tensor &indices,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/MaxUnpooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Tensor& max_unpooling2d_forward_out_cpu(
}

return output;
};
}

Tensor max_unpooling2d_forward_cpu(
const Tensor& self,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ static std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose2d_backward_cpu(
return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
}

REGISTER_ALL_CPU_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cpu);
REGISTER_ALL_CPU_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cpu)

} // namespace native
} // namespace at
4 changes: 2 additions & 2 deletions aten/src/ATen/native/NaiveDilatedConvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ static std::tuple<Tensor, Tensor, Tensor> slow_conv_dilated3d_backward_cpu(
return std::tie(grad_input, grad_weight, grad_bias);
}

REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cpu);
REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cpu);
REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cpu)
REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cpu)

} // namespace at::native
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const
int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH);
using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);

DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel);
DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel)
DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel)

// averge pooling has same signature for forward and backward
Expand Down
34 changes: 17 additions & 17 deletions aten/src/ATen/native/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1187,10 +1187,10 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backwar
DEFINE_DISPATCH(NAME##_miopen_stub); \
DEFINE_DISPATCH(NAME##_packed_cudnn_stub); \
DEFINE_DISPATCH(NAME##_packed_miopen_stub); \
REGISTER_NO_CPU_DISPATCH(NAME##_cudnn_stub); \
REGISTER_NO_CPU_DISPATCH(NAME##_miopen_stub); \
REGISTER_NO_CPU_DISPATCH(NAME##_packed_cudnn_stub); \
REGISTER_NO_CPU_DISPATCH(NAME##_packed_miopen_stub); \
REGISTER_NO_CPU_DISPATCH(NAME##_cudnn_stub) \
REGISTER_NO_CPU_DISPATCH(NAME##_miopen_stub) \
REGISTER_NO_CPU_DISPATCH(NAME##_packed_cudnn_stub) \
REGISTER_NO_CPU_DISPATCH(NAME##_packed_miopen_stub) \
\
std::tuple<Tensor, Tensor> NAME( \
const Tensor& _input, \
Expand Down Expand Up @@ -1415,17 +1415,17 @@ static std::tuple<Tensor, Tensor> quantized_gru_data_legacy(
using tanf_cell_type = SimpleCell<tanh_f, CellParams>;
ONE_HIDDEN_RNN(rnn_tanh, tanf_cell_type)
using relu_cell_type = SimpleCell<relu_f, CellParams>;
ONE_HIDDEN_RNN(rnn_relu, relu_cell_type);
ONE_HIDDEN_RNN(rnn_relu, relu_cell_type)

DEFINE_DISPATCH(lstm_cudnn_stub);
DEFINE_DISPATCH(lstm_packed_cudnn_stub);
DEFINE_DISPATCH(lstm_miopen_stub);
DEFINE_DISPATCH(lstm_packed_miopen_stub);
DEFINE_DISPATCH(lstm_mkldnn_stub);
REGISTER_NO_CPU_DISPATCH(lstm_cudnn_stub);
REGISTER_NO_CPU_DISPATCH(lstm_packed_cudnn_stub);
REGISTER_NO_CPU_DISPATCH(lstm_miopen_stub);
REGISTER_NO_CPU_DISPATCH(lstm_packed_miopen_stub);
REGISTER_NO_CPU_DISPATCH(lstm_cudnn_stub)
REGISTER_NO_CPU_DISPATCH(lstm_packed_cudnn_stub)
REGISTER_NO_CPU_DISPATCH(lstm_miopen_stub)
REGISTER_NO_CPU_DISPATCH(lstm_packed_miopen_stub)

std::tuple<Tensor, Tensor, Tensor> lstm(
const Tensor& _input, TensorList hx,
Expand Down Expand Up @@ -1857,9 +1857,9 @@ static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
// Quantized LSTM cell
using quantized_lstm_cell_dynamic_type = LSTMCell<QuantizedCellParamsDynamic>;

DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx)

static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx)

// Helpers for simpler cells
using simple_hx_type = const Tensor&;
Expand All @@ -1871,21 +1871,21 @@ static simple_hx_type prepare_quantized_hx(simple_hx_type hx) {
using quantized_gru_cell_type = GRUCell<QuantizedCellParams>;
using quantized_gru_cell_dynamic_type = GRUCell<QuantizedCellParamsDynamic>;

DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx);
DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx)

static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx);
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx)

// Quantized RNN w/ ReLU cell
using quantized_rnn_relu_cell_type = SimpleCell<relu_f, QuantizedCellParams>;
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx);
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx)
using quantized_rnn_relu_cell_dynamic_type = SimpleCell<relu_f, QuantizedCellParamsDynamic>;
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx);
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx)

// Quantized RNN w/ tanh cell
using quantized_rnn_tanh_cell_type = SimpleCell<tanh_f, QuantizedCellParams>;
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx);
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx)
using quantized_rnn_tanh_cell_dynamic_type = SimpleCell<tanh_f, QuantizedCellParamsDynamic>;
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx);
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx)

namespace {

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,11 +932,11 @@ Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {

Tensor special_multigammaln(const Tensor& self, int64_t p) {
return self.mvlgamma(p);
};
}

Tensor& special_multigammaln_out(const Tensor& self, int64_t p, Tensor& result) {
return at::mvlgamma_out(result, self, p);
};
}

std::tuple<Tensor, Tensor> frexp(const Tensor& self) {
Tensor mantissa = at::empty_like(self);
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1415,16 +1415,16 @@ REGISTER_DISPATCH(laguerre_polynomial_l_stub, &laguerre_polynomial_l_kernel)
REGISTER_DISPATCH(legendre_polynomial_p_stub, &legendre_polynomial_p_kernel)
REGISTER_DISPATCH(
shifted_chebyshev_polynomial_t_stub,
&shifted_chebyshev_polynomial_t_kernel);
&shifted_chebyshev_polynomial_t_kernel)
REGISTER_DISPATCH(
shifted_chebyshev_polynomial_u_stub,
&shifted_chebyshev_polynomial_u_kernel);
&shifted_chebyshev_polynomial_u_kernel)
REGISTER_DISPATCH(
shifted_chebyshev_polynomial_v_stub,
&shifted_chebyshev_polynomial_v_kernel);
&shifted_chebyshev_polynomial_v_kernel)
REGISTER_DISPATCH(
shifted_chebyshev_polynomial_w_stub,
&shifted_chebyshev_polynomial_w_kernel);
&shifted_chebyshev_polynomial_w_kernel)
// Might enable AVX512 dispatch after enabling explicit vectorization for them.
REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_kernel)
REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_kernel)
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/MultinomialKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,5 +241,5 @@ static void multinomial_with_replacement_kernel_impl(

REGISTER_DISPATCH(
multinomial_with_replacement_stub,
&multinomial_with_replacement_kernel_impl);
&multinomial_with_replacement_kernel_impl)
} // namespace at::native
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ inline namespace CPU_CAPABILITY {
constexpr auto kF32RegisterPairsPerIteration = 4;
constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2;
constexpr auto kF32ElementsPerRegister = vec::Vectorized<float>::size();
constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;;
constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;

namespace {
template <typename T>
Expand Down Expand Up @@ -328,8 +328,8 @@ void fp16_gemv_trans(
#if !defined(C10_MOBILE)
// NOTE: we don't *need* to go through dispatch for the ARM-only
// implementation right now, but we will need it when we cover x86.
REGISTER_DISPATCH(fp16_dot_with_fp32_arith_stub, &fp16_dot_with_fp32_arith);
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans);
REGISTER_DISPATCH(fp16_dot_with_fp32_arith_stub, &fp16_dot_with_fp32_arith)
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans)
#else
#endif // defined(__aarch64__) && !defined(C10_MOBILE)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace at::native {
#if !defined(C10_MOBILE)
using fp16_dot_fn = float(*)(const Half*, const Half*, int64_t);
using fp16_gemv_fn = void(*)(int, int, float, const Half*, int, const Half*, int, float, Half*, int);
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_with_fp32_arith_stub);
DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub);
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_with_fp32_arith_stub)
DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub)
#endif // !defined(C10_MOBILE)
} // namespace at::native
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cpu/SoftMaxKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1295,15 +1295,15 @@ ALSO_REGISTER_AVX512_DISPATCH(softmax_lastdim_kernel, &softmax_lastdim_kernel_im
ALSO_REGISTER_AVX512_DISPATCH(log_softmax_lastdim_kernel, &log_softmax_lastdim_kernel_impl)
ALSO_REGISTER_AVX512_DISPATCH(
softmax_backward_lastdim_kernel,
&softmax_backward_lastdim_kernel_impl);
&softmax_backward_lastdim_kernel_impl)
ALSO_REGISTER_AVX512_DISPATCH(
log_softmax_backward_lastdim_kernel,
&log_softmax_backward_lastdim_kernel_impl);
&log_softmax_backward_lastdim_kernel_impl)

ALSO_REGISTER_AVX512_DISPATCH(softmax_kernel, &softmax_kernel_impl)
ALSO_REGISTER_AVX512_DISPATCH(log_softmax_kernel, &log_softmax_kernel_impl)
ALSO_REGISTER_AVX512_DISPATCH(softmax_backward_kernel, &softmax_backward_kernel_impl)
ALSO_REGISTER_AVX512_DISPATCH(
log_softmax_backward_kernel,
&log_softmax_backward_kernel_impl);
&log_softmax_backward_kernel_impl)
} // namespace at::native
46 changes: 23 additions & 23 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,15 +830,15 @@ REGISTER_DISPATCH(special_i0e_stub, &CPU_CAPABILITY::i0e_kernel)
REGISTER_DISPATCH(special_ndtri_stub, &CPU_CAPABILITY::ndtri_kernel)
REGISTER_DISPATCH(special_modified_bessel_k0_stub, &CPU_CAPABILITY::modified_bessel_k0_kernel)
REGISTER_DISPATCH(special_modified_bessel_k1_stub, &CPU_CAPABILITY::modified_bessel_k1_kernel)
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(ceil);
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(floor);
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(round);
IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sqrt);
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(trunc);
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(i0);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sin);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(cos);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(tan);
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(ceil)
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(floor)
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(round)
IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sqrt)
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(trunc)
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(i0)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sin)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(cos)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(tan)

// The following kernels are compute-intensive & are compiled with both AVX512
// & AVX2
Expand Down Expand Up @@ -871,19 +871,19 @@ REGISTER_DISPATCH(special_bessel_y1_stub, &CPU_CAPABILITY::bessel_y1_kernel)
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &CPU_CAPABILITY::modified_bessel_i0_kernel)
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &CPU_CAPABILITY::modified_bessel_i1_kernel)

STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(acos);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(asin);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(atan);
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erf);
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfc);
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfinv);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(exp);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(expm1);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log10);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log1p);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log2);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(tanh);
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(lgamma);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(acos)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(asin)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(atan)
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erf)
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfc)
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfinv)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(exp)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(expm1)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log10)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log1p)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log2)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(tanh)
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(lgamma)

} // namespace at::native
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/DepthwiseConv2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,6 @@ std::tuple<Tensor, Tensor> conv_depthwise2d_backward_cuda(
grad_weight);
}

REGISTER_CUDA_DISPATCH(conv_depthwise2d_backward_stub, &conv_depthwise2d_backward_cuda);
REGISTER_CUDA_DISPATCH(conv_depthwise2d_backward_stub, &conv_depthwise2d_backward_cuda)

} // namespace at::native
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/DepthwiseConv3d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ std::tuple<Tensor, Tensor, Tensor> conv_depthwise3d_backward_cuda(

}

REGISTER_CUDA_DISPATCH(conv_depthwise3d_backward_stub, &conv_depthwise3d_backward_cuda);
REGISTER_CUDA_DISPATCH(conv_depthwise3d_backward_stub, &conv_depthwise3d_backward_cuda)

#undef DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION
#undef DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/FlattenIndicesKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ Tensor flatten_indices_cuda_kernel(const Tensor& indices, IntArrayRef size) {

}

REGISTER_CUDA_DISPATCH(flatten_indices_stub, &flatten_indices_cuda_kernel);
REGISTER_CUDA_DISPATCH(flatten_indices_stub, &flatten_indices_cuda_kernel)

} // namespace at::native
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/IndexKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,6 @@ REGISTER_DISPATCH(put_stub, &put_kernel)
REGISTER_DISPATCH(take_stub, &take_kernel)
REGISTER_DISPATCH(flip_stub, &flip_kernel)

REGISTER_CUDA_DISPATCH(index_put_kernel_quantized_stub, &index_put_kernel_quantized_cuda);
REGISTER_CUDA_DISPATCH(index_put_kernel_quantized_stub, &index_put_kernel_quantized_cuda)

} // namespace at::native
Loading

0 comments on commit 419a7e1

Please sign in to comment.