forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Intel GPU oneDNN upstreaming for primitive integration (pytorch#117112)
# Motivation As proposed in pytorch#114848 and pytorch#114723, oneDNN library is an important component for Intel GPU software ecosystem. Current PR is based on pytorch#117098, where oneDNN library for Intel GPU should be ready. This PR is the integration code from aten to oneDNN. GEMM integration code is the core part in this PR. Accompanied with GEMM, more basic support like runtime (device, stream), primitive attr is also included. We put the oneDNN integration code in directory `aten/src/ATen/native/mkldnn/xpu/detail`. We add a namespace `at::native::xpu::onednn` for oneDNN integration. The code in this PR would be used in following PRs, where aten operators would call the functions in these integration code.. We separate the prs due to onednn integration is logically separable with aten operator implementation. Also, this can ease the burden of reviewing by avoid too much codes in single PR. Co-authored-by: xiaolil1 <[email protected]> Co-authored-by: lei,zhenyuan <[email protected]> Pull Request resolved: pytorch#117112 Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/albanD
- Loading branch information
1 parent
944d046
commit cc18afa
Showing
7 changed files
with
1,139 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,365 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
#include <oneapi/dnnl/dnnl.hpp> | ||
#include <oneapi/dnnl/dnnl_types.h> | ||
#include <ATen/native/mkldnn/xpu/detail/Utils.h> | ||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h> | ||
|
||
namespace at::native::onednn { | ||
/* oneDNN quantization usage: | ||
https://oneapi-src.github.io/oneDNN/dev_guide_attributes_quantization.html# | ||
src_fp32 = scale_src * (src_int8 - zero_point) | ||
wei_fp32 = scale_wei * (wei_int8 - zero_point) | ||
dst_fp32 = scale_dst * (dst_int8 - zero_point) | ||
fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 | ||
Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei) | ||
Int8 Convolution: dst_int8 = 1 / scale_dst * dst_fp32; | ||
Considering zero-point (asymmetric): | ||
dst_fp32 = (src_int8 - src_zp) * src_sc * wei_int8 * wei_sc | ||
dst_sc * (dst_int8 - dst_zp) = (src_int8 - src_zp) * wei_int8 * src_sc * | ||
wei_sc | ||
dst_int8 = (src_int8 - src_zp) * wei_int8 * src_sc * wei_sc / dst_sc + | ||
dst_zp | ||
considering bias: | ||
fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 + bias | ||
Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei) | ||
+ bias Int8 Convolution: dst_fp32 = (src_int8 * wei_int8 + bias/(scale_src * | ||
scale_wei)) * (scale_src * scale_wei) Int8 Convolution: dst_int8 = 1 / | ||
scale_dst * dst_fp32; | ||
*/ | ||
|
||
/* | ||
oneDNN postops usage: | ||
Currently, oneDNN supports 5 kinds of post ops. More details can be refered | ||
to oneDNN doc. | ||
https://oneapi-src.github.io/oneDNN/dev_guide_attributes_post_ops.html#doxid-dev-guide-attributes-post-ops-1dev-guide-attributes-post-ops-eltwise | ||
0. without post ops | ||
dst = Conv(src, wei) + bias; | ||
dst_int8 = 1/q_scale * dst; q_scale is the op output quantization scale | ||
fp32 API: Attr attr; | ||
int8 API: Attr attr(q_scale); | ||
1. append eltwise post op | ||
dst = elt_scale * Eltwise{conv_scale * [Conv(src, wei) + bias], alpha, beta} | ||
dst_int8 = 1/q_scale * dst; | ||
fp32 API: | ||
Attr attr; | ||
attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) | ||
attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm) | ||
int8 API: | ||
Attr attr(q_scale); | ||
attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) | ||
attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm) | ||
2. append sum post op | ||
dst = conv_scale * Conv(src, wei) + sum_scale * (dst - zp) | ||
dst_int8 = 1/q_scale * dst; | ||
fp32 API: | ||
Attr attr; | ||
attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) | ||
attr.append_post_sum(sum_scale) | ||
int8 API: | ||
Attr attr(q_scale); | ||
attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) | ||
attr.append_post_sum(sum_scale) | ||
3. append binary post op | ||
dst = Binary[Conv(src, wei)] | ||
*/ | ||
using kind_t = dnnl::primitive::kind; | ||
struct PostOpParam { | ||
// eltwise post op constructor | ||
PostOpParam(float scale, float alpha, float beta, dnnl::algorithm algo, kind_t kind) | ||
: scale_(scale), alpha_(alpha), beta_(beta), algo_(algo), kind_(kind) {} | ||
// sum post op constructor | ||
PostOpParam(float scale, kind_t kind) : scale_(scale), kind_(kind) {} | ||
// binary post op constructor | ||
PostOpParam( | ||
at::Tensor& binary, | ||
dnnl::memory::desc& binary_md, | ||
dnnl::memory::desc& expected_md, | ||
dnnl::algorithm algo, | ||
kind_t kind) | ||
: binary_(binary), | ||
meta_(binary_md), | ||
expected_meta_(expected_md), | ||
algo_(algo), | ||
kind_(kind) {} | ||
// prelu post op constructor | ||
PostOpParam(int mask, kind_t kind) : mask_(mask), kind_(kind) {} | ||
|
||
// post sum or binary with scale post op constructor | ||
PostOpParam(at::Tensor& binary, float scale, dnnl::algorithm algo, kind_t kind) | ||
: scale_(scale), binary_(binary), algo_(algo), kind_(kind) {} | ||
|
||
// for int8 sum/eltwise | ||
float scale_ = 1.0; | ||
// for eltwise | ||
float alpha_ = 0.0; | ||
float beta_ = 0.0; | ||
// for binary | ||
at::Tensor binary_ = at::Tensor(); | ||
at::Tensor expected_binary_ = at::Tensor(); | ||
void* binary_ptr_ = nullptr; | ||
dnnl::memory::desc meta_ = dnnl::memory::desc(); | ||
dnnl::memory::desc expected_meta_ = dnnl::memory::desc(); | ||
// for prelu | ||
int mask_ = 0; | ||
// common | ||
dnnl::algorithm algo_ = dnnl::algorithm::eltwise_relu; | ||
kind_t kind_ = kind_t::eltwise; | ||
}; | ||
|
||
class Attr { | ||
public: | ||
Attr() : q_scale_(1.f), q_zero_point_(0) {} | ||
Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {} | ||
|
||
/***** eltwise *****/ | ||
dnnl::algorithm kind_with_relu = dnnl::algorithm::eltwise_relu; | ||
dnnl::algorithm kind_with_sigmoid = dnnl::algorithm::eltwise_logistic; | ||
dnnl::algorithm kind_with_gelu_tanh = dnnl::algorithm::eltwise_gelu_tanh; | ||
dnnl::algorithm kind_with_gelu_erf = dnnl::algorithm::eltwise_gelu_erf; | ||
dnnl::algorithm kind_with_mish = dnnl::algorithm::eltwise_mish; | ||
dnnl::algorithm kind_with_linear = dnnl::algorithm::eltwise_linear; | ||
dnnl::algorithm kind_with_swish = dnnl::algorithm::eltwise_swish; | ||
dnnl::algorithm kind_with_sqrt = dnnl::algorithm::eltwise_sqrt; | ||
dnnl::algorithm kind_with_tanh = dnnl::algorithm::eltwise_tanh; | ||
dnnl::algorithm kind_with_square = dnnl::algorithm::eltwise_square; | ||
dnnl::algorithm kind_with_abs = dnnl::algorithm::eltwise_abs; | ||
dnnl::algorithm kind_with_exp = dnnl::algorithm::eltwise_exp; | ||
dnnl::algorithm kind_with_log = dnnl::algorithm::eltwise_log; | ||
dnnl::algorithm kind_with_round = dnnl::algorithm::eltwise_round; | ||
dnnl::algorithm kind_with_hardswish = dnnl::algorithm::eltwise_hardswish; | ||
dnnl::algorithm kind_with_soft_relu = dnnl::algorithm::eltwise_soft_relu; | ||
dnnl::algorithm kind_with_elu = dnnl::algorithm::eltwise_elu; | ||
dnnl::algorithm kind_with_pow = dnnl::algorithm::eltwise_pow; | ||
dnnl::algorithm kind_with_clip = dnnl::algorithm::eltwise_clip; | ||
// note: hardsigmoid seems oneDNN still not support | ||
dnnl::algorithm kind_with_hardsigmoid = dnnl::algorithm::eltwise_hardsigmoid; | ||
|
||
/***** binary *****/ | ||
dnnl::algorithm kind_with_binary_mul = dnnl::algorithm::binary_mul; | ||
dnnl::algorithm kind_with_binary_add = dnnl::algorithm::binary_add; | ||
dnnl::algorithm kind_with_binary_sub = dnnl::algorithm::binary_sub; | ||
dnnl::algorithm kind_with_binary_div = dnnl::algorithm::binary_div; | ||
dnnl::algorithm kind_with_binary_eq = dnnl::algorithm::binary_eq; | ||
dnnl::algorithm kind_with_binary_ne = dnnl::algorithm::binary_ne; | ||
dnnl::algorithm kind_with_binary_ge = dnnl::algorithm::binary_ge; | ||
dnnl::algorithm kind_with_binary_gt = dnnl::algorithm::binary_gt; | ||
dnnl::algorithm kind_with_binary_le = dnnl::algorithm::binary_le; | ||
dnnl::algorithm kind_with_binary_lt = dnnl::algorithm::binary_lt; | ||
dnnl::algorithm kind_with_binary_max = dnnl::algorithm::binary_max; | ||
dnnl::algorithm kind_with_binary_min = dnnl::algorithm::binary_min; | ||
|
||
// append sum post op | ||
Attr& append_post_sum( | ||
float sum_scale, | ||
float sum_q_scale = 1.f, | ||
int64_t zp = 0) { | ||
ops_params_.push_back( | ||
PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, kind_t::sum)); | ||
return *this; | ||
} | ||
|
||
// append eltwise post op | ||
Attr& append_post_eltwise( | ||
float scale, | ||
float alpha, | ||
float beta, | ||
dnnl::algorithm algo) { | ||
ops_params_.push_back( | ||
PostOpParam(scale, alpha, beta, algo, kind_t::eltwise)); | ||
return *this; | ||
} | ||
|
||
// append binary post op | ||
Attr& append_post_binary(dnnl::algorithm algo, const at::Tensor& binary) { | ||
auto binary_ = binary.is_quantized() ? at::dequantize(binary) : binary; | ||
bool binary_is_channels_last = (binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast || | ||
binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d); | ||
|
||
binary_ = binary_is_channels_last ? binary_ : binary_.contiguous(); | ||
dnnl::memory::desc md = get_onednn_md(binary_); | ||
auto expected_md = dnnl::memory::desc( | ||
md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::any); | ||
ops_params_.push_back( | ||
PostOpParam(binary_, md, expected_md, algo, kind_t::binary)); | ||
return *this; | ||
} | ||
|
||
Attr& append_scale_binary( | ||
dnnl::algorithm algo, | ||
at::Tensor binary, | ||
float scale, | ||
float sum_q_scale = 1.f, | ||
int64_t zp = 0) { | ||
ops_params_.push_back(PostOpParam( | ||
binary, /*scale_sum*/ scale * sum_q_scale, algo, kind_t::binary)); | ||
return *this; | ||
} | ||
|
||
// append bias with binary_add method (only used for QConv now) | ||
template <int N> | ||
Attr& append_bias(const at::Tensor& binary) { | ||
// In PyTorch, bias are in shape of [OC], | ||
// we expand its shape according to Conv dimension | ||
// Conv1d [OC, 1, 1], Conv2d [1, OC, 1, ,1], Conv3d [1, OC, 1, 1, 1] | ||
at::Tensor binary_ = binary.contiguous(); | ||
dnnl::memory::desc binary_md; | ||
switch (N) { | ||
case 1: | ||
binary_md = dnnl::memory::desc( | ||
{binary.size(0), 1, 1}, | ||
dnnl::memory::data_type::f32, | ||
dnnl::memory::format_tag::abc); | ||
break; | ||
case 2: | ||
binary_md = dnnl::memory::desc( | ||
{1, binary.size(0), 1, 1}, | ||
dnnl::memory::data_type::f32, | ||
dnnl::memory::format_tag::abcd); | ||
break; | ||
case 3: | ||
binary_md = dnnl::memory::desc( | ||
{1, binary.size(0), 1, 1, 1}, | ||
dnnl::memory::data_type::f32, | ||
dnnl::memory::format_tag::abcde); | ||
break; | ||
default: | ||
TORCH_INTERNAL_ASSERT(0, | ||
"XPU only supports append_bias for Conv1d, Conv2d and Conv3d."); | ||
} | ||
// In this case, expected_md = binary_md | ||
ops_params_.push_back(PostOpParam( | ||
binary_, binary_md, binary_md, kind_with_binary_add, kind_t::binary)); | ||
return *this; | ||
} | ||
|
||
// append prelu post op | ||
Attr& append_post_prelu(int mask) { | ||
ops_params_.push_back(PostOpParam(mask, kind_t::prelu)); | ||
return *this; | ||
} | ||
|
||
dnnl::post_ops extract_post_ops(const at::Tensor& dst){ | ||
// this function is used to extract post ops params from the ops_params_ | ||
// and put them into onednn post ops | ||
for (size_t i = 0; i < ops_params_.size(); ++i) { | ||
kind_t kind = ops_params_[i].kind_; | ||
switch (kind) { | ||
case kind_t::eltwise: { | ||
dnnl::algorithm algo = ops_params_[i].algo_; | ||
float alpha = ops_params_[i].alpha_; | ||
float beta = ops_params_[i].beta_; | ||
dnnl_post_ops_.append_eltwise(algo, alpha, beta); | ||
break; | ||
} | ||
case kind_t::sum: { | ||
float scale = ops_params_[i].scale_; | ||
// TODO [Asymmetric]: | ||
// Post-sum zp for gpu is not supported currently | ||
dnnl_post_ops_.append_sum(scale); | ||
break; | ||
} | ||
case kind_t::binary: { | ||
dnnl::algorithm algo = ops_params_[i].algo_; | ||
auto expected_md = ops_params_[i].expected_meta_; | ||
// In this case user may create src1 memory descriptor with | ||
// format_tag::any or set a specific tag. However, in later case if | ||
// tags mismatch with dst, it would result in suboptimal performance. | ||
// So here we use format_tag::any to make sure the fast can be | ||
// selected. | ||
// Thus we use expected_md (with format_any) here to create pd instead | ||
// of original md | ||
dnnl_post_ops_.append_binary(algo, expected_md); | ||
break; | ||
} | ||
default: | ||
break; | ||
} | ||
} | ||
|
||
// if output is quantized, then append the eltwise linear to adjust the | ||
// output scale/zero_point | ||
if (dst.is_quantized()) { | ||
// [Note: Gap of u8 qtensor scale between oneDNN and PyTorch] | ||
// The /2 here is for output_scale collected by observer is different | ||
// from quantization requirements in oneDNN. | ||
// For Observer, the conv_scale (activation scale in other case) is | ||
// computed through 2max_v/(qmax - qmin). The max_v is collected | ||
// from the tensor to be observerd. | ||
// (https://pytorch.org/docs/stable/generated/torch.quantization.observer.MinMaxObserver.html#torch.quantization.observer.MinMaxObserver) | ||
// On the other hand, for u8 in oneDNN, the scale for quantization is | ||
// defined as max_v/(qmax-qmin). Hence, we need to divide by 2 here. | ||
// (https://oneapi-src.github.io/oneDNN/dev_guide_inference_int8.html) | ||
dnnl_post_ops_.append_eltwise( | ||
kind_with_linear, 1.f / q_scale_, q_zero_point_); | ||
} | ||
return dnnl_post_ops_; | ||
} | ||
|
||
bool with_sum() { | ||
for (size_t i = 0; i < ops_params_.size(); ++i) { | ||
if (ops_params_[i].kind_ == kind_t::sum) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
bool with_binary() { | ||
for (size_t i = 0; i < ops_params_.size(); ++i) { | ||
if (ops_params_[i].kind_ == kind_t::binary) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
void construct_post_binary( | ||
dnnl::primitive_desc& pd, | ||
std::unordered_map<int, dnnl::memory>& args) { | ||
// This function is used to construct binary memory desc in binary post ops. | ||
// According to oneDNN doc, the binary tensor can be in shape of | ||
// [1, 1, 1, 1], tensor broadcast | ||
// [1, C, 1, 1], channel broadcast | ||
// [dst.shape], no broadcast and eltwise-wise binary operations on dst | ||
|
||
auto engine = | ||
GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()}); | ||
for (size_t i = 0; i < ops_params_.size(); ++i) { | ||
kind_t kind = ops_params_[i].kind_; | ||
if (kind == kind_t::binary) { | ||
dnnl::memory binary_m; | ||
auto binary = ops_params_[i].binary_; | ||
auto md = ops_params_[i].meta_; | ||
// qeury expected_md to achieve peak performance | ||
auto expected_md = pd.query_md( | ||
dnnl::query::exec_arg_md, | ||
DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1); | ||
|
||
binary_m = at::native::onednn::make_onednn_memory( | ||
md, engine, binary.data_ptr() | ||
); | ||
|
||
args.insert( | ||
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1, binary_m}); | ||
} | ||
} | ||
} | ||
|
||
float q_scale_ = 1.0; // the scale used to quantize the fused result from fp32 | ||
// to int8, only works for int8 case | ||
int64_t q_zero_point_ = 0; | ||
std::vector<PostOpParam> ops_params_; // series of post ops | ||
dnnl::post_ops dnnl_post_ops_; | ||
}; | ||
|
||
} // namespace at::native::onednn |
Oops, something went wrong.