From cc18afa25f81aeeb6e817254e06cf537a20c656c Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Tue, 16 Apr 2024 06:58:14 +0000 Subject: [PATCH] Intel GPU oneDNN upstreaming for primitive integration (#117112) # Motivation As proposed in https://github.com/pytorch/pytorch/issues/114848 and https://github.com/pytorch/pytorch/issues/114723, oneDNN library is an important component for Intel GPU software ecosystem. Current PR is based on #117098, where oneDNN library for Intel GPU should be ready. This PR is the integration code from aten to oneDNN. GEMM integration code is the core part in this PR. Accompanied with GEMM, more basic support like runtime (device, stream), primitive attr is also included. We put the oneDNN integration code in directory `aten/src/ATen/native/mkldnn/xpu/detail`. We add a namespace `at::native::xpu::onednn` for oneDNN integration. The code in this PR would be used in following PRs, where aten operators would call the functions in these integration code.. We separate the prs due to onednn integration is logically separable with aten operator implementation. Also, this can ease the burden of reviewing by avoid too much codes in single PR. Co-authored-by: xiaolil1 Co-authored-by: lei,zhenyuan Pull Request resolved: https://github.com/pytorch/pytorch/pull/117112 Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/albanD --- aten/src/ATen/native/mkldnn/xpu/detail/Attr.h | 365 ++++++++++++++++++ .../ATen/native/mkldnn/xpu/detail/Matmul.cpp | 244 ++++++++++++ .../ATen/native/mkldnn/xpu/detail/Utils.cpp | 352 +++++++++++++++++ .../src/ATen/native/mkldnn/xpu/detail/Utils.h | 56 +++ .../ATen/native/mkldnn/xpu/detail/oneDNN.h | 20 + .../mkldnn/xpu/detail/oneDNNContext.cpp | 27 ++ .../native/mkldnn/xpu/detail/oneDNNContext.h | 75 ++++ 7 files changed, 1139 insertions(+) create mode 100644 aten/src/ATen/native/mkldnn/xpu/detail/Attr.h create mode 100644 aten/src/ATen/native/mkldnn/xpu/detail/Matmul.cpp create mode 100644 aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp create mode 100644 aten/src/ATen/native/mkldnn/xpu/detail/Utils.h create mode 100644 aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h create mode 100644 aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.cpp create mode 100644 aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h b/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h new file mode 100644 index 00000000000000..56e587084959dc --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h @@ -0,0 +1,365 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at::native::onednn { +/* oneDNN quantization usage: + https://oneapi-src.github.io/oneDNN/dev_guide_attributes_quantization.html# + + src_fp32 = scale_src * (src_int8 - zero_point) + wei_fp32 = scale_wei * (wei_int8 - zero_point) + dst_fp32 = scale_dst * (dst_int8 - zero_point) + fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 + Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei) + Int8 Convolution: dst_int8 = 1 / scale_dst * dst_fp32; + + Considering zero-point (asymmetric): + dst_fp32 = (src_int8 - src_zp) * src_sc * wei_int8 * wei_sc + dst_sc * (dst_int8 - dst_zp) = (src_int8 - src_zp) * wei_int8 * src_sc * + wei_sc + dst_int8 = (src_int8 - src_zp) * wei_int8 * src_sc * wei_sc / dst_sc + + dst_zp + + considering bias: + fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 + bias + Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei) + + bias Int8 Convolution: dst_fp32 = (src_int8 * wei_int8 + bias/(scale_src * + scale_wei)) * (scale_src * scale_wei) Int8 Convolution: dst_int8 = 1 / + scale_dst * dst_fp32; +*/ + +/* + oneDNN postops usage: + Currently, oneDNN supports 5 kinds of post ops. More details can be refered +to oneDNN doc. + https://oneapi-src.github.io/oneDNN/dev_guide_attributes_post_ops.html#doxid-dev-guide-attributes-post-ops-1dev-guide-attributes-post-ops-eltwise + +0. without post ops + dst = Conv(src, wei) + bias; + dst_int8 = 1/q_scale * dst; q_scale is the op output quantization scale + fp32 API: Attr attr; + int8 API: Attr attr(q_scale); + +1. append eltwise post op + dst = elt_scale * Eltwise{conv_scale * [Conv(src, wei) + bias], alpha, beta} + dst_int8 = 1/q_scale * dst; + fp32 API: + Attr attr; + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm) + int8 API: + Attr attr(q_scale); + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm) + +2. append sum post op + dst = conv_scale * Conv(src, wei) + sum_scale * (dst - zp) + dst_int8 = 1/q_scale * dst; + fp32 API: + Attr attr; + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_sum(sum_scale) + int8 API: + Attr attr(q_scale); + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_sum(sum_scale) + +3. append binary post op + dst = Binary[Conv(src, wei)] + +*/ +using kind_t = dnnl::primitive::kind; +struct PostOpParam { + // eltwise post op constructor + PostOpParam(float scale, float alpha, float beta, dnnl::algorithm algo, kind_t kind) + : scale_(scale), alpha_(alpha), beta_(beta), algo_(algo), kind_(kind) {} + // sum post op constructor + PostOpParam(float scale, kind_t kind) : scale_(scale), kind_(kind) {} + // binary post op constructor + PostOpParam( + at::Tensor& binary, + dnnl::memory::desc& binary_md, + dnnl::memory::desc& expected_md, + dnnl::algorithm algo, + kind_t kind) + : binary_(binary), + meta_(binary_md), + expected_meta_(expected_md), + algo_(algo), + kind_(kind) {} + // prelu post op constructor + PostOpParam(int mask, kind_t kind) : mask_(mask), kind_(kind) {} + + // post sum or binary with scale post op constructor + PostOpParam(at::Tensor& binary, float scale, dnnl::algorithm algo, kind_t kind) + : scale_(scale), binary_(binary), algo_(algo), kind_(kind) {} + + // for int8 sum/eltwise + float scale_ = 1.0; + // for eltwise + float alpha_ = 0.0; + float beta_ = 0.0; + // for binary + at::Tensor binary_ = at::Tensor(); + at::Tensor expected_binary_ = at::Tensor(); + void* binary_ptr_ = nullptr; + dnnl::memory::desc meta_ = dnnl::memory::desc(); + dnnl::memory::desc expected_meta_ = dnnl::memory::desc(); + // for prelu + int mask_ = 0; + // common + dnnl::algorithm algo_ = dnnl::algorithm::eltwise_relu; + kind_t kind_ = kind_t::eltwise; +}; + +class Attr { + public: + Attr() : q_scale_(1.f), q_zero_point_(0) {} + Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {} + + /***** eltwise *****/ + dnnl::algorithm kind_with_relu = dnnl::algorithm::eltwise_relu; + dnnl::algorithm kind_with_sigmoid = dnnl::algorithm::eltwise_logistic; + dnnl::algorithm kind_with_gelu_tanh = dnnl::algorithm::eltwise_gelu_tanh; + dnnl::algorithm kind_with_gelu_erf = dnnl::algorithm::eltwise_gelu_erf; + dnnl::algorithm kind_with_mish = dnnl::algorithm::eltwise_mish; + dnnl::algorithm kind_with_linear = dnnl::algorithm::eltwise_linear; + dnnl::algorithm kind_with_swish = dnnl::algorithm::eltwise_swish; + dnnl::algorithm kind_with_sqrt = dnnl::algorithm::eltwise_sqrt; + dnnl::algorithm kind_with_tanh = dnnl::algorithm::eltwise_tanh; + dnnl::algorithm kind_with_square = dnnl::algorithm::eltwise_square; + dnnl::algorithm kind_with_abs = dnnl::algorithm::eltwise_abs; + dnnl::algorithm kind_with_exp = dnnl::algorithm::eltwise_exp; + dnnl::algorithm kind_with_log = dnnl::algorithm::eltwise_log; + dnnl::algorithm kind_with_round = dnnl::algorithm::eltwise_round; + dnnl::algorithm kind_with_hardswish = dnnl::algorithm::eltwise_hardswish; + dnnl::algorithm kind_with_soft_relu = dnnl::algorithm::eltwise_soft_relu; + dnnl::algorithm kind_with_elu = dnnl::algorithm::eltwise_elu; + dnnl::algorithm kind_with_pow = dnnl::algorithm::eltwise_pow; + dnnl::algorithm kind_with_clip = dnnl::algorithm::eltwise_clip; + // note: hardsigmoid seems oneDNN still not support + dnnl::algorithm kind_with_hardsigmoid = dnnl::algorithm::eltwise_hardsigmoid; + + /***** binary *****/ + dnnl::algorithm kind_with_binary_mul = dnnl::algorithm::binary_mul; + dnnl::algorithm kind_with_binary_add = dnnl::algorithm::binary_add; + dnnl::algorithm kind_with_binary_sub = dnnl::algorithm::binary_sub; + dnnl::algorithm kind_with_binary_div = dnnl::algorithm::binary_div; + dnnl::algorithm kind_with_binary_eq = dnnl::algorithm::binary_eq; + dnnl::algorithm kind_with_binary_ne = dnnl::algorithm::binary_ne; + dnnl::algorithm kind_with_binary_ge = dnnl::algorithm::binary_ge; + dnnl::algorithm kind_with_binary_gt = dnnl::algorithm::binary_gt; + dnnl::algorithm kind_with_binary_le = dnnl::algorithm::binary_le; + dnnl::algorithm kind_with_binary_lt = dnnl::algorithm::binary_lt; + dnnl::algorithm kind_with_binary_max = dnnl::algorithm::binary_max; + dnnl::algorithm kind_with_binary_min = dnnl::algorithm::binary_min; + + // append sum post op + Attr& append_post_sum( + float sum_scale, + float sum_q_scale = 1.f, + int64_t zp = 0) { + ops_params_.push_back( + PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, kind_t::sum)); + return *this; + } + + // append eltwise post op + Attr& append_post_eltwise( + float scale, + float alpha, + float beta, + dnnl::algorithm algo) { + ops_params_.push_back( + PostOpParam(scale, alpha, beta, algo, kind_t::eltwise)); + return *this; + } + + // append binary post op + Attr& append_post_binary(dnnl::algorithm algo, const at::Tensor& binary) { + auto binary_ = binary.is_quantized() ? at::dequantize(binary) : binary; + bool binary_is_channels_last = (binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast || + binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d); + + binary_ = binary_is_channels_last ? binary_ : binary_.contiguous(); + dnnl::memory::desc md = get_onednn_md(binary_); + auto expected_md = dnnl::memory::desc( + md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::any); + ops_params_.push_back( + PostOpParam(binary_, md, expected_md, algo, kind_t::binary)); + return *this; + } + + Attr& append_scale_binary( + dnnl::algorithm algo, + at::Tensor binary, + float scale, + float sum_q_scale = 1.f, + int64_t zp = 0) { + ops_params_.push_back(PostOpParam( + binary, /*scale_sum*/ scale * sum_q_scale, algo, kind_t::binary)); + return *this; + } + + // append bias with binary_add method (only used for QConv now) + template + Attr& append_bias(const at::Tensor& binary) { + // In PyTorch, bias are in shape of [OC], + // we expand its shape according to Conv dimension + // Conv1d [OC, 1, 1], Conv2d [1, OC, 1, ,1], Conv3d [1, OC, 1, 1, 1] + at::Tensor binary_ = binary.contiguous(); + dnnl::memory::desc binary_md; + switch (N) { + case 1: + binary_md = dnnl::memory::desc( + {binary.size(0), 1, 1}, + dnnl::memory::data_type::f32, + dnnl::memory::format_tag::abc); + break; + case 2: + binary_md = dnnl::memory::desc( + {1, binary.size(0), 1, 1}, + dnnl::memory::data_type::f32, + dnnl::memory::format_tag::abcd); + break; + case 3: + binary_md = dnnl::memory::desc( + {1, binary.size(0), 1, 1, 1}, + dnnl::memory::data_type::f32, + dnnl::memory::format_tag::abcde); + break; + default: + TORCH_INTERNAL_ASSERT(0, + "XPU only supports append_bias for Conv1d, Conv2d and Conv3d."); + } + // In this case, expected_md = binary_md + ops_params_.push_back(PostOpParam( + binary_, binary_md, binary_md, kind_with_binary_add, kind_t::binary)); + return *this; + } + + // append prelu post op + Attr& append_post_prelu(int mask) { + ops_params_.push_back(PostOpParam(mask, kind_t::prelu)); + return *this; + } + + dnnl::post_ops extract_post_ops(const at::Tensor& dst){ + // this function is used to extract post ops params from the ops_params_ + // and put them into onednn post ops + for (size_t i = 0; i < ops_params_.size(); ++i) { + kind_t kind = ops_params_[i].kind_; + switch (kind) { + case kind_t::eltwise: { + dnnl::algorithm algo = ops_params_[i].algo_; + float alpha = ops_params_[i].alpha_; + float beta = ops_params_[i].beta_; + dnnl_post_ops_.append_eltwise(algo, alpha, beta); + break; + } + case kind_t::sum: { + float scale = ops_params_[i].scale_; + // TODO [Asymmetric]: + // Post-sum zp for gpu is not supported currently + dnnl_post_ops_.append_sum(scale); + break; + } + case kind_t::binary: { + dnnl::algorithm algo = ops_params_[i].algo_; + auto expected_md = ops_params_[i].expected_meta_; + // In this case user may create src1 memory descriptor with + // format_tag::any or set a specific tag. However, in later case if + // tags mismatch with dst, it would result in suboptimal performance. + // So here we use format_tag::any to make sure the fast can be + // selected. + // Thus we use expected_md (with format_any) here to create pd instead + // of original md + dnnl_post_ops_.append_binary(algo, expected_md); + break; + } + default: + break; + } + } + + // if output is quantized, then append the eltwise linear to adjust the + // output scale/zero_point + if (dst.is_quantized()) { + // [Note: Gap of u8 qtensor scale between oneDNN and PyTorch] + // The /2 here is for output_scale collected by observer is different + // from quantization requirements in oneDNN. + // For Observer, the conv_scale (activation scale in other case) is + // computed through 2max_v/(qmax - qmin). The max_v is collected + // from the tensor to be observerd. + // (https://pytorch.org/docs/stable/generated/torch.quantization.observer.MinMaxObserver.html#torch.quantization.observer.MinMaxObserver) + // On the other hand, for u8 in oneDNN, the scale for quantization is + // defined as max_v/(qmax-qmin). Hence, we need to divide by 2 here. + // (https://oneapi-src.github.io/oneDNN/dev_guide_inference_int8.html) + dnnl_post_ops_.append_eltwise( + kind_with_linear, 1.f / q_scale_, q_zero_point_); + } + return dnnl_post_ops_; + } + + bool with_sum() { + for (size_t i = 0; i < ops_params_.size(); ++i) { + if (ops_params_[i].kind_ == kind_t::sum) { + return true; + } + } + return false; + } + + bool with_binary() { + for (size_t i = 0; i < ops_params_.size(); ++i) { + if (ops_params_[i].kind_ == kind_t::binary) { + return true; + } + } + return false; + } + + void construct_post_binary( + dnnl::primitive_desc& pd, + std::unordered_map& args) { + // This function is used to construct binary memory desc in binary post ops. + // According to oneDNN doc, the binary tensor can be in shape of + // [1, 1, 1, 1], tensor broadcast + // [1, C, 1, 1], channel broadcast + // [dst.shape], no broadcast and eltwise-wise binary operations on dst + + auto engine = + GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()}); + for (size_t i = 0; i < ops_params_.size(); ++i) { + kind_t kind = ops_params_[i].kind_; + if (kind == kind_t::binary) { + dnnl::memory binary_m; + auto binary = ops_params_[i].binary_; + auto md = ops_params_[i].meta_; + // qeury expected_md to achieve peak performance + auto expected_md = pd.query_md( + dnnl::query::exec_arg_md, + DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1); + + binary_m = at::native::onednn::make_onednn_memory( + md, engine, binary.data_ptr() + ); + + args.insert( + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1, binary_m}); + } + } + } + + float q_scale_ = 1.0; // the scale used to quantize the fused result from fp32 + // to int8, only works for int8 case + int64_t q_zero_point_ = 0; + std::vector ops_params_; // series of post ops + dnnl::post_ops dnnl_post_ops_; +}; + +} // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Matmul.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Matmul.cpp new file mode 100644 index 00000000000000..7dfd31b93ba8dd --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Matmul.cpp @@ -0,0 +1,244 @@ + +#include + +#include +#include + +#include +#include + +#include + +namespace at::native::onednn { + +sycl::event matmul( + at::Tensor& result, + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Tensor& b_raw, + bool m2_trans, + Attr attr, + const std::vector& deps) { + int64_t dims = result.dim(); + TORCH_CHECK( + dims == 2 || dims == 3, + "oneDNN matmul only works with 2D or 3D, got ", + dims); + TORCH_CHECK( + dims == mat1.dim() && dims == mat2.dim(), + "oneDNN input matrixes must have the same ranks"); + TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined"); + + at::Device cur_device = at::Device(at::kXPU, c10::xpu::current_device()); + auto engine = GpuEngineManager::Instance().get_engine(cur_device); + auto stream = GpuStreamManager::Instance().get_stream(); + + at::Tensor m1 = is_onednn_matmul_strides(mat1) ? mat1 : mat1.contiguous(); + at::Tensor m2 = is_onednn_matmul_strides(mat2) ? mat2 : mat2.contiguous(); + at::Tensor dst = is_onednn_matmul_strides(result, true) ? result : result.contiguous(); + + int64_t m = dst.size(-2); + int64_t n = dst.size(-1); + int64_t k = m1.size(-1); + int64_t mb = 1; + + if (dims == 3) { + mb = dst.size(0); + TORCH_CHECK( + mb == m1.size(0) && mb == m2.size(0), + "batch size mismatch, dst mb: ", + mb, + "m1 mb", + m1.size(0), + " m2 mb: ", + m2.size(0)); + } + + // validate bias and make it compatible with oneDNN implementation + bool with_bias = false; + at::Tensor b = b_raw; + if (b.defined()) { + with_bias = true; + if (b.dim() == 1) { + TORCH_CHECK( + b.size(0) == n || b.size(0) == 1, + "matmul supports [n] or [1] when bias dim is 1 ..."); + if (b.size(0) == 0) { + with_bias = false; + } else if (m1.dim() == 3) { + b = b.expand({mb, m, n}).contiguous(); + } else if (m1.dim() == 2) { + b = b.expand({1, n}).contiguous(); + } + } else if (b.dim() == 2) { + TORCH_CHECK( + (b.size(0) == m && b.size(1) == n) || + (b.size(0) == 1 && b.size(1) == n) || + (b.size(0) == m && b.size(1) == 1) || + (b.size(0) == 1 && b.size(1) == 1), + "matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ..."); + if (b.size(0) == 1 && b.size(1) == 1) + b = b.expand({1, n}).contiguous(); + } else if (b.dim() == 3) { + TORCH_CHECK( + at::are_expandable({mb, m, n}, b.sizes()), + "matmul bias must be expandable to:", + dst.sizes(), + " but got:", + b.sizes()); + b = b.expand({mb, m, n}).contiguous(); + } else if (b.dim() == 0) { + TORCH_CHECK( + b.numel() == 1, "matmul supports 1 numel when bias dim is [] ..."); + if (m1.dim() == 3) { + b = b.expand({mb, m, n}).contiguous(); + } else { + b = b.expand({1, n}).contiguous(); + } + } else { + TORCH_CHECK(0, "unsupported bias dim in matmul ..."); + } + } + + b = b.contiguous(); // avoid reorder 2 times + + // xpu matmul support both ab/ba shape for m2 tensor, we don't check any more + auto m1_usr_dt = get_onednn_dtype(m1); + auto m2_usr_dt = get_onednn_dtype(m2); + auto dst_usr_dt = get_onednn_dtype(dst); + + auto m1_dt = m1_usr_dt; + auto m2_dt = m2_usr_dt; + auto dst_dt = dst_usr_dt; + dnnl::memory::data_type bias_dt; + + dnnl::memory::desc m1_md, m1_usr_md, m1_any_md; + dnnl::memory::desc m2_md, m2_usr_md, m2_any_md; + dnnl::memory::desc dst_md, dst_usr_md, dst_any_md; + dnnl::memory::desc bias_md; + + // Naive Master weight + if (m1_dt == dnnl::memory::data_type::bf16 && m2_dt == dnnl::memory::data_type::f32) { + m2_dt = dnnl::memory::data_type::bf16; + dst_dt = dnnl::memory::data_type::bf16; + } else if ( + m1_dt == dnnl::memory::data_type::f32 && m2_dt == dnnl::memory::data_type::bf16) { + m1_dt = dnnl::memory::data_type::bf16; + dst_dt = dnnl::memory::data_type::bf16; + } + + dnnl::memory::dims m1_dims, m2_dims, dst_dims, bias_dims; + dnnl::memory::dims m1_strides, m2_strides, dst_strides, bias_strides; + if (dims == 2) { + m1_dims = {m, k}; + m2_dims = {k, n}; + dst_dims = {m, n}; + + m1_strides = {m1.stride(0), m1.stride(1)}; + if (m2_trans) { + m2_strides = {m2.stride(0), m2.stride(1)}; + } else { + m2_strides = {m2.stride(1), m2.stride(0)}; + } + dst_strides = {dst.stride(0), dst.stride(1)}; + } else { + m1_dims = {mb, m, k}; + m2_dims = {mb, k, n}; + dst_dims = {mb, m, n}; + + m1_strides = {m1.stride(0), m1.stride(1), m1.stride(2)}; + if (m2_trans) { + m2_strides = {m2.stride(0), m2.stride(1), m2.stride(2)}; + } else { + m2_strides = {m2.stride(0), m2.stride(2), m2.stride(1)}; + } + dst_strides = {dst.stride(0), dst.stride(1), dst.stride(2)}; + } + + if (with_bias) { + bias_dims = get_onednn_dims(b); + bias_dt = get_onednn_dtype(b); + bias_strides = get_onednn_strides(b); + } + + dnnl::post_ops po = attr.extract_post_ops(dst); + + std::unordered_map args; + dnnl::matmul matmul_p; + dnnl::matmul::primitive_desc matmul_pd; + + // STEP1: create memory desc + m1_md = dnnl::memory::desc(m1_dims, m1_dt, m1_strides); + m2_md = dnnl::memory::desc(m2_dims, m2_dt, m2_strides); + dst_md = dnnl::memory::desc(dst_dims, dst_dt, dst_strides); + + // STEP2: creat attribute + dnnl::primitive_attr pattr; + pattr.set_post_ops(po); + + #if ONEDNN_SUPPORT_DETERMINISTIC + if(at::globalContext().deterministicAlgorithms()) + pattr.set_deterministic(true); + #endif + + // scratchpad + pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + + if (m1_dt == dnnl::memory::data_type::f32) { + pattr.set_fpmath_mode(dnnl::fpmath_mode::strict); + } + + // STEP3: create primitive + if (with_bias) { + bias_md = dnnl::memory::desc(bias_dims, bias_dt, bias_strides); + matmul_pd = + dnnl::matmul::primitive_desc(engine, m1_md, m2_md, bias_md, dst_md, pattr); + } else { + matmul_pd = dnnl::matmul::primitive_desc(engine, m1_md, m2_md, dst_md, pattr); + } + + matmul_p = dnnl::matmul(matmul_pd); + + m1_usr_md = dnnl::memory::desc(m1_dims, m1_usr_dt, m1_strides); + m2_usr_md = dnnl::memory::desc(m2_dims, m2_usr_dt, m2_strides); + dst_usr_md = dnnl::memory::desc(dst_dims, dst_usr_dt, dst_strides); + + // STEP4: create memory + auto m1_usr_m = make_onednn_memory(m1_usr_md, engine, m1.data_ptr()); + auto m2_usr_m = make_onednn_memory(m2_usr_md, engine, m2.data_ptr()); + auto dst_usr_m = make_onednn_memory(dst_usr_md, engine, dst.data_ptr()); + + auto expected_m1_md = matmul_pd.src_desc(); + auto expected_m2_md = matmul_pd.weights_desc(); + auto expected_dst_md = matmul_pd.dst_desc(); + + dnnl::memory m1_m = m1_usr_m, m2_m = m2_usr_m, dst_m = dst_usr_m; + at::Tensor m1_, m2_, dst_; + + if (attr.with_binary()) + attr.construct_post_binary(matmul_pd, args); + + size_t scratchpad_size = matmul_pd.scratchpad_desc().get_size(); + at::Tensor scratchpad_tensor = at::empty( + {static_cast(scratchpad_size)}, m1.options().dtype(at::kByte), c10::nullopt); + auto scratchpad_memory = make_onednn_memory( + matmul_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr()); + args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory}); + + args.insert({DNNL_ARG_SRC, m1_m}); + args.insert({DNNL_ARG_WEIGHTS, m2_m}); + args.insert({DNNL_ARG_DST, dst_m}); + if (with_bias) { + auto bias_m = make_onednn_memory(bias_md, engine, b.data_ptr()); + args.insert({DNNL_ARG_BIAS, bias_m}); + } + + sycl::event matmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args, deps); + + if (!dst.is_same(result)) + result.copy_(dst); + + return matmul_event; +} + +} // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp new file mode 100644 index 00000000000000..73a37d275b4897 --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp @@ -0,0 +1,352 @@ +#include + +namespace at::native::onednn { + +dnnl::memory make_onednn_memory( + dnnl::memory::desc md, + dnnl::engine& engine, + void* ptr){ + return dnnl::sycl_interop::make_memory( + md, + engine, + dnnl::sycl_interop::memory_kind::usm, + ptr == nullptr ? DNNL_MEMORY_ALLOCATE : ptr); +} + +dnnl::memory::format_tag get_dnnl_default_format( + int ndims, + bool is_channels_last, + bool allow_undef) { + switch (ndims) { + case 1: + return dnnl::memory::format_tag::a; + case 2: + return dnnl::memory::format_tag::ab; + case 3: + return is_channels_last ? dnnl::memory::format_tag::acb + : dnnl::memory::format_tag::abc; + case 4: + return is_channels_last ? dnnl::memory::format_tag::acdb + : dnnl::memory::format_tag::abcd; + case 5: + return is_channels_last ? dnnl::memory::format_tag::acdeb + : dnnl::memory::format_tag::abcde; + case 6: + return dnnl::memory::format_tag::abcdef; + case 7: + return dnnl::memory::format_tag::abcdefg; + case 8: + return dnnl::memory::format_tag::abcdefgh; + case 9: + return dnnl::memory::format_tag::abcdefghi; + case 10: + return dnnl::memory::format_tag::abcdefghij; + case 11: + return dnnl::memory::format_tag::abcdefghijk; + case 12: + return dnnl::memory::format_tag::abcdefghijkl; + default: + if (!allow_undef) { + TORCH_CHECK(false, "oneDNN doesn't support tensor dimension > 12"); + } + return dnnl::memory::format_tag::undef; + } +} + +dnnl::memory::data_type get_onednn_dtype( + const at::Tensor& tensor, + bool allow_undef) { + switch (tensor.scalar_type()) { + case at::ScalarType::Byte: + return dnnl::memory::data_type::u8; + case at::ScalarType::Char: + return dnnl::memory::data_type::s8; + case at::ScalarType::QInt8: + return dnnl::memory::data_type::s8; + case at::ScalarType::QUInt8: + return dnnl::memory::data_type::u8; + case at::ScalarType::Int: + return dnnl::memory::data_type::s32; + case at::ScalarType::Half: + return dnnl::memory::data_type::f16; + case at::ScalarType::Float: + return dnnl::memory::data_type::f32; + case at::ScalarType::BFloat16: + return dnnl::memory::data_type::bf16; + default: + if (!allow_undef) { + TORCH_CHECK( + false, + c10::toString(tensor.scalar_type()), + " is not supported in oneDNN!"); + } + return dnnl::memory::data_type::undef; + }; +} + +dnnl::memory::data_type get_onednn_dtype_include_double( + const at::Tensor& tensor, + bool allow_undef) { + if (tensor.scalar_type() == at::ScalarType::Double) + return dnnl::memory::data_type::f64; + return get_onednn_dtype(tensor, allow_undef); +} + +bool is_supported_onednn_dtype(const at::Tensor& tensor) { + return get_onednn_dtype(tensor, /*allow_undef*/ true) == + dnnl::memory::data_type::undef + ? false + : true; +} + +dnnl::memory::dims get_onednn_dims(const at::Tensor& tensor) { + dnnl::memory::dims dims; + for (size_t i = 0; i < tensor.sizes().size(); i++) + dims.push_back(tensor.size(i)); + return dims; +} + +dnnl::memory::dims get_onednn_strides(const at::Tensor& tensor) { + dnnl::memory::dims strides; + for (size_t i = 0; i < tensor.strides().size(); i++) + strides.push_back(tensor.stride(i)); + return strides; +} + +dnnl::memory::desc get_onednn_md(const at::Tensor& tensor) { + return { + get_onednn_dims(tensor), + get_onednn_dtype(tensor), + get_onednn_strides(tensor)}; +} + +bool onednn_strides_check(const at::Tensor& src) { + auto adims = get_onednn_dims(src); + int ndims = (int)adims.size(); + auto dims = adims.data(); + auto data_type = static_cast( + get_onednn_dtype(src, /*allow_undef*/ true)); + auto strides_info = get_onednn_strides(src); + auto strides = strides_info.empty() ? nullptr : &strides_info[0]; + + dnnl_memory_desc_t md; + dnnl_memory_desc_create_with_strides(&md, ndims, dims, data_type, strides); + dnnl_format_kind_t md_fmt_kind; + int md_ndims; + int md_inner_nblks; + dnnl_dims_t* md_padded_dims = nullptr; + + dnnl_memory_desc_query(md, dnnl_query_inner_nblks_s32, &md_inner_nblks); + dnnl_memory_desc_query(md, dnnl_query_format_kind, &md_fmt_kind); + dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &md_ndims); + dnnl_memory_desc_query(md, dnnl_query_padded_dims, &md_padded_dims); + if (strides == nullptr || md_ndims == 0 || + md_fmt_kind != dnnl_format_kind_t::dnnl_blocked) + return true; + + dnnl_dims_t blocks = {0}; + int perm[DNNL_MAX_NDIMS] = {0}; + for (int d = 0; d < md_ndims; ++d) { + // no strides check needed for empty tensor + if (md_padded_dims[d] == 0) + return true; + + // no strides verification for runtime dims + if (strides[d] == DNNL_RUNTIME_DIM_VAL) + return true; + + perm[d] = d; + blocks[d] = 1; + } + + auto block_size = 1; + dnnl_dims_t md_inner_blks; + dnnl_dims_t md_blk_inner_idxs; + dnnl_memory_desc_query(md, dnnl_query_inner_idxs, &md_blk_inner_idxs); + dnnl_memory_desc_query(md, dnnl_query_inner_blks, &md_inner_blks); + for (int iblk = 0; iblk < md_inner_nblks; ++iblk) { + blocks[md_blk_inner_idxs[iblk]] *= md_inner_blks[iblk]; + block_size *= md_inner_blks[iblk]; + } + + // A custom comparator to yield linear order on perm + auto idx_sorter = [&](const int a, const int b) -> bool { + if (strides[a] == strides[b] && md_padded_dims[a] == md_padded_dims[b]) + return a < b; + else if (strides[a] == strides[b]) + return md_padded_dims[a] < md_padded_dims[b]; + else + return strides[a] < strides[b]; + }; + std::sort(perm, perm + md_ndims, idx_sorter); + + auto min_stride = block_size; + for (int idx = 0; idx < md_ndims; ++idx) { + const int d = perm[idx]; + + // Make an exception for strides[d] == 0 as it has broadcast semantics + // Note: owing to being sorted, these are the initial strides + if (strides[d] == 0) + continue; + else if (strides[d] < min_stride) + return false; + + // update min_stride for next iteration + const auto padded_dim = *md_padded_dims[d]; + min_stride = block_size * strides[d] * (padded_dim / blocks[d]); + } + return true; +} + +bool is_broadcast(const at::Tensor& t) { + for (int i = 0; i < t.dim(); i++) { + if (t.stride(i) == 0) + return true; + } + return false; +} + +bool is_onednn_matmul_strides( + const at::Tensor& tensor, + bool is_dst) { + // https://oneapi-src.github.io/oneDNN/dev_guide_matmul.html + // oneDNN matmul only support 2-dim and 3-dim + // 2D src(Mxk), wei(KxN), dst(MxN) + // 3D src(SxMxK), wei(WxKxN), dst(DxMxN) + auto sizes = tensor.sizes(); + auto tensor_dim = sizes.size(); + if (tensor_dim != 2 && tensor_dim != 3) + return false; + + if (tensor.is_contiguous()) + return true; + + // the overlaped cases are not supported + dnnl::memory::dims strides = get_onednn_strides(tensor); + int64_t storage_size = 1; + for (size_t dim = 0; dim < tensor_dim; ++dim) + storage_size += (sizes[dim] - 1) * strides[dim]; + if (storage_size < tensor.numel()) + return false; + + // the broadcast cases are not supported + if (is_broadcast(tensor)) { + return false; + } + + if (is_dst) { + // The memory format of the destination tensor should always + // be plain with n axis contiguous + if (strides[-1] != 1) + return false; + } else { + // the src and weight must have at least one of the axes + // m or k and n or k contiguous (i.e., stride=1) respectively. + if (strides[tensor_dim - 1] != 1 && strides[tensor_dim - 2] != 1) + return false; + } + + if (!onednn_strides_check(tensor)) + return false; + return true; +} + +bool is_broadcast_from_other_to_self( + const at::Tensor& self, + const at::Tensor& other) { + return ( + self.sizes() != other.sizes() && + at::is_expandable_to(other.sizes(), self.sizes())); +} + +at::MemoryFormat get_cl_tag_by_ndim(const int64_t ndim) { + TORCH_CHECK( + 3 == ndim || 4 == ndim || 5 == ndim, + "ndim must be 3, 4 or 5 when get cl tag"); + if (3 == ndim) { + return at::MemoryFormat::Contiguous; + } else if (5 == ndim) { + return at::MemoryFormat::ChannelsLast3d; + } else { + return at::MemoryFormat::ChannelsLast; + } +} + +bool binary_valid( + const at::Tensor& self, + const at::Tensor& other, + bool is_fusion) { + if (self.sizes() != other.sizes() && + !is_broadcast_from_other_to_self(self, other)) + return false; + + /* If the following conditions are satisfied, then oneDNN path will be + selected: + * 1. self and other should be xpu tensor and be defined. + * 2. self or other should not be scalar (wrapped tensor). + * 3. dim of self and other should be equal and must be larger than 0 and + smaller than 7. + * 4. the datatype should be supported by oneDNN primitive. + * 5. self and other should be in the same datatype. + * 6. self and other should be contiguous or channel-last contiguous.*/ + + + // 1. self and other should be xpu tensor and be defined. + if ((!self.defined()) || (!other.defined()) || (!self.is_xpu()) || + (!other.is_xpu())) + return false; + + // 2. self or other should not be scalar (wrapped tensor). + if (self.unsafeGetTensorImpl()->is_wrapped_number() || other.unsafeGetTensorImpl()->is_wrapped_number()) + return false; + + // 3. dim of self and other should be equal and must be larger than 0 and + // smaller than 7. + if ((self.dim() <= 0) || (other.dim() <= 0) || (self.dim() != other.dim()) || + (self.dim() > 6) || (other.dim() > 6)) + return false; + + // 4. the datatype should be supported by oneDNN primitive. + switch (self.scalar_type()) { + case at::ScalarType::Char: + break; + case at::ScalarType::Byte: + break; + case at::ScalarType::Half: + break; + case at::ScalarType::Float: + break; + case at::ScalarType::BFloat16: + break; + default: + return false; + }; + + // 5. datatype check + if (is_fusion) { + // for fusion case, the fusion can be performed on scalar_type or Float + // datatype. + if (self.scalar_type() != other.scalar_type() && + other.scalar_type() != at::ScalarType::Float) { + return false; + } + } else { + if (self.scalar_type() != other.scalar_type()) { + // for non-fusion case: self and other should be in the same datatype. + return false; + } + } + + // 6. self and other should be contiguous or channel-last contiguous. + const auto ndim = self.ndimension(); + auto cl_tag = at::MemoryFormat::ChannelsLast; + if (3 == ndim || 4 == ndim || 5 == ndim) { + cl_tag = get_cl_tag_by_ndim(ndim); + } + if ((self.is_contiguous() && other.is_contiguous()) || + (self.is_contiguous(cl_tag) && other.is_contiguous(cl_tag))) + return true; + return false; +} + +} diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h new file mode 100644 index 00000000000000..1fcb669d534b76 --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h @@ -0,0 +1,56 @@ +#pragma once +#include +#include +#include +#include + +#include +#include +#include +#include +#include + + +#define ONEDNN_SUPPORT_DETERMINISTIC (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=4) + +namespace at::native::onednn { + +dnnl::memory::format_tag get_dnnl_default_format( + int ndims, + bool is_channels_last = false, + bool allow_undef = false); + +dnnl::memory::data_type get_onednn_dtype( + const at::Tensor& tensor, + bool allow_undef = false); + +dnnl::memory::data_type get_onednn_dtype_include_double( + const at::Tensor& tensor, + bool allow_undef = false); + +bool is_supported_onednn_dtype(const at::Tensor& tensor); + +dnnl::memory::dims get_onednn_dims(const at::Tensor& tensor); + +dnnl::memory::dims get_onednn_strides(const at::Tensor& tensor); +dnnl::memory::desc get_onednn_md(const at::Tensor& tensor); + +bool onednn_strides_check(const at::Tensor& src); +bool is_broadcast(const at::Tensor& t); + +bool is_onednn_matmul_strides( + const at::Tensor& tensor, + bool is_dst = false); + +bool is_broadcast_from_other_to_self( + const at::Tensor& self, + const at::Tensor& other); + +at::MemoryFormat get_cl_tag_by_ndim(const int64_t ndim); + +bool binary_valid( + const at::Tensor& self, + const at::Tensor& other, + bool is_fusion = false); + +} // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h new file mode 100644 index 00000000000000..a34edfff363619 --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include +#include + + +namespace at::native::onednn{ + +TORCH_API sycl::event matmul( + at::Tensor& result, + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Tensor& b_raw, + bool m2_trans, + Attr attr, + const std::vector& deps = {}); + +} // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.cpp new file mode 100644 index 00000000000000..9bec64c8c0248f --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.cpp @@ -0,0 +1,27 @@ +#include +#include + +/* * + * Do NOT put any kernels or call any device binaries here! + * Only maintain oneDNN runtime states in this file. + * */ +namespace at::native::onednn { + +using namespace dnnl; + +GpuEngineManager& GpuEngineManager::Instance() { + static GpuEngineManager myInstance; + return myInstance; +} + +GpuStreamManager& GpuStreamManager::Instance() { + static thread_local GpuStreamManager myInstance; + return myInstance; +} + +bool set_onednn_verbose(int level) { + dnnl::status rs = dnnl::set_verbose(level); + return rs == dnnl::status::success; +} + +} // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h new file mode 100644 index 00000000000000..c7e7a5e94b406b --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h @@ -0,0 +1,75 @@ +#pragma once + +#include + +#include +#include +#include + +#include +#include +#include + +namespace at::native::onednn { + +TORCH_API dnnl::memory make_onednn_memory( + dnnl::memory::desc md, + dnnl::engine& engine, + void* ptr); + +// Keep non-static and non-inline +bool set_onednn_verbose(int level); + +// GpuEngineManager singleton +struct TORCH_API GpuEngineManager { + static GpuEngineManager& Instance(); // Singleton + + dnnl::engine& get_engine(const Device& device) { + TORCH_INTERNAL_ASSERT(device.type() == kXPU); + TORCH_INTERNAL_ASSERT(device.index() < c10::xpu::device_count()); + return *engine_pool[device.index()]; + } + + GpuEngineManager(GpuEngineManager const&) = delete; + GpuEngineManager& operator=(GpuEngineManager const&) = delete; + + protected: + GpuEngineManager() { + int device_count = (int)c10::xpu::device_count(); + TORCH_INTERNAL_ASSERT(device_count > 0); + for (int i = 0; i < device_count; i++) { + engine_pool.push_back( + std::make_shared(dnnl::sycl_interop::make_engine( + c10::xpu::get_raw_device(i), c10::xpu::get_device_context() + ))); + } + } + ~GpuEngineManager() {} + + private: + std::vector> engine_pool; +}; + +// GpuStreamManager singleton +struct TORCH_API GpuStreamManager { + static GpuStreamManager& Instance(); // Singleton + + dnnl::stream get_stream() { + c10::DeviceIndex device_index = c10::xpu::current_device(); + TORCH_INTERNAL_ASSERT(device_index < c10::xpu::device_count()); + return dnnl::sycl_interop::make_stream( + GpuEngineManager::Instance().get_engine({c10::kXPU, device_index}), + c10::xpu::getCurrentXPUStream(device_index).queue()); + } + + GpuStreamManager(GpuStreamManager const&) = delete; + GpuStreamManager& operator=(GpuStreamManager const&) = delete; + + protected: + GpuStreamManager() { + } + ~GpuStreamManager() {} + +}; + +} // namespace at::native::onednn