From 05c9d1ba5467e695d0b52839f5e95b8240fdb48a Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Sat, 20 Jul 2024 00:28:31 -0400 Subject: [PATCH] remove splitk support (#3286) --- CHANGELOG.md | 2 - Jenkinsfile | 2 +- docs/dev/env_vars.rst | 5 +++ src/targets/gpu/mlir.cpp | 8 +++- test/gpu/mlir.cpp | 86 +++++++++++++++++++++++++++------------- 5 files changed, 71 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 91843ae6216..98f4f775feb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,6 @@ Full documentation for MIGraphX is available at * Added a `--test` flag in migraphx-driver to validate the installation * Added support for ONNX Operator: Einsum * Added uint8 support in ONNX Operators -* Enabled Split-k kernel configurations for performance improvements * Added fusion for group convolutions * Added rocMLIR conv3d support * Added rocgdb to the Dockerfile @@ -46,7 +45,6 @@ Full documentation for MIGraphX is available at * Added support for multi outputs in pointwise ops * Improve reduction fusion with reshape operators * Use the quantized output when an operator is used again -* Enabled Split-k GEMM perf configs for rocMLIR based GEMM kernels for better performance on all Hardware ### Fixes diff --git a/Jenkinsfile b/Jenkinsfile index 7cb184f51d1..eb247971ad9 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -144,7 +144,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build -> } }, mlir_debug: rocmnode('mi100+') { cmake_build -> stage('MLIR Debug') { - withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1']) { + withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1']) { def sanitizers = "undefined" // Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS. def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}" diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index ad65dc83c6a..b739a82072c 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -278,6 +278,11 @@ Limits the number of solutions available to MLIR for tuning. Set to "1", "enable", "enabled", "yes", or "true" to use. Enable input fusions in MLIR. +.. envvar:: MIGRAPHX_MLIR_ENABLE_SPLITK + +Set to "1", "enable", "enabled", "yes", or "true" to use. +Enable Split-k perf configs when tuning with MLIR. + CK vars ----------- diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index b08c02d074f..94badfe5bbd 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -78,6 +78,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_LIMIT); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_ENABLE_SPLITK); #ifdef MIGRAPHX_MLIR template // NOLINT @@ -595,8 +596,11 @@ struct mlir_program {"sym_name", sym_name}, {"kernel", std::string("mixr")}, {"arch", target_arch}, - {"num_cu", num_cu}, - {"enable_splitk_for_tuning", mlirUnitAttrGet(ctx.get())}}); + {"num_cu", num_cu}}); + if(enabled(MIGRAPHX_MLIR_ENABLE_SPLITK{})) + { + ops.add_attributes({{"enable_splitk_for_tuning", mlirUnitAttrGet(ctx.get())}}); + } ops.add_region(std::move(region)); insert(body, std::move(ops)); diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index d335e219594..7d41148876f 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -37,6 +38,8 @@ #include #include +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_ENABLE_SPLITK); + struct mlir_gpu_target : migraphx::gpu::target { std::string name() const { return "mlir"; } @@ -154,11 +157,20 @@ bool verify_mlir(const migraphx::module& mmlir) "mlir", run_gpu(mlir, inputs), migraphx::verify::expected{run_ref(ref, inputs)}); } +std::string get_attrs() +{ + if(migraphx::enabled(MIGRAPHX_MLIR_ENABLE_SPLITK{})) + { + return R"({arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64})"; + } + return R"({arch = "", kernel = "mixr", num_cu = 0 : i64})"; +} + TEST_CASE(conv) { - const std::string mlir_output = R"__migraphx__( + std::string mlir_output = R"__migraphx__( module { - func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} { + func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes ${attrs} { %0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1> return %0 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> } @@ -173,15 +185,17 @@ module { // Skip test if MLIR is not enabled if(s.empty()) return; - CHECK(encode(s) == encode(mlir_output)); + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); EXPECT(verify_mlir(m)); } TEST_CASE(conv_nhwc) { - const std::string mlir_output = R"__migraphx__( + std::string mlir_output = R"__migraphx__( module { - func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x1x24x8>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x1x32x8>) -> !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} { + func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x1x24x8>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x1x32x8>) -> !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> attributes ${attrs} { %0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x1x32x8>, <2x8x3x3xf32, 72x1x24x8> -> <1x2x2x2xf32, 8x1x4x2> return %0 : !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> } @@ -196,15 +210,17 @@ module { // Skip test if MLIR is not enabled if(s.empty()) return; - CHECK(encode(s) == encode(mlir_output)); + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); EXPECT(verify_mlir(m)); } TEST_CASE(conv_add_relu) { - const std::string mlir_output = R"__migraphx__( + std::string mlir_output = R"__migraphx__( module { - func.func @mlir_convolution_add_relu(%arg0: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} { + func.func @mlir_convolution_add_relu(%arg0: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes ${attrs} { %0 = migraphx.convolution %arg2, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1> %1 = migraphx.add %0, %arg0 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1> %2 = migraphx.relu %1 : <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1> @@ -224,16 +240,19 @@ module { // Skip test if MLIR is not enabled if(s.empty()) return; - CHECK(encode(s) == encode(mlir_output)); + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); + EXPECT(verify_mlir(m)); } // The following test checks that a dimension -1, within reshape operator is handled properly.. TEST_CASE(conv_reshape_dim_minus_one) { - const std::string mlir_output = R"__migraphx__( + std::string mlir_output = R"__migraphx__( module { - func.func @mlir_convolution_reshape(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x4x1x2xf32, 8x2x2x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} { + func.func @mlir_convolution_reshape(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x4x1x2xf32, 8x2x2x1> attributes ${attrs} { %0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1> %1 = migraphx.reshape %0 {dims = [1, 4, 1, 2]} : <1x2x2x2xf32, 8x4x2x1> -> <1x4x1x2xf32, 8x2x2x1> return %1 : !migraphx.shaped<1x4x1x2xf32, 8x2x2x1> @@ -250,15 +269,17 @@ module { // Skip test if MLIR is not enabled if(s.empty()) return; - CHECK(encode(s) == encode(mlir_output)); + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); EXPECT(verify_mlir(m)); } TEST_CASE(quant_dot_add) { - const std::string mlir_output = R"__migraphx__( + std::string mlir_output = R"__migraphx__( module { - func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped<1x5x3xi32, 15x3x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} { + func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped<1x5x3xi32, 15x3x1> attributes ${attrs} { %0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xi8, 20x4x1>, <1x4x3xi8, 12x3x1> -> <1x5x3xi32, 15x3x1> %1 = migraphx.add %0, %arg2 : <1x5x3xi32, 15x3x1>, <1x5x3xi32, 15x3x1> -> <1x5x3xi32, 15x3x1> return %1 : !migraphx.shaped<1x5x3xi32, 15x3x1> @@ -277,15 +298,17 @@ module { // Skip test if MLIR is not enabled if(s.empty()) return; - CHECK(encode(s) == encode(mlir_output)); + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); EXPECT(verify_mlir(m)); } TEST_CASE(dot_add) { - const std::string mlir_output = R"__migraphx__( + std::string mlir_output = R"__migraphx__( module { - func.func @mlir_dot_add(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} { + func.func @mlir_dot_add(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> %1 = migraphx.add %0, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1> @@ -303,15 +326,17 @@ module { // Skip test if MLIR is not enabled if(s.empty()) return; - CHECK(encode(s) == encode(mlir_output)); + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); EXPECT(verify_mlir(m)); } TEST_CASE(conv_int8_dequantize_quantize) { - const std::string mlir_output = R"__migraphx__( + std::string mlir_output = R"__migraphx__( module { - func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} { + func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> attributes ${attrs} { %0 = migraphx.quant_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x3x3xi8, 72x9x3x1> -> <1x2x2x2xi32, 8x4x2x1> %1 = migraphx.dequantizelinear %0, %arg2, %arg3 : <1x2x2x2xi32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1> %2 = migraphx.quantizelinear %1, %arg2, %arg3 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xi32, 8x4x2x1> @@ -336,15 +361,17 @@ module { // Skip test if MLIR is not enabled if(s.empty()) return; - CHECK(encode(s) == encode(mlir_output)); + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); EXPECT(verify_mlir(m)); } TEST_CASE(dot_convert) { - const std::string mlir_output = R"__migraphx__( + std::string mlir_output = R"__migraphx__( module { - func.func @mlir_dot_convert(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped<1x5x3xf16, 15x3x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} { + func.func @mlir_dot_convert(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped<1x5x3xf16, 15x3x1> attributes ${attrs} { %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> %1 = migraphx.convert %0 {target_type = 1 : i64} : <1x5x3xf32, 15x3x1> to <1x5x3xf16, 15x3x1> return %1 : !migraphx.shaped<1x5x3xf16, 15x3x1> @@ -362,15 +389,17 @@ module { // Skip test if MLIR is not enabled if(s.empty()) return; - CHECK(encode(s) == encode(mlir_output)); + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); EXPECT(verify_mlir(m)); } TEST_CASE(dot_where) { - const std::string mlir_output = R"__migraphx__( + std::string mlir_output = R"__migraphx__( module { - func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} { + func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> %1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1> @@ -389,7 +418,10 @@ module { // Skip test if MLIR is not enabled if(s.empty()) return; - CHECK(encode(s) == encode(mlir_output)); + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); + EXPECT(verify_mlir(m)); }