From 6f6705c18d7499d91613a3e4ae33832869052ba1 Mon Sep 17 00:00:00 2001 From: shivadbhavsar <105248561+shivadbhavsar@users.noreply.github.com> Date: Wed, 11 Dec 2024 19:45:25 -0800 Subject: [PATCH 01/11] catch python buffer unsupported types (#3701) --- src/py/migraphx_py.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index bafe0fdedf8..925c4546c1f 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -263,6 +263,13 @@ migraphx::shape to_shape(const py::buffer_info& info) { migraphx::shape::type_t t; std::size_t n = 0; + // Unsupported pybuffer types lead to undefined behaviour when comparing with migraphx type enum + if(info.format == "z") + { + MIGRAPHX_THROW( + "MIGRAPHX PYTHON: Unsupported data type. For fp8 and bf16 literals try using " + "migraphx.generate_argument with migraphx.add_literal"); + } visit_types([&](auto as) { if(info.format == py::format_descriptor::format() or (info.format == "l" and py::format_descriptor::format() == "q") or @@ -388,6 +395,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) py::arg("op"), py::arg("args"), py::arg("mod_args") = std::vector{}) + .def( + "add_literal", + [](migraphx::module& mm, migraphx::argument a) { + return mm.add_literal(a.get_shape(), a.data()); + }, + py::arg("data")) .def( "add_literal", [](migraphx::module& mm, py::buffer data) { From 79a2561f83cc51b0f5559bb2bec765985a925ad6 Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Thu, 12 Dec 2024 09:38:48 -0500 Subject: [PATCH 02/11] Updates to CHANGELOG for 6.3 (#3703) Changelog updates based on taking all the commits in 6.3 since 6.2 and boiling them down to human readable descriptions of the changes. --- CHANGELOG.md | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 98f4f775feb..8461e3a372d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,83 @@ Full documentation for MIGraphX is available at [https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/](https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/). +## MIGraphX 2.11 for ROCm 6.3.0 + +### Added + +* Initial code to run on Windows +* Support for gfx120x GPU +* Support for FP8, and INT4 +* Support for the Log2 internal operator +* Support for the GCC 14 compiler +* The BitwiseAnd, Scan, SoftmaxCrossEntropyLoss, GridSample, and NegativeLogLikelihoodLoss ONNX operators +* The MatMulNBits, QuantizeLinear/DequantizeLinear, GroupQueryAttention, SkipSimplifiedLayerNormalization, and SimpliedLayerNormalization Microsoft Contrib operators +* Dymamic batch parameter support to OneHot operator +* Split-K as an optional performance improvement +* Scripts to validate ONNX models from the ONNX Model Zoo +* GPU Pooling Kernel +* --mlir flag to the migraphx-driver program to offload entire module to mlir +* Fusing split-reduce with MLIR +* Multiple outputs for the MLIR + Pointwise fusions +* Pointwise fusions with MLIR across reshape operations +* MIGRAPHX_MLIR_DUMP environment variable to dump MLIR modules to MXRs +* The 3 option to MIGRAPHX_TRACE_BENCHMARKING to print the MLIR program for improved debug output +* MIGRAPHX_ENABLE_HIPBLASLT_GEMM environment variable to call hipBlasLt libaries +* MIGRAPHX_VERIFY_DUMP_DIFF to improve the debugging of accuracy issues +* reduce_any and reduce_all options to the Reduce operation via Torch MIGraphX +* Examples for RNNT, and ControlNet + + +### Changed + +* Switched to MLIR's 3D Convolution operator. +* MLIR is now used for Attention operations by default on gfx942 and newer ASICs. +* Names and locations for VRM specific libraries have changed. +* Use random mode for benchmarking GEMMs and convolutions. +* Python version is now printed with an actual version number. + + +### Removed + +* Disabled requirements for MIOpen and rocBlas when running on Windows. +* Removed inaccuracte warning messages when using exhaustive-tune. +* Remove the hard coded path in MIGRAPHX_CXX_COMPILER allowing the compiler to be installed in different locations. + + +### Optimized + +* Improved: + * Infrastructure code to enable better Kernel fusions with all supported data types + * Subsequent model compile time by creating a cache for already performant kernels + * Use of Attention fusion with models + * Performance of the Softmax JIT kernel and of the Pooling opterator + * Tuning operations through a new 50ms delay before running the next kernel + * Performance of several convolution based models through an optimized NHWC layout + * Performance for the FP8 datatype + * GPU utilization + * Verification tools + * Debug prints + * Documentation, including gpu-driver utility documentation + * Summary section of the migrahx-driver perf command +* Reduced model compilation time +* Reordered some compiler passes to allow for more fusions +* Preloaded tiles into LDS to improve performance of pointwise transposes +* Exposed the external_data_path property in onnx_options to set the path from onnxruntime + + +### Resolved Issues + +* Fixed a bug with gfx1030 that overwrote dpp_reduce. +* Fixed a bug in 1arg dynamic reshape that created a failure. +* Fixed a bug with dot_broadcast and inner_broadcast that caused compile failures. +* Fixed a bug where some configs were failing when using exhaustive-tune. +* Fixed the ROCM Install Guide URL. +* Fixed an issue while building a whl package due to an apostrophe. +* Fixed the BERT Squad example requirements file to support different versions of Python. +* Fixed a bug that stopped the Vicuna model from compiling. +* Fixed failures with the verify option of migraphx-driver that would cause the application to exit early. + + ## MIGraphX 2.10 for ROCm 6.2.0 ### Additions From f56b1b4f14bfa198ad4c17befb7c35592fbae7ef Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Fri, 13 Dec 2024 14:07:48 -0600 Subject: [PATCH 03/11] Enable split reduce by default (#3709) --- docs/dev/env_vars.rst | 6 +++--- src/fuse_pointwise_reduce.cpp | 13 +++++++++++++ src/include/migraphx/fuse_pointwise_reduce.hpp | 1 + src/split_reduce.cpp | 2 -- src/targets/gpu/target.cpp | 2 -- test/split_reduce.cpp | 1 + 6 files changed, 18 insertions(+), 7 deletions(-) diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index cc7915df879..a05b99cef64 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -116,9 +116,9 @@ Disables the ``schedule`` pass. Set to "1", "enable", "enabled", "yes", or "true" to use. Disables the ``fuse_reduce`` pass. -.. envvar:: MIGRAPHX_ENABLE_SPLIT_REDUCE -Set to "1", "enable", "enabled", "yes", or "true" to use. -Enable split_reduce. +.. envvar:: MIGRAPHX_SPLIT_REDUCE_SIZE +Set to the minimum size of a reduction to do a split reduce. Overrides what +is set in the backend. Set to -1 to disable split reduce completely. .. envvar:: MIGRAPHX_ENABLE_NHWC diff --git a/src/fuse_pointwise_reduce.cpp b/src/fuse_pointwise_reduce.cpp index dfd3f474ba2..eb2feb565b7 100644 --- a/src/fuse_pointwise_reduce.cpp +++ b/src/fuse_pointwise_reduce.cpp @@ -26,9 +26,20 @@ #include #include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SPLIT_REDUCE_SIZE); + +static std::size_t get_split_size(std::size_t default_split) +{ + std::string value = string_value_of(MIGRAPHX_SPLIT_REDUCE_SIZE{}); + if(value.empty()) + return default_split; + return std::stoul(value); +} void fuse_pointwise_reduce::apply(module_pass_manager& mpm) const { @@ -36,6 +47,8 @@ void fuse_pointwise_reduce::apply(module_pass_manager& mpm) const mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = false}); mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = true}); mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = true}); + mpm.run_pass(split_reduce{.split_size = get_split_size(split_size)}); + mpm.run_pass(fuse_pointwise{.enable_rewrite_broadcasts = true}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/fuse_pointwise_reduce.hpp b/src/include/migraphx/fuse_pointwise_reduce.hpp index 68bdc4e9951..63d78d2360b 100644 --- a/src/include/migraphx/fuse_pointwise_reduce.hpp +++ b/src/include/migraphx/fuse_pointwise_reduce.hpp @@ -35,6 +35,7 @@ struct module_pass_manager; struct MIGRAPHX_EXPORT fuse_pointwise_reduce { + std::size_t split_size = 32768; std::string name() const { return "fuse_pointwise_reduce"; } void apply(module_pass_manager& mpm) const; }; diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 91bdfc9924f..3188b00563f 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -237,8 +237,6 @@ void split_reduce::apply(module_pass_manager& mpm) const assert(replaced.size() == 1); mpm.get_module().replace_instruction(ins, replaced.front()); } - - mpm.run_pass(fuse_pointwise{.enable_rewrite_broadcasts = true}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index b70c8cd1c10..9320ed86f9f 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -77,7 +77,6 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SPLIT_REDUCE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) #ifndef _WIN32 MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) @@ -211,7 +210,6 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, optimize_module{}, fuse_pointwise_reduce{}, - enable_pass(enabled(MIGRAPHX_ENABLE_SPLIT_REDUCE{}), split_reduce{}), dead_code_elimination{}, #ifndef _WIN32 enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}), diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index e837b577882..9dd15eab203 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -41,6 +41,7 @@ void run_pass(migraphx::program& p) {migraphx::fuse_pointwise{}, migraphx::fuse_reduce{}, migraphx::split_reduce{.split_size = 8192}, + migraphx::fuse_pointwise{.enable_rewrite_broadcasts = true}, migraphx::dead_code_elimination{}}); } From de20bd06225e95d6e6c34432c5f14ce4774daabb Mon Sep 17 00:00:00 2001 From: Ahsan Saghir <142340507+ahsan-ca@users.noreply.github.com> Date: Sat, 14 Dec 2024 11:16:36 -0500 Subject: [PATCH 04/11] normalize standard input shapes to hipblaslt gemms (#3712) --- src/targets/gpu/fuse_ops.cpp | 2 +- src/targets/gpu/hip_gemm_impl.cpp | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/targets/gpu/fuse_ops.cpp b/src/targets/gpu/fuse_ops.cpp index 50ab5d184f3..70d77b5037c 100644 --- a/src/targets/gpu/fuse_ops.cpp +++ b/src/targets/gpu/fuse_ops.cpp @@ -735,7 +735,7 @@ struct find_hipblas_gemm_pointwise : gemm_pointwise shape s = c_ins->get_shape(); // const-fold input if not standard shape // Updated for a case where "standard" shape has out-of-sequence strides - if(not s.standard() or s.normalize_standard() != s) + if(not s.standard()) { auto c = make_op("contiguous"); auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()}); diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index 03bda69081e..e639e30706e 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -92,10 +92,11 @@ hipDataType get_type_hipblas(shape::type_t type) MIGRAPHX_THROW("HIPBLAS_GEMM: data type not supported!"); } -void blas_shape_hip(const shape& s) +void blas_shape_hip(const shape& in_shape) { - if(s.lens().size() < 2) + if(in_shape.lens().size() < 2) return; + auto s = in_shape.normalize_standard(); if(std::none_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 1; })) MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1"); if(std::any_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 0; })) @@ -669,7 +670,7 @@ void hip_gemm_compute(context& ctx, std::transform(args.begin(), args.end(), std::back_inserter(input_shapes), - [](const argument& x) { return x.get_shape(); }); + [](const argument& x) { return x.get_shape().normalize_standard(); }); auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta); gemm_item.run(ctx, args, solution_idx); } From e8bfc2c645e9dbb1ac56fbb489d7916054b2c651 Mon Sep 17 00:00:00 2001 From: Richa Gadgil Date: Sat, 14 Dec 2024 08:24:26 -0800 Subject: [PATCH 05/11] Add inference percentile details to driver perf (#3504) Use flag --r or --runtimes with perf to export times to JSON Updated perf output: Batch size: 1 Rate: 898.852 inferences/sec Total time: 1.11253ms (Min: 1.095ms, Max: 3.7384ms, Mean: 1.14045ms, Median: 1.11241ms) Percentiles (90%, 95%, 99%): (1.12391ms, 1.13541ms, 1.22018ms) Total instructions time: 2.40689ms Overhead time: 0.0542455ms, -1.29436ms Overhead: 5%, -116% --- src/program.cpp | 43 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/src/program.cpp b/src/program.cpp index b39f3936371..2d43f3f8d55 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -44,6 +44,7 @@ #include #include #include +#include #include #include #include @@ -845,6 +846,31 @@ double common_average(const std::vector& v) return total / std::distance(v.begin() + n, v.end() - n); } +double mean(const std::vector& v) +{ + double total = std::accumulate(v.begin(), v.end(), 0.0); + return total / v.size(); +} + +double median(const std::vector& v) +{ + size_t mid = v.size() / 2; + if(v.size() % 2 == 0) + { + return (v[mid - 1] + v[mid]) / 2.0; + } + else + { + return v[mid]; + } +} + +double percentile(const std::vector& v, double percentile) +{ + size_t index = (percentile * (v.size() - 1)); + return v[index]; +} + std::string perf_group(instruction_ref ins, bool detailed) { std::string result; @@ -925,8 +951,14 @@ void program::perf_report( { overhead_vec.push_back(time([&] { dry_run(params); })); } - double total_time = common_average(total_vec); + double min_time = total_vec.front(); + double max_time = total_vec.back(); + double mean_time = mean(total_vec); + double median_time = median(total_vec); + double percentile_90_time = percentile(total_vec, 0.90); + double percentile_95_time = percentile(total_vec, 0.95); + double percentile_99_time = percentile(total_vec, 0.99); double rate = 1000.0 / total_time; double overhead_time = common_average(overhead_vec); double overhead_percent = overhead_time * 100.0 / total_time; @@ -978,7 +1010,14 @@ void program::perf_report( os << "Batch size: " << batch << std::endl; os << "Rate: " << rate * batch << " inferences/sec" << std::endl; - os << "Total time: " << total_time << "ms" << std::endl; + os << "Total time: " << total_time << "ms "; + os << "(Min: " << min_time << "ms, "; + os << "Max: " << max_time << "ms, "; + os << "Mean: " << mean_time << "ms, "; + os << "Median: " << median_time << "ms)" << std::endl; + os << "Percentiles (90%, 95%, 99%): ("; + os << percentile_90_time << "ms, " << percentile_95_time << "ms, " << percentile_99_time + << "ms)" << std::endl; os << "Total instructions time: " << total_instruction_time << "ms" << std::endl; os << "Overhead time: " << overhead_time << "ms" << ", " << calculate_overhead_time << "ms" << std::endl; From b0072d92145bed23525a69be2b29a1c8443fb125 Mon Sep 17 00:00:00 2001 From: Ahsan Saghir <142340507+ahsan-ca@users.noreply.github.com> Date: Sat, 14 Dec 2024 11:25:02 -0500 Subject: [PATCH 06/11] Add changes for contiguous transpose gemm fusion for hipblaslt (#3706) --- src/targets/gpu/fuse_ops.cpp | 87 ++++++++++++++++--- src/targets/gpu/hip_gemm_impl.cpp | 14 +++ .../gpu/include/migraphx/gpu/hip_gemm.hpp | 8 +- 3 files changed, 97 insertions(+), 12 deletions(-) diff --git a/src/targets/gpu/fuse_ops.cpp b/src/targets/gpu/fuse_ops.cpp index 70d77b5037c..53d3e3b563c 100644 --- a/src/targets/gpu/fuse_ops.cpp +++ b/src/targets/gpu/fuse_ops.cpp @@ -753,16 +753,8 @@ struct find_hipblas_gemm_pointwise : gemm_pointwise }; #endif -struct find_contiguous_tranpose_gemm +struct contiguous_transpose_gemm { - auto matcher() const - { - return match::name("gpu::contiguous")(match::arg(0)( - match::name("transpose")( - match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm"))) - .bind("transpose"))); - } - template static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j) { @@ -773,6 +765,17 @@ struct find_contiguous_tranpose_gemm std::swap(perm2[i], perm2[j]); return perm2 == perm; } +}; + +struct find_contiguous_transpose_rocblas_gemm : contiguous_transpose_gemm +{ + auto matcher() const + { + return match::name("gpu::contiguous")(match::arg(0)( + match::name("transpose")( + match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm"))) + .bind("transpose"))); + } void apply(module& m, const match::matcher_result& r) const { @@ -811,6 +814,67 @@ struct find_contiguous_tranpose_gemm } }; +#if MIGRAPHX_USE_HIPBLASLT +struct find_contiguous_transpose_hip_gemm : contiguous_transpose_gemm +{ + auto matcher() const + { + return match::name("gpu::contiguous")(match::arg(0)( + match::name("transpose")( + match::arg(0)( + match::name("gpu::hipblaslt_op")(match::used_once()).bind("hip_gemm"))) + .bind("transpose"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto gemm_ins = r.instructions["hip_gemm"]; + auto gemm_op = any_cast(gemm_ins->get_operator()).op; + + if(gemm_op.name() != "gpu::hip_gemm") + return; + + auto gemm = any_cast>(gemm_op); + + auto alloc = gemm_ins->inputs().back(); + auto transpose = r.instructions["transpose"]; + auto perm = transpose->get_operator().to_value()["permutation"].to_vector(); + auto iperm = invert_permutation(perm); + + if(perm.size() < 3) + return; + + if(not is_swapped(perm, perm.size() - 3, perm.size() - 2)) + return; + + auto lens = gemm_ins->get_shape().lens(); + if(lens.size() > 3 and + not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; })) + return; + + gemm.trans_batch = 1; + + auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)}; + auto new_alloc = + m.insert_instruction(gemm_ins, make_op("allocate", {{"shape", to_value(s)}})); + + auto alloc_transpose = m.insert_instruction( + gemm_ins, make_op("transpose", {{"permutation", perm}}), new_alloc); + + auto inputs = gemm_ins->inputs(); + inputs.back() = alloc_transpose; + operation new_gemm_op = gemm; + auto new_gemm = m.insert_instruction( + gemm_ins, make_op("gpu::hipblaslt_op", {{"op", to_value(new_gemm_op)}}), inputs); + + auto gemm_transpoe = m.insert_instruction(gemm_ins, transpose->get_operator(), new_gemm); + + m.replace_instruction(ins, gemm_transpoe); + } +}; +#endif + struct find_commutative_broadcast { auto matcher() const @@ -980,7 +1044,10 @@ void fuse_ops::apply(module& m) const match::find_matches(m, find_layernorm_pointwise{}, find_concat_pointwise{}, - find_contiguous_tranpose_gemm{}, + find_contiguous_transpose_rocblas_gemm{}, +#if MIGRAPHX_USE_HIPBLASLT + find_contiguous_transpose_hip_gemm{}, +#endif find_commutative_broadcast{}); match::find_matches(m, find_contiguous{}); } diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index e639e30706e..966927da7b5 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -31,6 +31,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -111,6 +112,19 @@ void blas_shape_hip(const shape& in_shape) MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); } +shape transpose_batch_hip(const shape& s, unsigned trans_batch) +{ + if(trans_batch == 0) + return s; + if(s.lens().size() < 3) + return s; + auto batch = s.lens().size() - 3; + std::vector perm(s.lens().size()); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[batch], perm[batch + trans_batch]); + return shape::from_permutation(s.type(), s.lens(), perm); +} + static bool is_transposed_hip(const shape& s) { return s.transposed() and s.strides().back() != 1; } static int32_t get_batch_stride_hip(const shape& s) diff --git a/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp b/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp index 9f74bc02813..8c3d67bcd93 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp @@ -41,6 +41,7 @@ namespace gpu { struct context; void blas_shape_hip(const shape& s); +shape transpose_batch_hip(const shape& s, unsigned trans_batch); template struct hip_gemm @@ -48,13 +49,16 @@ struct hip_gemm Op op; float alpha = 1; float beta = 0; + unsigned trans_batch = 0; int32_t solution_idx = 0; + template static auto reflect(Self& self, F f) { return pack_join(migraphx::reflect(self.op, f), pack(f(self.alpha, "alpha"), f(self.beta, "beta"), + f(self.trans_batch, "trans_batch"), f(self.solution_idx, "solution_idx"))); } @@ -98,10 +102,10 @@ struct hip_gemm to_string(cmat_shape.type()) + ", it must be: " + to_string(op_out_shape.type())); } - return op_out_shape; + return transpose_batch_hip(op_out_shape, trans_batch); } - return op.compute_shape(in_shapes); + return transpose_batch_hip(op.compute_shape(in_shapes), trans_batch); } argument From e90d1cf3b45fadba1e6ebe6e7581ab97e6c7ddc6 Mon Sep 17 00:00:00 2001 From: Ahsan Saghir <142340507+ahsan-ca@users.noreply.github.com> Date: Sat, 14 Dec 2024 11:25:45 -0500 Subject: [PATCH 07/11] Limit hip gemm pointwise fusion to dot (#3702) This PR makes change to ensure that hip gemm pointwise fusion is performed for dot, and not for quant_dot. --- src/targets/gpu/fuse_ops.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/targets/gpu/fuse_ops.cpp b/src/targets/gpu/fuse_ops.cpp index 53d3e3b563c..5e93ccf5ecf 100644 --- a/src/targets/gpu/fuse_ops.cpp +++ b/src/targets/gpu/fuse_ops.cpp @@ -715,6 +715,9 @@ struct find_hipblas_gemm_pointwise : gemm_pointwise auto gemm_op = any_cast(gemm_ins->get_operator()).op; + if(gemm_op.name() != "gpu::hip_gemm") + return; + auto gemm = any_cast>(gemm_op); // Already fused gemm From 11e2c4774f275604a2cb48a4648409b0720920cc Mon Sep 17 00:00:00 2001 From: Richa Gadgil Date: Sat, 14 Dec 2024 08:26:25 -0800 Subject: [PATCH 08/11] Bf16 gpu support (#3630) --- src/driver/main.cpp | 10 +++ src/driver/precision.hpp | 1 + src/driver/verify.cpp | 13 +++- src/include/migraphx/quantization.hpp | 3 + src/py/migraphx_py.cpp | 4 + src/quantization.cpp | 10 +++ .../migraphx/gpu/device/float_equal.hpp | 10 +++ .../include/migraphx/gpu/device/types.hpp | 74 +++++++++++++++---- .../include/migraphx/gpu/device/visit.hpp | 4 + src/targets/gpu/fuse_mlir.cpp | 8 +- src/targets/gpu/gemm_impl.cpp | 2 +- .../gpu/include/migraphx/gpu/miopen.hpp | 2 + .../include/migraphx/kernels/types.hpp | 1 + src/targets/gpu/lowering.cpp | 3 +- src/targets/gpu/mlir.cpp | 24 +++--- src/targets/gpu/target.cpp | 1 + 16 files changed, 138 insertions(+), 32 deletions(-) diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 04fa0cfe3bc..aba29975e4b 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -482,6 +482,7 @@ struct compiler compiler_target ct; compile_options co; bool to_fp16 = false; + bool to_bf16 = false; bool to_fp8 = false; bool to_int8 = false; bool to_int4 = false; @@ -506,6 +507,7 @@ struct compiler ap.help("Exhastively search for best tuning parameters for kernels"), ap.set_value(true)); ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true)); + ap(to_bf16, {"--bf16"}, ap.help("Quantize for bf16"), ap.set_value(true)); ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true)); ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8"), ap.set_value(true)); ap(to_int4, {"--int4-weights"}, ap.help("Quantize weights for int4"), ap.set_value(true)); @@ -555,6 +557,10 @@ struct compiler { quantize_fp16(p); } + if(to_bf16) + { + quantize_bf16(p); + } if(to_int8) { quantize_int8(p, t, {host_params(p)}); @@ -639,6 +645,10 @@ struct verify : command { vo.quantize = precision::fp16; } + if(c.to_bf16) + { + vo.quantize = precision::bf16; + } if(c.to_int8) { vo.quantize = precision::int8; diff --git a/src/driver/precision.hpp b/src/driver/precision.hpp index d7d7cecf00e..9ed1f402f9d 100644 --- a/src/driver/precision.hpp +++ b/src/driver/precision.hpp @@ -32,6 +32,7 @@ enum class precision { fp32, fp16, + bf16, int8 }; diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index 92bae3eee86..14f9e71f70f 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -50,11 +50,14 @@ verify::tolerance get_tolerances(const program& p, std::optional atol, std::optional rtol) { - bool has_fp16 = any_of(p.get_modules(), [](auto&& m) { - return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::half_type); }); + bool has_16bit = any_of(p.get_modules(), [](auto&& m) { + return any_of(*m, [](auto&& ins) { + return (ins.get_shape().type() == shape::half_type or + ins.get_shape().type() == shape::bf16_type); + }); }); migraphx::verify::tolerance result{}; - if(has_fp16 or vo.quantize == precision::fp16) + if(has_16bit or vo.quantize == precision::fp16 or vo.quantize == precision::bf16) { result.rms_tol = 8e-2; result.atol = 4e-2; @@ -100,6 +103,10 @@ std::vector run_target(program p, { quantize_fp16(p); } + if(vo.quantize == precision::bf16) + { + quantize_bf16(p); + } p.compile(t, options); parameter_map m; diff --git a/src/include/migraphx/quantization.hpp b/src/include/migraphx/quantization.hpp index d849023b6cf..eead5e40ba1 100644 --- a/src/include/migraphx/quantization.hpp +++ b/src/include/migraphx/quantization.hpp @@ -51,6 +51,9 @@ quantize_fp8(program& prog, const target& t, const std::vector& c MIGRAPHX_EXPORT void quantize_int4_weights(program& prog); +MIGRAPHX_EXPORT void quantize_bf16(program& prog, + const std::vector& ins_names = {"all"}); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 925c4546c1f..75f7fab09d9 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -664,6 +664,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) }, "Auto-convert FP8 parameters and return values to Float for MIGraphX Program", py::arg("prog")); + m.def("quantize_bf16", + &migraphx::quantize_bf16, + py::arg("prog"), + py::arg("ins_names") = std::vector{"all"}); #ifdef HAVE_GPU m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); diff --git a/src/quantization.cpp b/src/quantization.cpp index 7e02ae66685..276012bbf73 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -74,6 +74,16 @@ void quantize_fp16(program& prog, const std::vector& ins_names) quant_tracer()); } +void quantize_bf16(program& prog, const std::vector& ins_names) +{ + run_passes(prog, + {normalize_ops{}, + optimize_module{{"quantizelinear", "dequantizelinear"}}, + truncate_float_pass{ins_names, shape::bf16_type}, + optimize_module{{"quantizelinear", "dequantizelinear"}}}, + quant_tracer()); +} + void quantize_8bits(program& prog, const target& t, shape::type_t precision, diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp index 9fb6f858d18..a5f18fc5aa8 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp @@ -44,6 +44,16 @@ __device__ bool float_equal_device(T x, T y) std::nextafter(x, std::numeric_limits::max()) >= y; } +template <> +__device__ bool float_equal_device(__bf16 x, __bf16 y) // NOLINT(misc-definitions-in-headers) +{ + float xf = x; + float yf = y; + return std::isfinite(xf) and std::isfinite(yf) and + std::nextafter(xf, std::numeric_limits::lowest()) <= yf and + std::nextafter(xf, std::numeric_limits::max()) >= yf; +} + template {})> __device__ bool float_equal_device(T x, T y) { diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp index 19fb02763fb..c9f2e3d7cd4 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -67,6 +68,7 @@ auto pack_vec(Ts... xs) } using gpu_half = __fp16; +using gpu_bf16 = __bf16; namespace detail { template @@ -87,6 +89,12 @@ struct device_type using type = gpu_half; }; +template <> +struct device_type +{ + using type = gpu_bf16; +}; + template struct host_type { @@ -99,6 +107,12 @@ struct host_type using type = half; }; +template <> +struct host_type +{ + using type = bf16; +}; + } // namespace detail template @@ -143,23 +157,53 @@ __device__ __host__ T to_hip_type(T x) return x; } -// Hip doens't support __fp16 +// Hip doens't support __fp16 and __bf16 inline __device__ __host__ float to_hip_type(gpu_half x) { return x; } +inline __device__ __host__ float to_hip_type(gpu_bf16 x) { return x; } + +template +struct is_floating_point : std::is_floating_point +{ +}; + +template <> +struct is_floating_point<__fp16> : std::true_type +{ +}; + +template +struct is_signed : std::is_signed +{ +}; + +template <> +struct is_signed<__fp16> : std::true_type +{ +}; + +template +struct is_arithmetic : std::is_arithmetic +{ +}; + +template <> +struct is_arithmetic<__fp16> : std::true_type +{ +}; -#define MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ - template \ - struct trait : std::trait \ - { \ - }; \ - \ - template <> \ - struct trait : std::true_type \ - { \ - }; - -MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16) -MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16) -MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16) +// Redo for __bf16 +template <> +struct is_floating_point<__bf16> : std::true_type +{ +}; +template <> +struct is_signed<__bf16> : std::true_type +{ +}; +template <> +struct is_arithmetic<__bf16> : std::true_type +{ +}; } // namespace device } // namespace gpu diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp index 18981399364..78f28a552bd 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp @@ -98,6 +98,10 @@ template <> struct is_hip_type : std::true_type { }; +template <> +struct is_hip_type : std::true_type +{ +}; template {})> void hip_visitor_invoke(T as, V&& v) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 1cf97aa66b4..65a27a76ad1 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -390,6 +390,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) const auto& name = i.name(); const auto result_type = i.get_shape().type(); const std::initializer_list allowed_types = {type_t::float_type, + type_t::bf16_type, type_t::half_type, type_t::fp8e4m3fnuz_type, type_t::fp8e5m2fnuz_type, @@ -439,6 +440,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) }; std::set float_types = {type_t::float_type, type_t::half_type, + type_t::bf16_type, type_t::fp8e4m3fnuz_type, type_t::fp8e5m2fnuz_type, type_t::fp8e4m3fn_type, @@ -459,7 +461,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) return false; } // else return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { - return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); + return contains({type_t::float_type, type_t::half_type, type_t::bf16_type}, + arg->get_shape().type()); }); } return false; @@ -472,10 +475,12 @@ bool is_reduce_op_supported_by_mlir(const instruction& i) const auto result_type = i.get_shape().type(); const std::initializer_list allowed_types = {type_t::float_type, type_t::half_type, + type_t::bf16_type, type_t::fp8e4m3fnuz_type, type_t::fp8e5m2fnuz_type, type_t::fp8e4m3fn_type, type_t::fp8e5m2_type}; + // Preliminary type check. if(not contains(allowed_types, result_type)) { @@ -732,6 +737,7 @@ struct find_mlir_standalone_op if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { return not contains({shape::type_t::float_type, shape::type_t::half_type, + shape::type_t::bf16_type, shape::type_t::int8_type, shape::type_t::fp8e4m3fnuz_type, shape::type_t::fp8e5m2fnuz_type, diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 6d21a29e2a0..d0f750a2501 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -224,7 +224,7 @@ struct gemm_impl compute_type = rb_compute_type{output_type}; if(compute_fp32) { - if(arg_type == rocblas_datatype_f16_r) + if(arg_type == rocblas_datatype_f16_r or arg_type == rocblas_datatype_bf16_r) compute_type = rocblas_datatype_f32_r; } if(arg_type == rocblas_datatype_f8_r) diff --git a/src/targets/gpu/include/migraphx/gpu/miopen.hpp b/src/targets/gpu/include/migraphx/gpu/miopen.hpp index fb61103538d..87a561ad6f4 100644 --- a/src/targets/gpu/include/migraphx/gpu/miopen.hpp +++ b/src/targets/gpu/include/migraphx/gpu/miopen.hpp @@ -143,6 +143,8 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os) d = miopenInt32; else if(s.type() == shape::int8_type) d = miopenInt8; + else if(s.type() == shape::bf16_type) + d = miopenBFloat16; else MIGRAPHX_THROW("MAKE_TENSOR: unsupported type"); miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data()); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp index f65cdfbba34..c88343ce16d 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp @@ -76,6 +76,7 @@ using vec = T __attribute__((ext_vector_type(N))); using half = _Float16; using half2 = migraphx::vec; +using bf16 = __bf16; } // namespace migraphx diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index bc2e887fa87..adba5466135 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -325,7 +325,8 @@ struct miopen_apply static bool use_miopen_pooling(instruction_ref ins) { - if(enabled(MIGRAPHX_DISABLE_MIOPEN_POOLING{})) + if(enabled(MIGRAPHX_DISABLE_MIOPEN_POOLING{}) or + not contains({shape::float_type, shape::half_type}, ins->get_shape().type())) return false; auto&& op = ins->get_operator(); auto op_val = op.to_value(); diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 154fa762c26..61e0325ac96 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -312,6 +312,8 @@ struct mlir_program result = mlirF32TypeGet(ctx.get()); else if(as.type_enum() == shape::half_type) result = mlirF16TypeGet(ctx.get()); + else if(as.type_enum() == shape::bf16_type) + result = mlirBF16TypeGet(ctx.get()); else if(as.type_enum() == shape::fp8e4m3fnuz_type) result = mlirFloat8E4M3FNUZTypeGet(ctx.get()); else if(as.type_enum() == shape::fp8e5m2fnuz_type) @@ -444,15 +446,15 @@ struct mlir_program } using attribute_t = std::variant, - MlirType, - MlirAttribute>; + std::uint64_t, + unsigned char, + bool, + double, + std::string, + value, + std::vector, + MlirType, + MlirAttribute>; using named_attribute_t = std::pair; MlirNamedAttribute name_attribute(const named_attribute_t& na) const @@ -1155,7 +1157,7 @@ mlir_code_object compile_mlir(const context& migraphx_ctx, const std::lock_guard lock(mutex); std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; } - auto co = mp.compile(solution); + auto co = mp.compile(solution); co.expected_inputs = in_shapes; auto out_shapes = m.get_output_shapes(); @@ -1248,7 +1250,7 @@ void dump_mlir_to_mxr(module m, sizes.insert(sizes.end(), ins->inputs().begin(), ins->inputs().end()); } auto name = compute_dump_name(m, ".mxr"); - auto f = location / name; + auto f = location / name; std::cout << "Dumping MXR file to: " << f << std::endl; save(program{std::move(m)}, f.string()); } diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 9320ed86f9f..ad98fb680fe 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -100,6 +100,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::tuple_type); + unsupported_types.erase(shape::type_t::bf16_type); // whiltelist supported Ops for the FP8 types // different between fp8e4m3fnuz and OCP types because rocBLAS only has From 9247ae0b7650af881788a112ff1758b8f67445bf Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 14 Dec 2024 11:29:26 -0500 Subject: [PATCH 09/11] Update onnxruntime main 62e7e24f172a062242acae11575f7ea11529dd09 (#3711) --- test/onnx/.onnxrt-commit | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/.onnxrt-commit b/test/onnx/.onnxrt-commit index 5dc54bfc42d..75454c00440 100644 --- a/test/onnx/.onnxrt-commit +++ b/test/onnx/.onnxrt-commit @@ -1 +1 @@ -d27fecd3d3837864a268bc96f00f2b8dce294697 +62e7e24f172a062242acae11575f7ea11529dd09 From 429396091411b15bebd471712f79b963a67420d4 Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Sat, 14 Dec 2024 11:30:27 -0500 Subject: [PATCH 10/11] Bump Dockers to 6.3 (#3687) --- Dockerfile | 2 +- Jenkinsfile | 2 +- hip-clang.docker | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 53cf679bba7..5051b7b5cd8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y software-properties-common gnupg2 --no- curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - # Add rocm repository -RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/6.2/ jammy main > /etc/apt/sources.list.d/rocm.list' +RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/6.3/ jammy main > /etc/apt/sources.list.d/rocm.list' # From docs.amd.com for installing rocm. Needed to install properly RUN sh -c "echo 'Package: *\nPin: release o=repo.radeon.com\nPin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600" diff --git a/Jenkinsfile b/Jenkinsfile index 2c47e568614..0ce0026d454 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -42,7 +42,7 @@ def rocmtestnode(Map conf) { rm -rf build mkdir build cd build - cmake -DCTEST_TIMEOUT=3600 -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On -DCMAKE_EXECUTE_PROCESS_COMMAND_ECHO=STDOUT -DMIGRAPHX_DISABLE_VIRTUAL_ENV=ON ${flags} .. + cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On -DCMAKE_EXECUTE_PROCESS_COMMAND_ECHO=STDOUT -DMIGRAPHX_DISABLE_VIRTUAL_ENV=ON ${flags} .. git diff git diff-index --quiet HEAD || (echo "Git repo is not clean after running cmake." && exit 1) make -j\$(nproc) generate VERBOSE=1 diff --git a/hip-clang.docker b/hip-clang.docker index 8e3f9a9af28..6a2d57243c1 100755 --- a/hip-clang.docker +++ b/hip-clang.docker @@ -6,7 +6,7 @@ ARG PREFIX=/usr/local RUN dpkg --add-architecture i386 # Add rocm repository -RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/6.2/ focal main > /etc/apt/sources.list.d/rocm.list' +RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/6.3/ jammy main > /etc/apt/sources.list.d/rocm.list' # From docs.amd.com for installing rocm. Needed to install properly RUN sh -c "echo 'Package: *\nPin: release o=repo.radeon.com\nPin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600" From f9e276bcd700d440fbfd1370b57aa685c32b1afd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 16 Dec 2024 11:55:22 -0500 Subject: [PATCH 11/11] Update rocMLIR main 13065c4b3a216e1b13dfb8f746b8a0d421f124e8 (#3716) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bb466bc6e98..4daf866ed52 100755 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@e61b0f0e516f09144445b3c8eb372f39eb82d53b -DBUILD_FAT_LIBROCKCOMPILER=On \ No newline at end of file +ROCm/rocMLIR@13065c4b3a216e1b13dfb8f746b8a0d421f124e8 -DBUILD_FAT_LIBROCKCOMPILER=On \ No newline at end of file