Skip to content

Commit

Permalink
Enable BFloat16 support for Convolutions on ROCm (pytorch#30948)
Browse files Browse the repository at this point in the history
Summary:
This PR adds bfloat16 support for convolutions on ROCm.

- Intergrates MIOpen bfloat16 convolution support into PyTorch

- Enables bfloat16 convolution for non-miopen paths, i.e THCUNN, native hip kernels

- Enables bfloat16 type for probability distribution functions(this is included in this PR since conv unit tests use bfloat16 random number generators)

Native cuda kernels for convolution and random functions will be compiled for CUDA as well.

iotamudelta bddppq
Pull Request resolved: pytorch#30948

Differential Revision: D19274164

Pulled By: ezyang

fbshipit-source-id: c0888a6ac72a2c5749b1ebb2195ac6f2209996be
  • Loading branch information
rohithkrn authored and facebook-github-bot committed Jan 7, 2020
1 parent a561a84 commit 985fd97
Show file tree
Hide file tree
Showing 20 changed files with 140 additions and 50 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/AccumulateType.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct AccumulateType { };

#if defined(__CUDACC__) || defined(__HIPCC__)
template <> struct AccumulateType<half, true> { using type = float; };
template <> struct AccumulateType<BFloat16, true> {using type = float; };
#endif
template <> struct AccumulateType<Half, true> { using type = float; };
template <> struct AccumulateType<float, true> { using type = float; };
Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,23 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
} \
}()

#define AT_DISPATCH_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(SCALARTYPE1, \
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(SCALARTYPE2, \
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
}()

#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/miopen/Descriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ inline miopenDataType_t getDataType(const at::Tensor& t) {
return miopenFloat;
} else if (scalar_type == at::kHalf) {
return miopenHalf;
} else if (scalar_type == at::kBFloat16) {
return miopenBFloat16;
} else {
throw std::runtime_error("TensorDescriptor only supports float, half and bfloat16 tensors");
}
throw std::runtime_error("TensorDescriptor only supports float and half tensors");
}

} // anonymous namespace
Expand Down Expand Up @@ -51,6 +54,8 @@ std::string miopenTypeToString(miopenDataType_t dtype) {
return "miopenFloat";
case miopenHalf:
return "miopenHalf";
case miopenBFloat16:
return "miopenBFloat16";
default:
std::ostringstream oss;
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/miopen/Descriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ inline int dataSize(miopenDataType_t dataType)
switch (dataType) {
case miopenHalf: return 2;
case miopenFloat: return 4;
case miopenBFloat16: return 2;
default: return 8;
}
}
Expand Down Expand Up @@ -145,7 +146,7 @@ union Constant
float f;
double d;
Constant(miopenDataType_t dataType, double value) {
if (dataType == miopenHalf || dataType == miopenFloat) {
if (dataType == miopenHalf || dataType == miopenFloat || dataType == miopenBFloat16) {
f = static_cast<float>(value);
} else {
d = value;
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/miopen/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ miopenDataType_t getMiopenDataType(const at::Tensor& tensor) {
return miopenFloat;
} else if (tensor.scalar_type() == at::kHalf) {
return miopenHalf;
} else if (tensor.scalar_type() == at::kBFloat16) {
return miopenBFloat16;
}
std::string msg("getMiopenDataType() not supported for ");
msg += toString(tensor.scalar_type());
Expand Down
11 changes: 6 additions & 5 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct ConvParams {
bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const;
bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
bool cudnn_use_channels_last(const at::Tensor& input, const at::Tensor& weight) const;
bool use_miopen(const at::Tensor& input) const;
bool use_miopen(const at::Tensor& input, bool bias_defined) const;
bool use_mkldnn(const at::Tensor& input) const;
bool use_nnpack(const at::Tensor& input) const;
bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
Expand Down Expand Up @@ -198,13 +198,14 @@ auto ConvParams::use_cudnn(const at::Tensor& input, const at::Tensor& weight) co
return !is_output_padding_big();
}

auto ConvParams::use_miopen(const at::Tensor& input) const -> bool {
auto ConvParams::use_miopen(const at::Tensor& input, bool bias_defined) const -> bool {

return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf))
return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16))
&& detail::getCUDAHooks().compiledWithMIOpen()
&& input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1
&& !(input.scalar_type() == at::kBFloat16 && bias_defined) // MIOpen currently doesn't support bias with bfloat16
;
}

Expand Down Expand Up @@ -637,7 +638,7 @@ at::Tensor _convolution(
output = output + reshape_bias(input.dim(), bias);
}

} else if (params.use_miopen(input)){
} else if (params.use_miopen(input, bias.defined())){
output = at::miopen_depthwise_convolution(
input.contiguous(), weight, bias,
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);
Expand Down Expand Up @@ -667,7 +668,7 @@ at::Tensor _convolution(
output = output + reshape_bias(input.dim(), bias);
}
}
} else if (params.use_miopen(input)) {
} else if (params.use_miopen(input, bias.defined())) {
TORCH_CHECK(input.options().type_equal(weight.options()),
"Input type (", input.toString(), ") and weight type (", weight.toString(),
") should be the same");
Expand Down
34 changes: 17 additions & 17 deletions aten/src/ATen/native/cuda/Distributions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen_) {
rng_engine_inputs = gen->philox_engine_inputs(20);
}
Tensor ret = at::empty(lambda.sizes(), lambda.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "poisson_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "poisson_cuda", [&] {
poisson_cuda_kernel<scalar_t>(ret, lambda, rng_engine_inputs);
});
return ret;
Expand All @@ -353,7 +353,7 @@ Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen_) {
rng_engine_inputs = gen->philox_engine_inputs(10);
}
Tensor ret = at::empty(alpha.sizes(), alpha.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "gamma_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "gamma_cuda", [&] {
gamma_cuda_kernel<scalar_t>(ret, alpha, rng_engine_inputs);
});
return ret;
Expand All @@ -368,7 +368,7 @@ Tensor _s_dirichlet_cuda(const Tensor& alpha, Generator* gen_) {
rng_engine_inputs = gen->philox_engine_inputs(10);
}
Tensor ret = at::empty(alpha.sizes(), alpha.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "dirichlet", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "dirichlet", [&] {
Tensor gamma = at::empty(alpha.sizes(), alpha.options());
gamma_cuda_kernel<scalar_t>(gamma, alpha, rng_engine_inputs);
dirichlet_scalar_cuda_kernel<scalar_t>(ret, gamma);
Expand All @@ -378,7 +378,7 @@ Tensor _s_dirichlet_cuda(const Tensor& alpha, Generator* gen_) {

Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
Tensor ret = at::empty(self.sizes(), self.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "_standard_gamma_grad_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "_standard_gamma_grad_cuda", [&] {
gamma_grad_cuda_kernel<scalar_t>(ret, self, output);
});
return ret;
Expand All @@ -402,10 +402,10 @@ Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen_)
rng_engine_inputs = gen->philox_engine_inputs(10);
}
auto p = std::get<0>(expand_inplace(self, p_.to(kCUDA)));
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
AT_DISPATCH_ALL_TYPES_AND3(
at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
using self_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.scalar_type(), "bernoulli_tensor_cuda_p_", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, p.scalar_type(), "bernoulli_tensor_cuda_p_", [&] {
using p_t = scalar_t;
return bernoulli_tensor_cuda_kernel<self_t, p_t>(self, p, rng_engine_inputs);
});
Expand All @@ -415,7 +415,7 @@ Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen_)

void uniform_kernel_cuda(TensorIterator& iter, double from_, double to_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "uniform_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_cuda", [&] {
auto from = static_cast<scalar_t>(from_);
auto to = static_cast<scalar_t>(to_);
TORCH_CHECK(from <= to,
Expand Down Expand Up @@ -454,7 +454,7 @@ void uniform_kernel_cuda(TensorIterator& iter, double from_, double to_, Generat

void random_kernel_cuda(TensorIterator& iter, uint64_t range, int64_t base, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, iter.dtype(), "random_cuda", [&] {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_cuda", [&] {
if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
// define lambda to mod with range and add base
auto random_func = [range, base] __device__ (uint64_t rand) {
Expand Down Expand Up @@ -486,7 +486,7 @@ void random_kernel_cuda(TensorIterator& iter, uint64_t range, int64_t base, Gene

void normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "normal_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto mean = static_cast<accscalar_t>(mean_);
auto std = static_cast<accscalar_t>(std_);
Expand All @@ -510,7 +510,7 @@ void normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Generat

void cauchy_kernel_cuda(TensorIterator& iter, double median_, double sigma_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "cauchy_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto median = static_cast<accscalar_t>(median_);
auto sigma = static_cast<accscalar_t>(sigma_);
Expand Down Expand Up @@ -543,7 +543,7 @@ void exponential_kernel_cuda(TensorIterator& iter, double lambda_, Generator* ge
// Note that HIP doesn't support std::nextafter in device code.
auto nextafter_1_0_float = std::nextafter(1.0f, 0.0f);
auto nextafter_1_0_double = std::nextafter(1.0, 0.0);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "exponential_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto lambda = static_cast<accscalar_t>(lambda_);
if (std::is_same<scalar_t, double>::value) {
Expand Down Expand Up @@ -584,7 +584,7 @@ void exponential_kernel_cuda(TensorIterator& iter, double lambda_, Generator* ge

void geometric_kernel_cuda(TensorIterator& iter, double p_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "geometric_cuda", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
if (std::is_same<scalar_t, double>::value) {
// define lambda for geometric transformation
auto geometric_func = [p_] __device__ (double rand) {
Expand All @@ -610,7 +610,7 @@ void geometric_kernel_cuda(TensorIterator& iter, double p_, Generator* gen_) {

void log_normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "log_normal_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto mean = static_cast<accscalar_t>(mean_);
auto std = static_cast<accscalar_t>(std_);
Expand Down Expand Up @@ -638,8 +638,8 @@ void log_normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Gen

void bernoulli_scalar_cuda_kernel(TensorIterator& iter, double p_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
AT_DISPATCH_ALL_TYPES_AND3(
at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
if (std::is_same<scalar_t, double>::value) {
// define lambda for bernoulli transformation
auto bernoulli_func = [p_] __device__ (double rand) {
Expand Down Expand Up @@ -673,7 +673,7 @@ Tensor& random_cuda_(Tensor& self, Generator* gen) {
uint64_t range;
auto iter_scalar_type = iter.dtype();
if (isFloatingType(iter_scalar_type)) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter_scalar_type, "random_cuda_range_calc", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter_scalar_type, "random_cuda_range_calc", [&] {
range = static_cast<uint64_t>((1ULL << std::numeric_limits<scalar_t>::digits) + 1);
});
} else {
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ void slow_conv_transpose2d_out_cuda_template(
ones.fill_(1);
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_transpose2d_out_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;

Expand Down Expand Up @@ -460,7 +460,7 @@ static void slow_conv_transpose2d_backward_out_cuda_template(
grad_columns.resize_({n_output_plane * kernel_width * kernel_height,
input_height * input_width});

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
grad_output.scalar_type(), "slow_conv_transpose2d_backward_out_cuda", [&] {
// Helpers
Tensor grad_input_n = Tensor();
Expand Down Expand Up @@ -663,7 +663,7 @@ void slow_conv_transpose2d_acc_grad_parameters_cuda_template(
columns.resize_({n_output_plane * kernel_width * kernel_height,
input_height * input_width});

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_transpose2d_acc_grad_parameters_cuda", [&] {
// Helpers
Tensor input_n = Tensor();
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ void slow_conv_transpose3d_out_cuda_template(
ones.fill_(1);
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_transpose3d_out_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;

Expand Down Expand Up @@ -531,7 +531,7 @@ void slow_conv_transpose3d_backward_out_cuda_template(
{n_output_plane * kernel_width * kernel_height * kernel_depth,
input_depth * input_height * input_width});

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_transpose3d_backward_out_cuda", [&] {
// Helpers
Tensor grad_input_n;
Expand Down Expand Up @@ -761,7 +761,7 @@ void slow_conv_transpose3d_acc_grad_parameters_cuda(
columns.resize_({n_output_plane * kernel_width * kernel_height * kernel_depth,
input_depth * input_height * input_width});

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(),
"slow_conv_transpose3d_acc_grad_parameters_cuda",
[&] {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ void slow_conv_dilated_all_cuda_template(
std::vector<int64_t> dims(dim);
std::iota(dims.begin(), dims.end(), 1);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_dilated<>", [&] {
// For each elt in batch, do:
for (int elt = 0; elt < batchSize; elt++) {
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/nn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@
CPU:
forward_scalar_types: ['Float', 'Double', 'Long', 'BFloat16']
backward_scalar_types: ['Float', 'Double', 'BFloat16']
CUDA:
forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
backward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']

- name: _thnn_conv_depthwise2d(Tensor self, Tensor weight, IntArrayRef[2] kernel_size, Tensor? bias, IntArrayRef[2] stride, IntArrayRef[2] padding, IntArrayRef[2] dilation)
cname: SpatialDepthwiseConvolution
buffers: []
CUDA:
forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
backward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
2 changes: 2 additions & 0 deletions aten/src/ATen/nn_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def function_info(name, arguments, cimpls, buffers, backends, inplace, backend_t
'name': name,
'cpu_bfloat16': True if backend_types is not None and 'CPU' in backend_types and
'BFloat16' in backend_types['CPU'] else False,
'cuda_bfloat16': True if backend_types is not None and 'CUDA' in backend_types and
'BFloat16' in backend_types['CUDA'] else False,
'backend_types': backend_types,
'arguments': arguments,
'return': 'argument 0' if inplace else get_return(arguments),
Expand Down
3 changes: 3 additions & 0 deletions aten/src/THCUNN/SpatialConvolutionMM.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@

#include <THCUNN/generic/SpatialConvolutionMM.cu>
#include <THC/THCGenerateFloatTypes.h>

#include <THCUNN/generic/SpatialConvolutionMM.cu>
#include <THC/THCGenerateBFloat16Type.h>
3 changes: 3 additions & 0 deletions aten/src/THCUNN/SpatialDepthwiseConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,6 @@ __global__ void spatialDepthwiseConvolutionAccGradParameters(

#include <THCUNN/generic/SpatialDepthwiseConvolution.cu>
#include <THC/THCGenerateFloatTypes.h>

#include <THCUNN/generic/SpatialDepthwiseConvolution.cu>
#include <THC/THCGenerateBFloat16Type.h>
3 changes: 3 additions & 0 deletions aten/src/THCUNN/THCUNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ typedef int64_t THCIndex_t;

#include <THCUNN/generic/THCUNN.h>
#include <THC/THCGenerateFloatTypes.h>

#include <THCUNN/generic/THCUNN.h>
#include <THC/THCGenerateBFloat16Type.h>
Loading

0 comments on commit 985fd97

Please sign in to comment.