From 0075ab89e92a00e9032d086cedc26170ad4aa424 Mon Sep 17 00:00:00 2001 From: IgorMirosavljevicHTEC Date: Wed, 22 May 2024 20:48:51 +0200 Subject: [PATCH] testing real branch --- CMakeLists.txt | 14 +- Dockerfile | 9 +- Jenkinsfile | 14 +- hip-clang.docker | 2 +- requirements.txt | 2 +- src/CMakeLists.txt | 6 +- src/api/include/migraphx/migraphx.hpp | 22 +- src/argument.cpp | 20 +- src/cpp_generator.cpp | 18 +- src/dom_info.cpp | 6 +- src/driver/CMakeLists.txt | 6 +- src/driver/main.cpp | 32 +- src/driver/models.cpp | 45 +++ src/driver/models.hpp | 6 +- src/fuse_pointwise.cpp | 46 +-- src/fuse_pointwise_reduce.cpp | 42 +++ src/fuse_reduce.cpp | 105 +++++-- src/include/migraphx/argument.hpp | 4 +- src/include/migraphx/as_number.hpp | 43 +++ src/include/migraphx/cpp_generator.hpp | 4 +- src/include/migraphx/dom_info.hpp | 4 +- src/include/migraphx/fuse_pointwise.hpp | 4 +- .../migraphx/fuse_pointwise_reduce.hpp | 44 +++ src/include/migraphx/fuse_reduce.hpp | 4 +- src/include/migraphx/matcher.hpp | 18 ++ src/include/migraphx/module.hpp | 16 +- src/include/migraphx/onnx.hpp | 3 + src/include/migraphx/op/dot.hpp | 9 +- src/include/migraphx/op/pointwise.hpp | 34 +- src/include/migraphx/op/squeeze.hpp | 12 +- src/include/migraphx/par_for.hpp | 4 +- src/include/migraphx/param_utils.hpp | 2 +- src/include/migraphx/pass_manager.hpp | 3 +- src/include/migraphx/program.hpp | 3 + src/include/migraphx/raw_data.hpp | 8 + src/include/migraphx/rewrite_reshapes.hpp | 39 ++- src/include/migraphx/shape.hpp | 3 + src/include/migraphx/stringutils.hpp | 6 +- src/include/migraphx/tensor_view.hpp | 10 +- src/include/migraphx/tf.hpp | 11 +- src/include/migraphx/type_name.hpp | 13 +- src/instruction.cpp | 7 +- src/module.cpp | 119 +++++-- .../include/migraphx/onnx/onnx_parser.hpp | 1 + src/onnx/onnx.cpp | 1 + src/onnx/onnx_parser.cpp | 10 +- src/onnx/parse_convolution.cpp | 228 +++++++++++++- src/onnx/parse_expand.cpp | 8 +- src/param_utils.cpp | 10 +- src/pass_manager.cpp | 9 + src/program.cpp | 20 ++ src/propagate_constant.cpp | 8 +- src/register_target.cpp | 19 +- src/rewrite_reduce.cpp | 76 ++++- src/shape.cpp | 18 ++ src/simplify_algebra.cpp | 294 ++++++++++++++---- src/simplify_qdq.cpp | 30 +- src/simplify_reshapes.cpp | 201 +++++++++--- src/split_single_dyn_dim.cpp | 2 + .../cpu/include/migraphx/cpu/parallel.hpp | 12 +- src/targets/cpu/lowering.cpp | 1 - src/targets/gpu/CMakeLists.txt | 70 +++-- src/targets/gpu/code_object_op.cpp | 9 +- src/targets/gpu/compile_gen.cpp | 77 +++-- src/targets/gpu/compile_miopen.cpp | 3 +- src/targets/gpu/compile_ops.cpp | 35 ++- src/targets/gpu/compile_pointwise.cpp | 50 +++ src/targets/gpu/fuse_ck.cpp | 10 +- src/targets/gpu/fuse_mlir.cpp | 93 ++---- src/targets/gpu/fuse_ops.cpp | 6 +- src/targets/gpu/gemm_impl.cpp | 24 +- .../gpu/include/migraphx/gpu/compile_gen.hpp | 8 +- .../include/migraphx/gpu/compile_miopen.hpp | 2 +- .../migraphx/gpu/compile_pointwise.hpp | 45 +++ .../gpu/include/migraphx/gpu/context.hpp | 11 +- src/targets/gpu/include/migraphx/gpu/gemm.hpp | 5 +- src/targets/gpu/include/migraphx/gpu/mlir.hpp | 22 +- .../include/migraphx/gpu/prepare_reduce.hpp | 47 +++ .../gpu/include/migraphx/gpu/rocblas.hpp | 7 +- .../gpu/include/migraphx/gpu/time_op.hpp | 7 +- src/targets/gpu/jit/mlir.cpp | 157 +++++++++- src/targets/gpu/jit/pointwise.cpp | 45 ++- src/targets/gpu/jit/reduce.cpp | 12 +- src/targets/gpu/kernel.cpp | 28 +- .../include/migraphx/kernels/array.hpp | 24 +- .../include/migraphx/kernels/functional.hpp | 31 +- .../include/migraphx/kernels/layernorm.hpp | 54 ++-- .../include/migraphx/kernels/pointwise.hpp | 22 +- .../include/migraphx/kernels/print.hpp | 35 ++- .../include/migraphx/kernels/reduce.hpp | 18 +- .../include/migraphx/kernels/tuple.hpp | 164 ++++++++++ src/targets/gpu/lowering.cpp | 7 +- src/targets/gpu/mlir.cpp | 55 +++- src/targets/gpu/prefuse_ops.cpp | 11 +- src/targets/gpu/prepare_reduce.cpp | 134 ++++++++ src/targets/gpu/rocblas.cpp | 8 +- src/targets/gpu/target.cpp | 13 +- src/targets/gpu/time_op.cpp | 46 ++- src/tf/tf.cpp | 26 +- src/tf/tf_parser.cpp | 13 + tools/check_stamped.py | 16 +- .../migraphx_with_onnxruntime_pytorch.docker | 13 +- tools/docker/sles.docker | 2 +- tools/docker/ubuntu_2204.dockerfile | 2 +- tools/format.py | 5 +- 105 files changed, 2587 insertions(+), 655 deletions(-) create mode 100644 src/driver/models.cpp create mode 100644 src/fuse_pointwise_reduce.cpp create mode 100644 src/include/migraphx/as_number.hpp create mode 100644 src/include/migraphx/fuse_pointwise_reduce.hpp create mode 100644 src/targets/gpu/compile_pointwise.cpp create mode 100644 src/targets/gpu/include/migraphx/gpu/compile_pointwise.hpp create mode 100644 src/targets/gpu/include/migraphx/gpu/prepare_reduce.hpp create mode 100644 src/targets/gpu/kernels/include/migraphx/kernels/tuple.hpp create mode 100644 src/targets/gpu/prepare_reduce.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 72cefdb6c88..d3a2fde80be 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,8 @@ else() option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) endif() +option(MIGRAPHX_USE_ROCBLAS "Enable MIGraphX to use rocBLAS" ON) + # By default build shared libraries option(BUILD_SHARED_LIBS "Create shared libraries" ON) @@ -167,6 +169,9 @@ rocm_enable_clang_tidy( -cert-dcl51-cpp -cert-err33-c -cert-str34-c + # We seed random numbers with constants for reproducibility + -cert-msc32-c + -cert-msc51-cpp # Disable all alpha checks by default -clang-analyzer-alpha* # Enable some alpha checks @@ -334,11 +339,18 @@ else() set(DEPENDS_HIP_RUNTIME "hip-runtime-amd" ) endif() +if(MIGRAPHX_USE_ROCBLAS) + list(APPEND PACKAGE_DEPENDS rocblas) +endif() + +rocm_package_add_deb_dependencies(SHARED_DEPENDS "hip-dev") +rocm_package_add_rpm_dependencies(SHARED_DEPENDS "hip-devel") + rocm_create_package( NAME MIGraphX DESCRIPTION "AMD's graph optimizer" MAINTAINER "AMDMIGraphX Maintainer " LDCONFIG PTH - DEPENDS miopen-hip rocblas ${DEPENDS_HIP_RUNTIME} hip-base half ${PACKAGE_DEPENDS} + DEPENDS miopen-hip ${DEPENDS_HIP_RUNTIME} half ${PACKAGE_DEPENDS} ) diff --git a/Dockerfile b/Dockerfile index 7650225ef8c..cc5f14a7ca0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ ARG PREFIX=/usr/local RUN dpkg --add-architecture i386 # Install rocm key -RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \ +RUN apt-get update && apt-get install -y software-properties-common gnupg2 --no-install-recommends curl && \ curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - # Add rocm repository @@ -15,6 +15,9 @@ RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/6.0 # From docs.amd.com for installing rocm. Needed to install properly RUN sh -c "echo 'Package: *\nPin: release o=repo.radeon.com\nPin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600" +# rocgdb doesn't work on 22.04, workaround by installing the older python packages that are in 20.04 +RUN add-apt-repository -y ppa:deadsnakes/ppa + # Install dependencies RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ apt-utils \ @@ -32,10 +35,10 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- python3 \ python3-dev \ python3-pip \ - software-properties-common \ + libpython3.8 \ wget \ rocm-device-libs \ - hip-base \ + hip-dev \ libnuma-dev \ miopen-hip \ rocblas \ diff --git a/Jenkinsfile b/Jenkinsfile index 935af40b6e8..de6fa059a0b 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -155,13 +155,13 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build -> cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags_cxx}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}' -DGPU_TARGETS='${gpu_targets}'") } } -}, ck_hiprtc: rocmnode('mi100+') { cmake_build -> - stage('CK hipRTC') { - withEnv(['MIGRAPHX_ENABLE_CK=1', 'MIGRAPHX_TUNE_CK=1', 'MIGRAPHX_DISABLE_MLIR=1']) { - def gpu_targets = getgputargets() - cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'") - } - } +//}, ck_hiprtc: rocmnode('mi100+') { cmake_build -> +// stage('CK hipRTC') { +// withEnv(['MIGRAPHX_ENABLE_CK=1', 'MIGRAPHX_TUNE_CK=1', 'MIGRAPHX_DISABLE_MLIR=1']) { +// def gpu_targets = getgputargets() +// cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'") +// } +// } }, clang_asan: rocmnode('nogpu') { cmake_build -> stage('Clang ASAN') { def sanitizers = "undefined,address" diff --git a/hip-clang.docker b/hip-clang.docker index 6b090607b1e..73a3e8edbba 100755 --- a/hip-clang.docker +++ b/hip-clang.docker @@ -27,7 +27,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- software-properties-common \ wget \ rocm-device-libs \ - hip-base \ + hip-dev \ libnuma-dev \ miopen-hip \ rocblas \ diff --git a/requirements.txt b/requirements.txt index ba8f6d4644a..49b42573add 100755 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@ce62762a4f4f929e463091373f6d1f201da61204 -DBUILD_FAT_LIBROCKCOMPILER=On +ROCm/rocMLIR@e50d72fc6ab9a7a792d92a1ba7db6db45e4c508c -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f38380225d7..eedc7a27501 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -59,6 +59,7 @@ add_library(migraphx fp_to_double.cpp fuse_concat.cpp fuse_pointwise.cpp + fuse_pointwise_reduce.cpp fuse_reduce.cpp generate.cpp inline_module.cpp @@ -330,7 +331,10 @@ target_link_libraries(migraphx_all_targets INTERFACE migraphx_cpu) target_compile_definitions(migraphx_all_targets INTERFACE -DHAVE_CPU) endif() if(MIGRAPHX_ENABLE_GPU) -list(APPEND MIGRAPHX_CONFIG_DEPENDS PACKAGE MIOpen PACKAGE rocblas) + if(MIGRAPHX_USE_ROCBLAS) + list(APPEND MIGRAPHX_CONFIG_DEPENDS PACKAGE rocblas) + endif() + list(APPEND MIGRAPHX_CONFIG_DEPENDS PACKAGE MIOpen) add_subdirectory(targets/gpu) target_link_libraries(migraphx_all_targets INTERFACE migraphx_gpu) target_compile_definitions(migraphx_all_targets INTERFACE -DHAVE_GPU) diff --git a/src/api/include/migraphx/migraphx.hpp b/src/api/include/migraphx/migraphx.hpp index 336cbb71e69..af30020f4d0 100644 --- a/src/api/include/migraphx/migraphx.hpp +++ b/src/api/include/migraphx/migraphx.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -67,8 +67,24 @@ std::string compute_type_name() { std::string name; #if defined(_MSC_VER) && !defined(__clang__) - name = typeid(PrivateMigraphTypeNameProbe).name(); - name = name.substr(7); + const char struct_name[] = "struct "; + const char class_name[] = "class "; + const char function_name[] = "compute_type_name<"; + const char parameter_name[] = ">(void)"; + const char cdecl_name[] = "__cdecl"; + + name = __FUNCSIG__; + + auto begin = name.find(function_name) + sizeof(function_name) - 1; + auto length = name.find(parameter_name) - begin; + name = name.substr(begin, length); + if(name.find(class_name) == 0) + name = name.substr(sizeof(class_name) - 1); + else if(name.find(struct_name) == 0) + name = name.substr(sizeof(struct_name) - 1); + begin = name.find(cdecl_name); + if(begin != std::string::npos) + name.erase(begin, sizeof(cdecl_name) - 1); #else const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT diff --git a/src/argument.cpp b/src/argument.cpp index ba8e821d632..733e27acac1 100644 --- a/src/argument.cpp +++ b/src/argument.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -102,6 +102,24 @@ void argument::assign_buffer(std::function d) })(s); } +std::vector flatten(const std::vector& args) +{ + std::vector result; + for(const auto& arg : args) + { + if(arg.get_shape().type() == shape::tuple_type) + { + auto subs = flatten(arg.get_sub_objects()); + result.insert(result.end(), subs.begin(), subs.end()); + } + else + { + result.push_back(arg); + } + } + return result; +} + std::vector to_shapes(const std::vector& args) { std::vector shapes; diff --git a/src/cpp_generator.cpp b/src/cpp_generator.cpp index da292b83cce..433ccaadb5b 100644 --- a/src/cpp_generator.cpp +++ b/src/cpp_generator.cpp @@ -38,6 +38,7 @@ inline namespace MIGRAPHX_INLINE_NS { cpp_generator::function& cpp_generator::function::set_body(const module& m, const cpp_generator::generate_module_callback& g) { + const std::string prefix = "zz"; std::unordered_map names; std::stringstream ss; @@ -53,12 +54,13 @@ cpp_generator::function::set_body(const module& m, const cpp_generator::generate } else if(ins->name() == "@return") { - assert(ins->inputs().size() == 1); - return_ins = ins->inputs().front(); + names[ins] = prefix + "return"; + ss << "auto " << names[ins] << " = " << g(ins, names) << ";\n"; + return_ins = ins; } else { - std::string n = "z" + std::to_string(names.size()); + std::string n = prefix + std::to_string(names.size()); names[ins] = n; ss << "auto " << n << " = " << g(ins, names) << ";\n"; } @@ -125,6 +127,7 @@ struct cpp_generator_impl std::function fmap = nullptr; std::function fresult = nullptr; std::unordered_map point_op_map = {}; + bool always_return_tuple = false; }; cpp_generator::cpp_generator() : impl(std::make_unique()) {} @@ -142,6 +145,8 @@ void cpp_generator::fmap(const std::function& f) { imp void cpp_generator::fresult(const std::function& f) { impl->fresult = f; } +void cpp_generator::always_return_tuple(bool b) { impl->always_return_tuple = b; } + void cpp_generator::add_point_op(const std::string& op_name, const std::string& code) { impl->point_op_map[op_name] = code; @@ -222,6 +227,13 @@ cpp_generator::function cpp_generator::generate_module(const module& m, }); return shape::cpp_type(ins->get_shape().type()) + "(" + string_literal + ")"; } + if(ins->name() == "@return") + { + // TODO: Customize the make_tuple call + if(impl->always_return_tuple or ins->inputs().size() != 1) + return "make_tuple(" + join_strings(to_args(ins->inputs(), names), ", ") + ")"; + return names.at(ins->inputs().front()); + } auto s = g(ins, names); if(impl->fresult) return impl->fresult(ins->get_shape()) + '(' + s + ')'; diff --git a/src/dom_info.cpp b/src/dom_info.cpp index 400dd80e0e1..89ff19140ce 100644 --- a/src/dom_info.cpp +++ b/src/dom_info.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -31,7 +31,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2) +bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2) const { if(ins1 == ins2) return false; @@ -65,7 +65,7 @@ dominator_info compute_dominator_generic(Visitor v) if(children.size() == 1) { info.ins2idom[ins] = children.front(); - instr2_doms[ins].insert(children.front()); + instr2_doms[ins] = instr2_doms[children.front()]; } else if(children.size() > 1) { diff --git a/src/driver/CMakeLists.txt b/src/driver/CMakeLists.txt index f9f200a0abe..658fddf132e 100755 --- a/src/driver/CMakeLists.txt +++ b/src/driver/CMakeLists.txt @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,10 +26,8 @@ add_executable(driver main.cpp verify.cpp passes.cpp + models.cpp perf.cpp - resnet50.cpp - inceptionv3.cpp - alexnet.cpp marker_roctx.cpp ) set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver) diff --git a/src/driver/main.cpp b/src/driver/main.cpp index f10b875c734..1faeef4844f 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -70,11 +70,11 @@ inline std::string get_version() struct loader { - std::string model; std::string file; std::string file_type; unsigned batch = 1; bool is_nhwc = true; + bool is_test = false; unsigned trim = 0; bool optimize = false; bool skip_unknown_operators = false; @@ -91,11 +91,10 @@ struct loader void parse(argument_parser& ap) { ap(file, {}, ap.metavar(""), ap.file_exist(), ap.required(), ap.group("input")); - ap(model, - {"--model"}, - ap.help("Load model"), - ap.type("resnet50|inceptionv3|alexnet"), - ap.matches({"resnet50", "inceptionv3", "alexnet"}), + ap(is_test, + {"--test"}, + ap.help("Run a single GEMM to test MIGraphX"), + ap.set_value(true), ap.group("input")); ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); @@ -312,7 +311,11 @@ struct loader program load() { program p; - if(model.empty()) + if(is_test) + { + p = test_gemm(); + } + else { if(file_type.empty()) { @@ -344,17 +347,6 @@ struct loader p = migraphx::load(file); } } - else - { - if(model == "resnet50") - p = resnet50(batch); - else if(model == "inceptionv3") - p = inceptionv3(batch); - else if(model == "alexnet") - p = alexnet(batch); - else - MIGRAPHX_THROW("Unknown model: " + model); - } if(trim > 0) { auto* mm = p.get_main_module(); @@ -396,7 +388,7 @@ struct loader std::ofstream fs; if(not output.empty()) { - fs.open(output); + fs.open(output, std::ios::binary); os = &fs; } @@ -690,7 +682,7 @@ struct run_cmd : command struct perf : command { compiler c; - unsigned n = 100; + unsigned n = 100; bool detailed = false; void parse(argument_parser& ap) { diff --git a/src/driver/models.cpp b/src/driver/models.cpp new file mode 100644 index 00000000000..a2b0021238f --- /dev/null +++ b/src/driver/models.cpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "models.hpp" +#include +#include + +namespace migraphx { +namespace driver { +inline namespace MIGRAPHX_INLINE_NS { + +migraphx::program test_gemm() +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); + mm->add_instruction(migraphx::make_op("dot"), a, b); + return p; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace driver +} // namespace migraphx diff --git a/src/driver/models.hpp b/src/driver/models.hpp index d2256b12f8d..8d4209285b2 100644 --- a/src/driver/models.hpp +++ b/src/driver/models.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,9 +28,7 @@ namespace migraphx { namespace driver { inline namespace MIGRAPHX_INLINE_NS { -migraphx::program resnet50(unsigned batch); -migraphx::program inceptionv3(unsigned batch); -migraphx::program alexnet(unsigned batch); +migraphx::program test_gemm(); } // namespace MIGRAPHX_INLINE_NS } // namespace driver diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 10b560ff52e..46e3934908f 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -41,7 +42,7 @@ inline namespace MIGRAPHX_INLINE_NS { static literal get_scalar(instruction_ref ins) { - if(ins->name() == "contiguous") + if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name())) return get_scalar(ins->inputs().front()); const auto& s = ins->get_shape(); if(s.elements() != 1 and not(s.scalar())) @@ -88,7 +89,7 @@ static void create_pointwise_modules(module_pass_manager& mpm) { pointwise_inputs.push_back(input); param_map[input] = - pm->add_parameter("x" + std::to_string(i), shape{input->get_shape().type()}); + pm->add_parameter(param_name(i), shape{input->get_shape().type()}); i++; } else @@ -113,18 +114,17 @@ static void create_pointwise_modules(module_pass_manager& mpm) } } -static std::vector append_pointwise_module(instruction_ref ins, - instruction_ref output) +static module::with_inputs append_pointwise_module(instruction_ref ins, instruction_ref output) { assert(contains(output->inputs(), ins)); - module_ref pm = ins->module_inputs().at(0); + module pm = *ins->module_inputs().at(0); module_ref xm = output->module_inputs().at(0); - auto last = std::prev(pm->end()); + auto last = std::prev(pm.end()); assert(last->name() == "@return"); assert(last->inputs().size() == 1); - assert(pm->get_parameter_names().size() == ins->inputs().size()); + assert(pm.get_parameter_names().size() == ins->inputs().size()); assert(xm->get_parameter_names().size() == output->inputs().size()); std::vector inputs = ins->inputs(); @@ -134,15 +134,15 @@ static std::vector append_pointwise_module(instruction_ref ins, for(auto i : range(inputs.size())) { auto input = inputs[i]; - auto param = pm->get_parameter("x" + std::to_string(i)); - assert(param != pm->end()); + auto param = pm.get_parameter(param_name(i)); + assert(param != pm.end()); input_map[input] = param; } // Add the new parameter and additional inputs for(auto i : range(output->inputs().size())) { auto input = output->inputs()[i]; - auto param = xm->get_parameter("x" + std::to_string(i)); + auto param = xm->get_parameter(param_name(i)); assert(param != xm->end()); if(input == ins) { @@ -157,20 +157,20 @@ static std::vector append_pointwise_module(instruction_ref ins, else { map_ins[param] = - pm->add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()}); + pm.add_parameter(param_name(inputs.size()), {input->get_shape().type()}); inputs.push_back(input); input_map[input] = map_ins[param]; } } - pm->replace_return(pm->insert_instructions(last, xm, &map_ins)); - return inputs; + pm.replace_return(pm.insert_instructions(last, xm, &map_ins)); + return {std::move(pm), inputs}; } -static bool find_pointwise_modules(module& m) +static bool find_pointwise_modules(module_pass_manager& mpm) { bool changed = false; - auto last = std::prev(m.end()); - for(auto ins : iterator_for(m)) + auto last = std::prev(mpm.get_module().end()); + for(auto ins : iterator_for(mpm.get_module())) { if(ins->name() != "pointwise") continue; @@ -183,10 +183,11 @@ static bool find_pointwise_modules(module& m) continue; auto input = *it; - auto new_inputs = append_pointwise_module(input, ins); - m.replace_instruction(input, input->get_operator(), new_inputs, input->module_inputs()); - m.replace_instruction(ins, input); - m.move_instruction(input, ins); + auto fused = append_pointwise_module(input, ins); + auto name = fused.mod.name(); + mpm.rename_module(name, name + ":" + ins->module_inputs().front()->name() + "-deleted"); + auto* new_pm = mpm.create_module(name, std::move(fused.mod)); + mpm.get_module().replace_instruction(ins, input->get_operator(), fused.inputs, {new_pm}); changed = true; } @@ -212,8 +213,9 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const } for(int i = 0; i < 8; i++) { - mpm.run_pass(rewrite_reshapes{}); - if(not find_pointwise_modules(mpm.get_module())) + if(enable_rewrite_reshapes) + mpm.run_pass(rewrite_reshapes{}); + if(not find_pointwise_modules(mpm)) break; mpm.run_pass(dead_code_elimination{}); } diff --git a/src/fuse_pointwise_reduce.cpp b/src/fuse_pointwise_reduce.cpp new file mode 100644 index 00000000000..dfd3f474ba2 --- /dev/null +++ b/src/fuse_pointwise_reduce.cpp @@ -0,0 +1,42 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +void fuse_pointwise_reduce::apply(module_pass_manager& mpm) const +{ + mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = false}); + mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = false}); + mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = true}); + mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = true}); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 7f60d5ebe70..e533856c38a 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -22,15 +22,16 @@ * THE SOFTWARE. */ #include -#include +#include #include +#include #include -#include -#include #include -#include -#include +#include #include +#include +#include +#include #include #include #include @@ -39,6 +40,8 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) + struct fused_reduce { std::vector axes{}; @@ -64,25 +67,28 @@ struct fused_reduce if(not equal(names, inputs, [&](const auto& name, const auto& input) { return shapes.at(name).lens() == input.lens(); })) - MIGRAPHX_THROW("Dimenstion does not match the submodule."); - const auto& s = inputs.at(0); - auto lens = s.lens(); - if(lens != sm->get_output_shapes().front().lens()) - { - for(const auto& axis : axes) - { - lens[axis] = 1; - } - } + MIGRAPHX_THROW("Input dimension does not match the submodule."); - return shape::from_permutation( - sm->get_output_shapes().front().type(), lens, find_permutation(inputs)); + return shape::from_permutation(sm->get_output_shapes().front().type(), + sm->get_output_shapes().front().lens(), + find_permutation(inputs)); } std::string name() const { return "fused_reduce"; } }; MIGRAPHX_REGISTER_OP(fused_reduce); +/* + * Predicate matcher checks that input and output shapes have the same rank. This is assumed + * for broadcast instructions for these fusions. + */ +MIGRAPHX_PRED_MATCHER(input_output_ndim_match, instruction_ref ins) +{ + auto input_shape = ins->inputs().front()->get_shape(); + auto output_shape = ins->get_shape(); + return input_shape.ndim() == output_shape.ndim(); +} + static void insert_params(module_ref sm, const std::vector& inputs, std::unordered_map& map_ins) @@ -193,11 +199,49 @@ static void create_reduce_modules(module_pass_manager& mpm) } } +namespace { + +instruction_ref get_broadcast_output(instruction_ref broadcast) +{ + if(broadcast->outputs().size() != 1) + return broadcast; + auto output = broadcast->outputs().front(); + if(output->name() == "contiguous") + return get_broadcast_output(output); + return output; +} + +MIGRAPHX_PRED_MATCHER(used_once_except_broadcast, instruction_ref ins) +{ + if(ins->outputs().size() == 1) + return true; + if(ins->outputs().size() == 2) + { + auto is_broadcast = [](instruction_ref output) { + return contains(output->name(), "broadcast"); + }; + auto broadcast = std::find_if(ins->outputs().begin(), ins->outputs().end(), is_broadcast); + if(broadcast == ins->outputs().end()) + return false; + auto non_broadcast = + std::find_if_not(ins->outputs().begin(), ins->outputs().end(), is_broadcast); + if(non_broadcast == ins->outputs().end()) + return false; + auto output = get_broadcast_output(*broadcast); + return output == *non_broadcast; + } + + return false; +} +} // namespace template static auto match_broadcast(Ms... ms) { return match::skip(match::name("contiguous"))( - match::name("multibroadcast")(match::arg(0)(ms...), match::used_once()).bind("broadcast")); + match::name("multibroadcast")( + match::arg(0)(ms...), match::used_once(), input_output_ndim_match()) + .bind("broadcast")) + .bind("final_broadcast"); } template @@ -208,30 +252,36 @@ static auto any_input(Ms... ms) static auto match_broadcastable_input(const std::string& op, const std::string& name) { - auto match_op = match::name(op)(match::used_once()).bind(name); + auto match_op = match::name(op)(used_once_except_broadcast()).bind(name); auto match_op_input = any_input(match_op, match::used_once()); auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once()); return match::any_of(match_op_input, broadcast_match_op_input); } +static void finalize_reduce_module(module_ref m) +{ + eliminate_common_subexpression{}.apply(*m); + dead_code_elimination{}.apply(*m); +} + namespace { struct find_pointwise_reduce { auto matcher() const { + // fused_reduce instruction with pointwise inputs. return match::name("fused_reduce")(match_broadcastable_input("pointwise", "pointwise")); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto reduce = r.result; - auto input = r.instructions["pointwise"]; - + auto input = r.instructions["pointwise"]; const auto* pm = input->module_inputs().front(); const auto* old_rm = reduce->module_inputs().front(); + auto* rm = mpm.create_module(pm->name() + ":" + old_rm->name()); rm->set_bypass(); - std::unordered_map map_ins; // Insert pointwise auto rins = insert_ins_in_submodule(rm, input, map_ins).front(); @@ -240,11 +290,15 @@ struct find_pointwise_reduce if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; + auto fbroadcast = r.instructions["final_broadcast"]; map_ins[broadcast] = insert_ins_in_submodule(rm, broadcast, map_ins).front(); + if(fbroadcast != broadcast) + map_ins[fbroadcast] = map_ins[broadcast]; } // Insert fused_reduce rm->add_return(insert_module_in_submodule(rm, reduce, map_ins)); + finalize_reduce_module(rm); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm}); @@ -286,6 +340,7 @@ struct find_reduce_pointwise auto out = insert_ins_in_submodule(rm, pw, map_ins); rm->replace_return(out); + finalize_reduce_module(rm); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm}); @@ -330,6 +385,7 @@ struct find_reduce_reduce auto out = insert_module_in_submodule(rm, reduce1, map_ins); rm->replace_return(out); + finalize_reduce_module(rm); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm}); @@ -399,11 +455,14 @@ struct reduce_reshape : rewrite_reshapes_base void fuse_reduce::apply(module_pass_manager& mpm) const { + if(enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{})) + return; create_reduce_modules(mpm); mpm.run_pass(dead_code_elimination{}); for(int i = 0; i < 4; i++) { - mpm.run_pass(rewrite_reshapes{}); + if(enable_rewrite_reshapes) + mpm.run_pass(rewrite_reshapes{}); match::find_matches( mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{}); mpm.run_pass(dead_code_elimination{}); diff --git a/src/include/migraphx/argument.hpp b/src/include/migraphx/argument.hpp index 30c0df40c56..d02a009e317 100644 --- a/src/include/migraphx/argument.hpp +++ b/src/include/migraphx/argument.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -117,6 +117,8 @@ struct MIGRAPHX_EXPORT argument : raw_data data_t m_data{}; }; +MIGRAPHX_EXPORT std::vector flatten(const std::vector& args); + MIGRAPHX_EXPORT std::vector to_shapes(const std::vector& args); MIGRAPHX_EXPORT void migraphx_to_value(value& v, const argument& a); MIGRAPHX_EXPORT void migraphx_from_value(const value& v, argument& a); diff --git a/src/include/migraphx/as_number.hpp b/src/include/migraphx/as_number.hpp new file mode 100644 index 00000000000..987808b603e --- /dev/null +++ b/src/include/migraphx/as_number.hpp @@ -0,0 +1,43 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_RTGLIB_AS_NUMBER_HPP +#define MIGRAPHX_GUARD_RTGLIB_AS_NUMBER_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +T as_number(T x) +{ + return x; +} +inline int32_t as_number(int8_t x) { return static_cast(x); } +inline uint32_t as_number(uint8_t x) { return static_cast(x); } + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_RTGLIB_AS_NUMBER_HPP diff --git a/src/include/migraphx/cpp_generator.hpp b/src/include/migraphx/cpp_generator.hpp index ef052558051..9f34ba159d8 100644 --- a/src/include/migraphx/cpp_generator.hpp +++ b/src/include/migraphx/cpp_generator.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -95,6 +95,8 @@ struct MIGRAPHX_EXPORT cpp_generator void fresult(const std::function& f); + void always_return_tuple(bool b = true); + void add_point_op(const std::string& op_name, const std::string& code); std::string generate_point_op(const operation& op, const std::vector& args); diff --git a/src/include/migraphx/dom_info.hpp b/src/include/migraphx/dom_info.hpp index 7fd6db3a18e..ce589a44503 100644 --- a/src/include/migraphx/dom_info.hpp +++ b/src/include/migraphx/dom_info.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -36,7 +36,7 @@ struct module; struct MIGRAPHX_EXPORT dominator_info { - bool strictly_dominate(instruction_ref ins1, instruction_ref ins2); + bool strictly_dominate(instruction_ref ins1, instruction_ref ins2) const; std::unordered_map ins2idom; }; diff --git a/src/include/migraphx/fuse_pointwise.hpp b/src/include/migraphx/fuse_pointwise.hpp index bf1472c2602..5d17f0011fc 100644 --- a/src/include/migraphx/fuse_pointwise.hpp +++ b/src/include/migraphx/fuse_pointwise.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -36,6 +36,8 @@ struct MIGRAPHX_EXPORT fuse_pointwise { std::string name() const { return "fuse_pointwise"; } void apply(module_pass_manager& mpm) const; + + bool enable_rewrite_reshapes = true; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/fuse_pointwise_reduce.hpp b/src/include/migraphx/fuse_pointwise_reduce.hpp new file mode 100644 index 00000000000..68bdc4e9951 --- /dev/null +++ b/src/include/migraphx/fuse_pointwise_reduce.hpp @@ -0,0 +1,44 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_REDUCE_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_REDUCE_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module_pass_manager; + +struct MIGRAPHX_EXPORT fuse_pointwise_reduce +{ + std::string name() const { return "fuse_pointwise_reduce"; } + void apply(module_pass_manager& mpm) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_REDUCE_HPP diff --git a/src/include/migraphx/fuse_reduce.hpp b/src/include/migraphx/fuse_reduce.hpp index 90ee5ff2301..c5176eda812 100644 --- a/src/include/migraphx/fuse_reduce.hpp +++ b/src/include/migraphx/fuse_reduce.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -36,6 +36,8 @@ struct MIGRAPHX_EXPORT fuse_reduce { std::string name() const { return "fuse_reduce"; } void apply(module_pass_manager& mpm) const; + + bool enable_rewrite_reshapes = true; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 6e7e57c67bf..6956f5916e8 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -330,9 +330,27 @@ struct matcher_result }); } + void debug_print() const + { + for(const auto& it : ins_map) + { + std::cout << it.first << ": \n"; + it.second->debug_print(); + } + } + private: std::unordered_map ins_map; }; + + void debug_print() const + { + std::cout << "matcher_container: \n instructions:"; + instructions.debug_print(); + std::cout << " result: \n"; + result->debug_print(); + } + instruction_container instructions; instruction_ref result; }; diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index f9d41121159..a842c4aac28 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -202,6 +202,20 @@ struct MIGRAPHX_EXPORT module instruction_ref begin() const; instruction_ref end() const; + struct compute_shapes_options + { + std::string name = "compute_shapes"; + bool strict_type = false; + bool strict_lens = false; + std::vector scalar_const_out_lens = {}; + }; + + /// Compute a new ouput shape by replacing each parameter with input + /// shapes passed in. + std::vector compute_shapes(const std::vector& inputs, + compute_shapes_options options) const; + std::vector compute_shapes(const std::vector& inputs) const; + std::vector get_output_shapes() const; instruction_ref validate() const; @@ -292,7 +306,7 @@ struct MIGRAPHX_EXPORT module std::unique_ptr impl; }; -struct module_with_inputs +struct MIGRAPHX_EXPORT module_with_inputs { module mod; std::vector inputs; diff --git a/src/include/migraphx/onnx.hpp b/src/include/migraphx/onnx.hpp index c0d7637c093..b9bf42a4f0e 100644 --- a/src/include/migraphx/onnx.hpp +++ b/src/include/migraphx/onnx.hpp @@ -58,6 +58,9 @@ struct onnx_options int64_t limit_max_iterations = std::numeric_limits::max(); /// Use dynamic output for operators when available bool use_dyn_output = false; + /// Path to use for the external data if it is stored at different location compared to onnx + /// file + std::string external_data_path = ""; }; /// Create a program from an onnx file diff --git a/src/include/migraphx/op/dot.hpp b/src/include/migraphx/op/dot.hpp index 0ef3787b38c..37e38f76264 100644 --- a/src/include/migraphx/op/dot.hpp +++ b/src/include/migraphx/op/dot.hpp @@ -57,15 +57,18 @@ struct dot auto s1 = b.to_dynamic(); std::vector out_dyn_dims; - // check outer dynamic dimensions are the same + // Check outer dynamic dimensions are compatible. + // Must allow for intersection because of how simplify_dyn_ops + // simplifies each broadcast_for_dot individually. bool same_outers = std::equal(s0.dyn_dims().begin(), s0.dyn_dims().end() - 2, s1.dyn_dims().begin(), s1.dyn_dims().end() - 2, [&](auto x, auto y) { - if(x == y) + auto intersect = x.intersection(y); + if(intersect.has_value()) { - out_dyn_dims.push_back(x); + out_dyn_dims.push_back(intersect.value()); return true; } return false; diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp index 6e7d6f92f30..c76276a9f08 100644 --- a/src/include/migraphx/op/pointwise.hpp +++ b/src/include/migraphx/op/pointwise.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -45,23 +45,18 @@ struct pointwise { MIGRAPHX_THROW("should have one submodule."); } - auto* pm = mods.front(); - if(pm->get_output_shapes().size() != 1) - MIGRAPHX_THROW("pointwise should have only one output."); if(inputs.empty()) MIGRAPHX_THROW("pointwise should have at least one input"); + auto* pm = mods.front(); auto pnames = pm->get_parameter_names(); - std::sort(pnames.begin(), pnames.end()); check_shapes{inputs, *this}.has(pnames.size()).same_dims(); - auto type = pm->get_output_shapes().front().type(); - - // Scalar output if all inputs are scalar - if(inputs.front().elements() == 1 and - all_of(inputs, [](const auto& s) { return s.scalar(); })) - return shape{type}; - - return shape::from_permutation(type, inputs.front().lens(), find_permutation(inputs)); + auto result = pm->compute_shapes( + inputs, + {.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()}); + if(result.size() == 1) + return result.front(); + return shape{result}; } argument compute(const shape& output_shape, @@ -75,7 +70,7 @@ struct pointwise auto pnames = pm->get_parameter_names(); std::sort(pnames.begin(), pnames.end()); - par_for(output_shape.elements(), [&](auto i) { + par_for(args[0].get_shape().elements(), [&](auto i) { std::unordered_map params; std::transform( @@ -86,8 +81,15 @@ struct pointwise [&](auto&& name, auto&& arg) { return std::make_pair(name, arg.element(i)); }); auto results = run(pm, params); - assert(results.size() == 1); - visit_all(output, results.front())([&](auto out, auto x) { out[i] = x.front(); }); + assert(results.size() == output.get_sub_objects().size() or + (results.size() == 1 and output.get_sub_objects().empty())); + std::vector outputs; + if(results.size() == 1) + outputs = {output.share()}; + else + outputs = output.share().get_sub_objects(); + for(auto j : range(results.size())) + visit_all(outputs[j], results[j])([&](auto out, auto x) { out[i] = x.front(); }); }); return output; } diff --git a/src/include/migraphx/op/squeeze.hpp b/src/include/migraphx/op/squeeze.hpp index 5d09c1250f4..a237a0a5875 100644 --- a/src/include/migraphx/op/squeeze.hpp +++ b/src/include/migraphx/op/squeeze.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -59,12 +59,18 @@ struct squeeze auto input_shape = inputs[0]; if(input_shape.dynamic()) { + // Allow for any dynamic_dimension that intersects with {1, 1}. + // Assuming that the shape at run-time will be compatible. if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { - return input_shape.dyn_dims()[axis] != 1; + return not input_shape.dyn_dims() + .at(axis) + .intersection(shape::dynamic_dimension{1, 1}) + .has_value(); + ; })) { MIGRAPHX_THROW( - "SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}"); + "SQUEEZE: dynamic axis dimension should have an intersection with {1, 1}"); } std::vector dyn_dims = {}; if(axes.empty()) diff --git a/src/include/migraphx/par_for.hpp b/src/include/migraphx/par_for.hpp index 3f8673a3e9f..d34b6c5c170 100644 --- a/src/include/migraphx/par_for.hpp +++ b/src/include/migraphx/par_for.hpp @@ -38,9 +38,9 @@ void par_for(std::size_t n, F f) } template -void par_for(std::size_t n, std::size_t, F f) +void par_for(std::size_t n, std::size_t min_grain, F f) { - par_for(n, f); + simple_par_for(n, min_grain, f); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/param_utils.hpp b/src/include/migraphx/param_utils.hpp index 1552c28300b..3506fccbb30 100644 --- a/src/include/migraphx/param_utils.hpp +++ b/src/include/migraphx/param_utils.hpp @@ -33,7 +33,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -std::string param_name(std::size_t i, const std::string& prefix = "x"); +MIGRAPHX_EXPORT std::string param_name(std::size_t i, const std::string& prefix = "x"); void sort_params(std::vector& params); diff --git a/src/include/migraphx/pass_manager.hpp b/src/include/migraphx/pass_manager.hpp index fdbdc123a12..535c7b5a3a3 100644 --- a/src/include/migraphx/pass_manager.hpp +++ b/src/include/migraphx/pass_manager.hpp @@ -39,7 +39,8 @@ struct module_pass_manager module_pass_manager(const module_pass_manager&) = delete; virtual module& get_module() = 0; virtual module* create_module(const std::string& name) = 0; - virtual module* create_module(const std::string& name, module m) = 0; + virtual module* create_module(const std::string& name, module m) = 0; + virtual void rename_module(const std::string& old_name, const std::string& new_name) = 0; virtual module* get_common_parent() = 0; virtual module* get_root_module() = 0; virtual void run_pass(const pass& p) = 0; diff --git a/src/include/migraphx/program.hpp b/src/include/migraphx/program.hpp index 741063743e6..e86ba628656 100644 --- a/src/include/migraphx/program.hpp +++ b/src/include/migraphx/program.hpp @@ -80,6 +80,8 @@ struct MIGRAPHX_EXPORT program std::vector eval(parameter_map params, execution_environment exec_env = execution_environment{}) const; + std::vector eval_with_context(std::vector& ctx, parameter_map params) const; + void finish() const; std::size_t size() const; @@ -152,6 +154,7 @@ struct MIGRAPHX_EXPORT program std::unordered_multimap get_module_tree(); void remove_module(const std::string& name); + void rename_module(const std::string& old_name, const std::string& new_name); void remove_unused_modules(); private: diff --git a/src/include/migraphx/raw_data.hpp b/src/include/migraphx/raw_data.hpp index 91a9deb20e8..19373bab6d4 100644 --- a/src/include/migraphx/raw_data.hpp +++ b/src/include/migraphx/raw_data.hpp @@ -191,6 +191,14 @@ struct raw_data : raw_data_base ss << static_cast(*this); return ss.str(); } + + template + std::vector to_vector() const + { + std::vector result(static_cast(*this).get_shape().elements()); + this->visit([&](auto x) { result.assign(x.begin(), x.end()); }); + return result; + } }; namespace detail { diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index bd6ed1ec27d..e44413a09a1 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -74,13 +74,16 @@ struct rewrite_reshapes { auto reshape = match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once()); - auto skip_contiguous = [](auto... ms) { - return match::arg(0)(match::skip( - match::name("contiguous", "multibroadcast")(match::used_once()))(ms...)); + auto skip_contiguous_broadcast = + match::skip(match::name("contiguous", "multibroadcast")(match::used_once())); + auto skip_contiguous_broadcast_arg = [&](auto... ms) { + return match::arg(0)(skip_contiguous_broadcast(ms...)); }; auto pointwise = match::name(op1)(match::used_once()); - auto reshape_pointwise = reshape(skip_contiguous(pointwise.bind("x"))).bind("reshape"); - return match::name(op2)(match::any_of[match::inputs()](reshape_pointwise)); + auto reshape_pointwise = + reshape(skip_contiguous_broadcast_arg(pointwise.bind("x"))).bind("reshape"); + return match::name(op2)(match::any_of[match::inputs()]( + skip_contiguous_broadcast(reshape_pointwise).bind("input"))); } template @@ -107,17 +110,33 @@ struct rewrite_reshapes return x_ins == input; } + static std::optional is_broadcasted(instruction_ref start, instruction_ref last) + { + auto broadcast_ins = + find_input_if(start, last, [&](auto i) { return i->name() == "multibroadcast"; }); + bool result = broadcast_ins != last; + if(result and not match_input(broadcast_ins, last)) + return nullopt; + return result; + } + void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; auto x_ins = r.instructions["x"]; auto reshape_ins = r.instructions["reshape"]; + auto input_ins = r.instructions["input"]; - auto broadcast_ins = find_input_if( - reshape_ins, x_ins, [&](auto i) { return i->name() == "multibroadcast"; }); - const bool has_broadcast = broadcast_ins != x_ins; - if(has_broadcast and not match_input(broadcast_ins, x_ins)) + const auto has_broadcast_before_reshape = is_broadcasted(reshape_ins, x_ins); + const auto has_broadcast_after_reshape = is_broadcasted(input_ins, reshape_ins); + if(not has_broadcast_before_reshape.has_value()) + return; + if(not has_broadcast_after_reshape.has_value()) + return; + if(*has_broadcast_after_reshape and *has_broadcast_before_reshape) return; + const bool has_broadcast = + *has_broadcast_after_reshape or *has_broadcast_before_reshape; auto dims1 = T::base_dims(ins); auto dims2 = T::base_dims(x_ins); @@ -153,7 +172,7 @@ struct rewrite_reshapes auto inputs = ins->inputs(); std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { - if(input == reshape_ins) + if(input == input_ins) return new_x_ins; return reshape_input(ins)(input); }); diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index b84dbaa8728..bfa3253f137 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -428,6 +428,9 @@ struct MIGRAPHX_EXPORT shape std::shared_ptr impl; }; +/// Flatten subshapes to a single vector of non-tuple type of shapes +MIGRAPHX_EXPORT std::vector flatten(const std::vector& shapes); + MIGRAPHX_EXPORT void migraphx_to_value(value& v, const shape& s); MIGRAPHX_EXPORT void migraphx_from_value(const value& v, shape& s); diff --git a/src/include/migraphx/stringutils.hpp b/src/include/migraphx/stringutils.hpp index fa6b1afcb0c..bc49cb3d733 100644 --- a/src/include/migraphx/stringutils.hpp +++ b/src/include/migraphx/stringutils.hpp @@ -30,7 +30,7 @@ #include #include #include -#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -195,8 +195,8 @@ inline std::string to_string_range(Iterator start, Iterator last, const char* de std::stringstream ss; if(start != last) { - ss << *start; - std::for_each(std::next(start), last, [&](auto&& x) { ss << delim << x; }); + ss << as_number(*start); + std::for_each(std::next(start), last, [&](auto&& x) { ss << delim << as_number(x); }); } return ss.str(); } diff --git a/src/include/migraphx/tensor_view.hpp b/src/include/migraphx/tensor_view.hpp index a04812bce06..9c5ff978e83 100644 --- a/src/include/migraphx/tensor_view.hpp +++ b/src/include/migraphx/tensor_view.hpp @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -36,14 +36,6 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -template -T as_number(T x) -{ - return x; -} -inline int32_t as_number(int8_t x) { return static_cast(x); } -inline uint32_t as_number(uint8_t x) { return static_cast(x); } - template struct tensor_view_iterator_read { diff --git a/src/include/migraphx/tf.hpp b/src/include/migraphx/tf.hpp index 3ffbddbce30..2f8fa536fb8 100644 --- a/src/include/migraphx/tf.hpp +++ b/src/include/migraphx/tf.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -45,6 +45,15 @@ struct tf_options MIGRAPHX_TF_EXPORT program parse_tf(const std::string& name, const tf_options& options = tf_options{}); +/// Create a program from an tf buffer +MIGRAPHX_TF_EXPORT program parse_tf_buffer(const std::string& buffer, + const tf_options& options = tf_options{}); + +/// Create a program from tf buffer +MIGRAPHX_TF_EXPORT program parse_tf_buffer(const void* data, + std::size_t size, + const tf_options& options = tf_options{}); + MIGRAPHX_TF_EXPORT std::vector get_tf_operators(); } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/type_name.hpp b/src/include/migraphx/type_name.hpp index 72a1b53aad8..50396e1a4c6 100644 --- a/src/include/migraphx/type_name.hpp +++ b/src/include/migraphx/type_name.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -33,14 +33,9 @@ inline namespace MIGRAPHX_INLINE_NS { template std::string compute_type_name() { - std::string name; -#if defined(_MSC_VER) && !defined(__clang__) - name = typeid(PrivateMigraphTypeNameProbe).name(); - name = name.substr(7); -#else const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT - name = __PRETTY_FUNCTION__; + std::string name = __PRETTY_FUNCTION__; auto begin = name.find(parameter_name) + sizeof(parameter_name); #if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7) @@ -48,9 +43,7 @@ std::string compute_type_name() #else auto length = name.find_first_of("];", begin) - begin; #endif - name = name.substr(begin, length); -#endif - return name; + return name.substr(begin, length); } template diff --git a/src/instruction.cpp b/src/instruction.cpp index 543e7a13fbb..4a5537c32d2 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -400,9 +400,10 @@ void instruction::print(std::ostream& os, // skip return instruction shape if(ins->name() != "@return") os << " -> " << ins->get_shape(); - // print tid - os << ", target_id=" << ins->target_id; + // print tid + if(ins->target_id != 0) + os << ", target_id=" << ins->target_id; } static void debug_name(std::ostream& os, const instruction& ins) diff --git a/src/module.cpp b/src/module.cpp index 8c28c15d22d..fc93193acff 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -663,6 +663,71 @@ std::vector module::get_output_shapes() const } } +std::vector module::compute_shapes(const std::vector& inputs, + compute_shapes_options options) const +{ + auto params = this->get_parameter_names(); + std::sort(params.begin(), params.end()); + std::unordered_map ins_shapes; + std::unordered_map adjusted_param_shapes; + std::transform(inputs.begin(), + inputs.end(), + params.begin(), + std::inserter(adjusted_param_shapes, adjusted_param_shapes.end()), + [](auto ps, auto name) { return std::make_pair(name, ps); }); + for(auto ins : iterator_for(*this)) + { + if(ins->name() == "@param") + { + ins_shapes[ins] = + adjusted_param_shapes[any_cast(ins->get_operator()).parameter]; + if(options.strict_type and ins->get_shape().type() != ins_shapes[ins].type()) + { + MIGRAPHX_THROW(options.name + ": Mismatched type: expected " + + ins->get_shape().type_string() + " but passed " + + ins_shapes[ins].type_string()); + } + if(options.strict_lens and ins->get_shape().lens() != ins_shapes[ins].lens()) + { + MIGRAPHX_THROW(options.name + ": Mismatched lens: expected {" + + to_string_range(ins->get_shape().lens()) + "} but passed {" + + to_string_range(ins_shapes[ins].lens()) + "}"); + } + } + else if(ins->name() == "@literal") + { + if(not options.scalar_const_out_lens.empty() and ins->get_shape().scalar()) + { + std::vector strides(options.scalar_const_out_lens.size()); + ins_shapes[ins] = + shape{ins->get_shape().type(), options.scalar_const_out_lens, strides}; + } + else + { + ins_shapes[ins] = ins->get_shape(); + } + } + else + { + std::vector input_shapes; + input_shapes.resize(ins->inputs().size()); + std::transform(ins->inputs().begin(), + ins->inputs().end(), + input_shapes.begin(), + [&](auto in) { return ins_shapes.at(in); }); + if(ins->name() == "@return") + return input_shapes; + ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes); + } + } + MIGRAPHX_THROW("No return found in the submodule"); +} + +std::vector module::compute_shapes(const std::vector& inputs) const +{ + return compute_shapes(inputs, {}); +} + std::vector module::get_returns() const { auto last = std::prev(this->end()); @@ -857,8 +922,7 @@ generic_split(const module& m, instructions2.push_back(ins); } - std::vector inputs2 = select_params(instructions2, param_map); - inputs2.insert(inputs2.begin(), splits.begin(), splits.end()); + std::vector inputs2 = splits; module m2; std::size_t n = 0; std::unordered_map map_ins2; @@ -870,6 +934,7 @@ generic_split(const module& m, continue; if(not contains(instructions2, ins)) continue; + inputs2.push_back(param_map.at(ins)); map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); } auto r = m2.add_instructions(instructions2, &map_ins2); @@ -991,11 +1056,12 @@ std::unordered_map module::print( const std::unordered_map&)>& print_func, std::unordered_map names) const { + const bool is_root = names.empty(); int count = 0; for(auto ins : iterator_for(*this)) { std::string var_name; - if(not this->name().empty() and this->name() != "main") + if(not this->name().empty() and not is_root) var_name = this->name() + ":"; if(ins->name() == "@param") { @@ -1094,10 +1160,10 @@ static void print_make_op(std::ostream& os, const operation& op) static void print_py_shape(std::ostream& os, const migraphx::shape& s) { - os << "migraphx.shape(type=" << to_json_string(s.type_string()) - << ", lens=" << to_json_string(s.lens()); + os << "migraphx.shape(type=" << to_json_string(s.type_string()) << ", lens=[" + << to_string_range(s.lens()) << "]"; if(not s.standard()) - os << ", strides=" << to_json_string(s.strides()); + os << ", strides=[" << to_string_range(s.strides()) << "]"; os << ")"; } @@ -1130,25 +1196,34 @@ module::print_py(std::ostream& os, if(ins->name() == "@literal") { os << mname << ".add_literal("; - const bool use_abs = false; - // Disable abs for now - // ins->get_literal().visit([&](auto v) { - // use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; }); - // }); - if(use_abs) - os << "migraphx.abs_literal("; - os << "migraphx.generate_argument("; - print_py_shape(os, ins->get_shape()); - os << ", " << seed << ")"; - if(use_abs) - os << ")"; + if(ins->get_shape().elements() < 10) + { + os << "migraphx.create_argument("; + print_py_shape(os, ins->get_shape()); + os << ", [" << ins->get_literal() << "])"; + } + else + { + const bool use_abs = false; + // Disable abs for now + // ins->get_literal().visit([&](auto v) { + // use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; }); + // }); + if(use_abs) + os << "migraphx.abs_literal("; + os << "migraphx.generate_argument("; + print_py_shape(os, ins->get_shape()); + os << ", " << seed << ")"; + if(use_abs) + os << ")"; + seed++; + } os << ")" << std::endl; - seed++; } else if(ins->name() == "@param") { std::string name = any_cast(ins->get_operator()).parameter; - os << mname << ".add_parameter(" << enclose_name(name) << ","; + os << mname << ".add_parameter(" << enclose_name(name) << ", "; print_py_shape(os, ins->get_shape()); os << ")" << std::endl; } @@ -1163,7 +1238,9 @@ module::print_py(std::ostream& os, os << mname << ".add_instruction("; print_py_op(os, ins->get_operator()); os << ", [" << join_strings(input_vars, ", ") << "]"; - os << ")" << std::endl; + os << ") # "; + print_py_shape(os, ins->get_shape()); + os << std::endl; } }, names); diff --git a/src/onnx/include/migraphx/onnx/onnx_parser.hpp b/src/onnx/include/migraphx/onnx/onnx_parser.hpp index 37c1721ac12..493ad845a8e 100644 --- a/src/onnx/include/migraphx/onnx/onnx_parser.hpp +++ b/src/onnx/include/migraphx/onnx/onnx_parser.hpp @@ -45,6 +45,7 @@ struct onnx_parser { std::string filename; fs::path path; + std::string external_data_path; using attribute_map = std::unordered_map; struct node_info { diff --git a/src/onnx/onnx.cpp b/src/onnx/onnx.cpp index d8e10d0d6f9..a2c3db80b50 100644 --- a/src/onnx/onnx.cpp +++ b/src/onnx/onnx.cpp @@ -41,6 +41,7 @@ template program parse_onnx_from(const onnx_options& options, Ts&&... xs) { onnx::onnx_parser parser; + parser.external_data_path = options.external_data_path; parser.map_input_dims = options.map_input_dims; parser.dim_params = options.dim_params; parser.map_dyn_input_dims = options.map_dyn_input_dims; diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index ef00d241947..7bfaaa10213 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -514,7 +514,15 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const { nbytes = std::stoul(t.external_data().at(2).value()); } - auto raw_buffer = read_buffer(path / data_file, offset, nbytes); + std::vector raw_buffer; + if(not external_data_path.empty()) + { + raw_buffer = read_buffer(fs::path{external_data_path} / data_file, offset, nbytes); + } + else + { + raw_buffer = read_buffer(path / data_file, offset, nbytes); + } std::string s(raw_buffer.begin(), raw_buffer.end()); return create_literal(type, dims, s.data()); } diff --git a/src/onnx/parse_convolution.cpp b/src/onnx/parse_convolution.cpp index 155faf10b93..f0791222bd4 100644 --- a/src/onnx/parse_convolution.cpp +++ b/src/onnx/parse_convolution.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -42,18 +43,199 @@ struct parse_convolution : op_parser return {{"Conv", "convolution"}, {"ConvInteger", "quant_convolution"}}; } + // Convert to half prior to a shift to ensure we preserve accuracy here then + // convert back to int8 + static instruction_ref add_int8_shift(const onnx_parser::node_info& info, + const instruction_ref& offset_op, + instruction_ref& unshifted_input) + { + auto unshifted_input_half = info.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), + unshifted_input); + + auto input_shifted_half = info.add_common_op("add", unshifted_input_half, offset_op); + + return info.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), + input_shifted_half); + } + + static void shift_input_and_bias(const onnx_parser::node_info& info, + const instruction_ref& offset_op, + const bool has_bias, + instruction_ref& input, + instruction_ref& input_bias) + { + input = add_int8_shift(info, offset_op, input); + if(has_bias) + { + input_bias = add_int8_shift(info, offset_op, input_bias); + } + } + + static float get_symmetric_value(const instruction_ref& input) + { + float symmetric_value = 0; + // adjust symmetric zero point value for uint8 types + if(input->get_shape().type() == migraphx::shape::uint8_type) + { + symmetric_value = 128; + } + return symmetric_value; + } + + static instruction_ref gen_symmetric_literal(const instruction_ref& input, + const bool is_quant_conv, + onnx_parser::node_info& info) + { + instruction_ref ret = input; + if(is_quant_conv) + { + float symmetric_value = get_symmetric_value(input); + ret = info.add_literal(migraphx::literal{ + migraphx::shape{input->get_shape().type(), {1}, {0}}, {symmetric_value}}); + } + + return ret; + } + + static instruction_ref get_zero_point(const instruction_ref& input, + int index, + const bool is_quant_conv, + onnx_parser::node_info& info, + const std::vector& args) + { + instruction_ref ret = input; + if(args.size() > index) + { + // Check for type mismatch on parse + if(input->get_shape().type() != args[index]->get_shape().type()) + MIGRAPHX_THROW("PARSE:Conv Data and Data Zero Point must have same type"); + + ret = args[index]; + if(is_symmetric_zero_point(ret)) + { + ret = gen_symmetric_literal(ret, is_quant_conv, info); + } + } + else + { + ret = gen_symmetric_literal(ret, is_quant_conv, info); + } + + return ret; + } + + static bool is_symmetric_zero_point(instruction_ref zp) + { + if(not zp->can_eval()) + return false; + + float symmetric_value = get_symmetric_value(zp); + + bool all_zeros = false; + zp->eval().visit([&](auto z) { + all_zeros = std::all_of( + z.begin(), z.end(), [&](auto val) { return float_equal(val, symmetric_value); }); + }); + return all_zeros; + } + + static auto + qparam_broadcast_op(instruction_ref qparam, std::vector lens, std::size_t axis) + { + if(qparam->get_shape().scalar()) + { + return migraphx::make_op("multibroadcast", {{"out_lens", lens}}); + } + else + { + return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}}); + } + } + + static instruction_ref handle_quant_bias(const operation& op, + const instruction_ref& input, + const instruction_ref& x, + const instruction_ref& weights, + const instruction_ref& x_zp, + const instruction_ref& w_zp, + onnx_parser::node_info& info) + { + instruction_ref ret = input; + if(not is_symmetric_zero_point(x_zp)) + { + auto out_zp_1 = info.add_common_op(op.name(), x_zp, weights); + ret = info.add_common_op("sub", ret, out_zp_1); + } + + if(not is_symmetric_zero_point(w_zp)) + { + auto out_zp_2 = info.add_common_op(op.name(), x, w_zp); + ret = info.add_common_op("sub", ret, out_zp_2); + } + + if(not(is_symmetric_zero_point(x_zp)) and not(is_symmetric_zero_point(w_zp))) + { + auto x_zp_bc = + info.add_instruction(qparam_broadcast_op(x_zp, x->get_shape().lens(), 0), x_zp); + auto w_zp_bc = info.add_instruction( + qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0), w_zp); + + auto out_zp_3 = info.add_instruction(op, x_zp_bc, w_zp_bc); + + ret = info.add_common_op("add", ret, out_zp_3); + } + return ret; + } + + static void handle_quant_inputs(const bool is_quant_conv, + instruction_ref& input, + instruction_ref& weights, + instruction_ref& input_zp, + instruction_ref& weight_zp, + onnx_parser::node_info& info) + { + if(not is_quant_conv) + return; + + auto input_type = input->get_shape().type(); + auto weight_type = weights->get_shape().type(); + + // Handle uint8 bias and input shifts + instruction_ref offset_op; + if(((input_type == migraphx::shape::uint8_type) or + (weight_type == migraphx::shape::uint8_type))) + { + offset_op = info.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {-128}}); + } + + if(input_type == migraphx::shape::uint8_type) + { + shift_input_and_bias( + info, offset_op, (not is_symmetric_zero_point(input_zp)), input, input_zp); + } + + if(weight_type == migraphx::shape::uint8_type) + { + shift_input_and_bias( + info, offset_op, (not is_symmetric_zero_point(weight_zp)), weights, weight_zp); + } + } + instruction_ref parse(const op_desc& opd, const onnx_parser& parser, onnx_parser::node_info info, std::vector args) const { - auto op = make_op(opd.op_name); - auto values = op.to_value(); - auto l0 = args[0]; - auto weights = args[1]; - auto l0_shape = l0->get_shape(); - auto w_shape = weights->get_shape(); - auto in_lens = l0_shape.max_lens(); + auto op = make_op(opd.op_name); + auto values = op.to_value(); + auto x = args[0]; + auto weights = args[1]; + auto x_shape = x->get_shape(); + auto w_shape = weights->get_shape(); + auto in_lens = x_shape.max_lens(); assert(in_lens.size() > 2); auto kdims = in_lens.size() - 2; @@ -92,9 +274,9 @@ struct parse_convolution : op_parser // check if image shape is dynamic bool image_shape_dynamic = false; - if(l0_shape.dynamic()) + if(x_shape.dynamic()) { - auto dyn_dims = l0_shape.dyn_dims(); + auto dyn_dims = x_shape.dyn_dims(); std::for_each(dyn_dims.begin() + 2, dyn_dims.end(), [&](auto dyn_dim) { if(not dyn_dim.is_fixed()) { @@ -149,9 +331,31 @@ struct parse_convolution : op_parser recalc_conv_attributes(values, kdims); + instruction_ref ret; + // parse a_zero_point and b_zero_point values + auto is_quant_conv = opd.op_name == "quant_convolution"; + + auto x_zp = get_zero_point(x, 2, is_quant_conv, info, args); + auto w_zp = get_zero_point(weights, 3, is_quant_conv, info, args); + op.from_value(values); - auto l1 = info.add_instruction(op, l0, args[1]); - return info.add_bias(args, l1, 1); + + handle_quant_inputs(is_quant_conv, x, weights, x_zp, w_zp, info); + + ret = info.add_instruction(op, x, weights); + + // Handle quant_conv residuals between input/weights to avoid overflow + if(is_quant_conv) + { + ret = handle_quant_bias(op, ret, x, weights, x_zp, w_zp, info); + } + else + { + // Handle Convolution case with bias to output + ret = info.add_bias(args, ret, 1); + } + + return ret; } }; diff --git a/src/onnx/parse_expand.cpp b/src/onnx/parse_expand.cpp index 0762fbb6243..e468cad1005 100644 --- a/src/onnx/parse_expand.cpp +++ b/src/onnx/parse_expand.cpp @@ -41,7 +41,6 @@ struct parse_expand : op_parser const onnx_parser::node_info& info, std::vector args) const { - auto in_lens = args[0]->get_shape().lens(); migraphx::argument arg_s = args[1]->eval(); if(arg_s.empty()) { @@ -50,6 +49,13 @@ struct parse_expand : op_parser } else { + const shape& shape_0 = args[0]->get_shape(); + if(shape_0.dynamic()) + { + MIGRAPHX_THROW( + "PARSE_EXPAND: dynamic input tensor with fixed dims input not supported"); + } + const auto& in_lens = shape_0.lens(); std::vector dims; arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); auto out_lens = compute_broadcasted_lens(in_lens, dims); diff --git a/src/param_utils.cpp b/src/param_utils.cpp index 61302a0afba..20c1ad8b0e2 100644 --- a/src/param_utils.cpp +++ b/src/param_utils.cpp @@ -25,14 +25,20 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { std::string param_name(std::size_t i, const std::string& prefix) { - assert(i < 10); - return prefix + std::to_string(i); + if(i < 10) + return prefix + std::to_string(i); + const std::size_t max_digits = 5; + if(i >= std::pow(10, max_digits)) + MIGRAPHX_THROW("Too many parameters."); + std::size_t n = log10(i) + 1; + return prefix + ":" + std::string(max_digits - n, '0') + std::to_string(i); } void sort_params(std::vector& params) diff --git a/src/pass_manager.cpp b/src/pass_manager.cpp index 3748e094773..af8f0e4d6a7 100644 --- a/src/pass_manager.cpp +++ b/src/pass_manager.cpp @@ -105,6 +105,15 @@ struct module_pm : module_pass_manager return prog->create_module(name, std::move(m)); } + virtual void rename_module(const std::string& old_name, const std::string& new_name) override + { + assert(prog); + assert(mod); + assert( + any_of(mod->get_sub_modules(), [&](module_ref sm) { return sm->name() == old_name; })); + prog->rename_module(old_name, new_name); + } + virtual module* get_common_parent() override { return common_parent; } virtual module* get_root_module() override diff --git a/src/program.cpp b/src/program.cpp index fb24685fcbc..ddccd40eed3 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -523,6 +523,13 @@ std::vector generic_eval(const program& p, return generic_eval(mm, ctx, params, {}, trace); } +std::vector program::eval_with_context(std::vector& ctx, + parameter_map params) const +{ + const module* mm = this->get_main_module(); + return generic_eval(mm, ctx, std::move(params), {}, [](auto&&, auto f) { return f(); }); +} + std::vector program::eval(parameter_map params, execution_environment exec_env) const { auto& contexts = this->impl->contexts; @@ -1074,10 +1081,12 @@ const module* program::get_module(const std::string& name) const { return &impl- module* program::create_module(const std::string& name) { + assert(not contains(impl->modules, name)); auto r = impl->modules.emplace(name, name); return &(r.first->second); } + module* program::create_module(const std::string& name, module m) { assert(not contains(impl->modules, name)); @@ -1202,6 +1211,17 @@ void program::remove_module(const std::string& name) impl->modules.erase(name); } +void program::rename_module(const std::string& old_name, const std::string& new_name) +{ + assert(old_name != new_name); + assert(contains(impl->modules, old_name)); + assert(not contains(impl->modules, new_name)); + auto node = impl->modules.extract(old_name); + node.key() = new_name; + node.mapped().set_name(new_name); + impl->modules.insert(std::move(node)); +} + void program::remove_unused_modules() { std::vector unused; diff --git a/src/propagate_constant.cpp b/src/propagate_constant.cpp index c917d29b3ed..3a6856d9450 100644 --- a/src/propagate_constant.cpp +++ b/src/propagate_constant.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include namespace migraphx { @@ -83,7 +84,12 @@ void propagate_constant::apply(module& m) const // Compute literals in parallel std::vector const_instrs_vec{const_instrs.begin(), const_instrs.end()}; std::vector literals(const_instrs_vec.size()); - simple_par_for(const_instrs_vec.size(), 1, [&](const auto i) { + std::size_t grainsize = 1; +#if !MIGRAPHX_HAS_EXECUTORS + std::size_t n = std::max(2048 / std::thread::hardware_concurrency(), 1); + grainsize = const_instrs_vec.size() / n; +#endif + simple_par_for(const_instrs_vec.size(), grainsize, [&](const auto i) { literals[i] = const_instrs_vec[i]->eval(); }); diff --git a/src/register_target.cpp b/src/register_target.cpp index 76b9488a8a0..241b080c0e6 100644 --- a/src/register_target.cpp +++ b/src/register_target.cpp @@ -58,22 +58,23 @@ target make_target(const std::string& name) { if(not contains(target_map(), name)) { - std::string so_version = "." + std::to_string(MIGRAPHX_SO_MAJOR_VERSION) + ".0"; - auto target_name = make_shared_object_filename("migraphx_" + name); + std::string so_major_version = "." + std::to_string(MIGRAPHX_SO_MAJOR_VERSION); + auto target_name = make_shared_object_filename("migraphx_" + name); - // Try to load library with so version appended to the name. - // If library with so version name is not found, - // try loading the library without the so version name appended. - // For example, if "libmigraphx_ref.so.2010000.0" is not found, + // Try to load library with so_major_version appended to the name. + // If library with so_major_version name is not found, + // try loading the library without the so_major_version name appended. + // For example, if "libmigraphx_ref.so.2010000" is not found, // try loading "libmigraphx_ref.so". try { - // Default to loading shared libraries with so version appended. - store_target_lib(dynamic_loader(target_name + so_version)); + // Default to loading shared libraries with + // so_major_version appended. + store_target_lib(dynamic_loader(target_name + so_major_version)); } catch(...) { - // Load the library without the so version in the name. + // Load the library without the so_major_version in the name. store_target_lib(dynamic_loader(target_name)); } } diff --git a/src/rewrite_reduce.cpp b/src/rewrite_reduce.cpp index c90f0cafe95..80b2fc7765b 100644 --- a/src/rewrite_reduce.cpp +++ b/src/rewrite_reduce.cpp @@ -55,44 +55,87 @@ struct find_softmax } }; +struct find_reduce_mean_variance +{ + auto matcher() const + { + auto reduce_mean = match::name("reduce_mean"); + auto x_minus_mean = + match::name("sub")(match::arg(0)(match::any().bind("x")), + match::arg(1)(match::skip_broadcasts(reduce_mean.bind("mean")))); + auto pow_x_minus_mean = + match::name("pow")(match::arg(0)(x_minus_mean), match::arg(1)(match::has_value(2.0f))); + auto mul_x_minus_mean = + match::name("mul")(match::arg(0)(x_minus_mean), match::arg(1)(x_minus_mean)); + auto sqdiff = match::name("sqdiff")(match::either_arg(0, 1)( + match::any().bind("x"), skip_broadcasts(reduce_mean.bind("mean")))); + return reduce_mean( + match::arg(0)(match::any_of(pow_x_minus_mean, mul_x_minus_mean, sqdiff))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto x_ins = r.instructions["x"]; + auto mean = r.instructions["mean"]; + + if(ins->get_operator() != mean->get_operator()) + return; + + if(mean->inputs().front() != x_ins) + return; + + auto x2 = m.insert_instruction(ins, make_op("mul"), x_ins, x_ins); + auto mean_x2 = m.insert_instruction(ins, mean->get_operator(), x2); + auto mean_x_2 = m.insert_instruction(ins, make_op("mul"), mean, mean); + m.replace_instruction(ins, make_op("sub"), mean_x2, mean_x_2); + } +}; + struct find_reduce_mean { auto matcher() const { return match::name("reduce_mean"); } void apply(module& m, const match::matcher_result& r) const { - auto reduce_mean = r.result; - auto op = reduce_mean->get_operator().to_value(); - auto axes = op["axes"].to_vector(); - auto input = reduce_mean->inputs().front(); + auto ins = r.result; + auto op = ins->get_operator().to_value(); + auto axes = op["axes"].to_vector(); + auto input = ins->inputs().front(); bool is_integral = false; double max_n = 0; + std::size_t size = 0; input->get_shape().visit_type([&](auto t) { is_integral = t.is_integral(); max_n = t.max(); + size = t.size(); }); - auto n = input->get_shape().elements() / reduce_mean->get_shape().elements(); + auto n = input->get_shape().elements() / ins->get_shape().elements(); - // avoid overflow (the larger value will be later handled) - if(n >= max_n / 4) - return; + if(n >= max_n / 4 and size < 3) + { + shape::type_t t = is_integral ? shape::int32_type : shape::float_type; + input = m.insert_instruction(ins, make_op("convert", {{"target_type", t}}), input); + } auto n_literal = m.add_literal(literal{{input->get_shape().type(), {1}}, {n}}); if(is_integral) { auto reduce_sum = - m.insert_instruction(reduce_mean, make_op("reduce_sum", {{"axes", axes}}), input); - auto div = insert_common_op(m, reduce_mean, make_op("div"), {reduce_sum, n_literal}); - m.replace_instruction(reduce_mean, div); + m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), input); + auto div = insert_common_op(m, ins, make_op("div"), {reduce_sum, n_literal}); + m.replace_instruction( + ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), div); } else { - auto new_input = insert_common_op(m, reduce_mean, make_op("div"), {input, n_literal}); - auto reduce_sum = m.insert_instruction( - reduce_mean, make_op("reduce_sum", {{"axes", axes}}), new_input); - m.replace_instruction(reduce_mean, reduce_sum); + auto new_input = insert_common_op(m, ins, make_op("div"), {input, n_literal}); + auto reduce_sum = + m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), new_input); + m.replace_instruction( + ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), reduce_sum); } } }; @@ -101,7 +144,8 @@ struct find_reduce_mean void rewrite_reduce::apply(module& m) const { - match::find_matches(m, find_softmax{}, find_reduce_mean{}); + match::find_matches(m, find_softmax{}, find_reduce_mean_variance{}); + match::find_matches(m, find_reduce_mean{}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/shape.cpp b/src/shape.cpp index 79294debcb1..073089c2cd2 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -728,6 +728,24 @@ shape::type_t shape::parse_type(const std::string& s) const std::vector& shape::sub_shapes() const { return impl->m_shapes; } +std::vector flatten(const std::vector& shapes) +{ + std::vector result; + for(const auto& s : shapes) + { + if(s.type() == shape::tuple_type) + { + auto subs = flatten(s.sub_shapes()); + result.insert(result.end(), subs.begin(), subs.end()); + } + else + { + result.push_back(s); + } + } + return result; +} + void migraphx_to_value(value& v, const shape& s) { value result; diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 4e79caefd90..16df584593f 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -482,23 +482,174 @@ struct find_double_add_lit_broadcast } }; +/// Find elementswise operators that have all broadcast inputs. It then +/// rewrites the elementwise to do the computation on the non-broadcasted +/// axes, and then broadcast that result. struct find_inner_broadcast { auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); } - static auto non_scalar_op(const std::string& name) + static auto get_non_broadcast_input(instruction_ref ins) { - return [=](instruction_ref ins) { - if(ins->get_shape().scalar()) - return false; - return ins->name() == name; - }; + if(ins->inputs().size() != 1) + return ins; + auto input = ins->inputs().front(); + if(contains(input->name(), "broadcast")) + return get_non_broadcast_input(input); + return input; + } + + static bool is_unsqueeze_needed_for_multibroadcast(const shape& input, const shape& output) + { + if(input.elements() == 1) + return false; + auto shift = output.ndim() - input.ndim(); + if(shift == 0) + return false; + if(std::equal(input.lens().begin(), + input.lens().end(), + output.lens().begin() + shift, + output.lens().end())) + { + return std::all_of(output.lens().begin(), output.lens().begin() + shift, [](auto x) { + return x == 1; + }); + } + return true; + } + // Simple case + void apply_same_broadcasts(module& m, instruction_ref ins) const + { + const auto& broadcasts = ins->inputs(); + // Scalars can have different ndim, so find the largest ndim input + auto max_broadcast = *std::max_element( + broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref broadcast) { + return get_non_broadcast_input(broadcast)->get_shape().ndim(); + })); + auto max_ndim = max_broadcast->get_shape().ndim(); + std::vector inputs; + std::transform(broadcasts.begin(), + broadcasts.end(), + std::back_inserter(inputs), + [&](instruction_ref broadcast) { + auto input = get_non_broadcast_input(broadcast); + auto s = input->get_shape(); + // If scalar doesnt match the other input dims then add a squeeze + if(s.elements() == 1 and s.ndim() > 1 and s.ndim() != max_ndim) + return m.insert_instruction(broadcast, make_op("squeeze"), input); + return input; + }); + auto op = insert_common_op(m, ins, ins->get_operator(), inputs); + m.replace_instruction(ins, broadcasts.front()->get_operator(), op); + } + + void apply_diff_broadcasts(module& m, instruction_ref ins) const + { + const auto& broadcasts = ins->inputs(); + auto ndim = ins->get_shape().ndim(); + // Compute the inner dimensions and axes that the computation will + // use. Also compute the axes that will be broadcasted + std::vector idims; + std::vector iaxes; + std::vector axes; + for(auto axis : range(ndim)) + { + if(std::all_of(broadcasts.begin(), broadcasts.end(), [&](instruction_ref i) { + auto s = i->get_shape(); + return s.lens()[axis] == 1 or s.strides()[axis] == 0; + })) + { + axes.push_back(axis); + } + else + { + iaxes.push_back(axis); + idims.push_back(ins->get_shape().lens()[axis]); + } + } + // If the inner axes are the same as the original operator then + // there is no reason to do this transformation. + if(iaxes.size() == ndim) + return; + std::vector inputs; + std::transform( + broadcasts.begin(), + broadcasts.end(), + std::back_inserter(inputs), + [&](instruction_ref broadcast) { + auto input = broadcast->inputs().front(); + auto s = input->get_shape(); + + // If its a single element then just return that as an input + if(s.elements() == 1) + { + if(s.lens().size() > 1) + return m.insert_instruction(broadcast, make_op("squeeze"), input); + return input; + } + + // Find how the axes are shifted from the broadcast + std::int64_t shift = ndim - s.ndim(); + if(broadcast->name() == "broadcast") + shift = broadcast->get_operator().to_value()["axis"].to(); + // Compute the squeeze axes to be used by taking the inner + // axes and shifting to what the axes will be on the + // input + std::vector sq_axes; + for(auto axis : axes) + { + auto iaxis = axis - shift; + if(iaxis < 0) + continue; + if(iaxis >= s.ndim()) + continue; + sq_axes.push_back(iaxis); + } + instruction_ref result = input; + if(not sq_axes.empty()) + result = m.insert_instruction( + broadcast, make_op("squeeze", {{"axes", sq_axes}}), result); + // If the number of dimension are still smaller than the + // number of inner axes, then we need to insert a + // broadcast to have the same dimensions for all inputs. + if(result->get_shape().ndim() < iaxes.size()) + { + // We find the first inner axis that can be mapped to the input + auto start_axis = std::find_if(iaxes.begin(), + iaxes.end(), + [&](auto x) { return x >= shift; }) - + iaxes.begin(); + result = m.insert_instruction( + broadcast, + make_op("broadcast", {{"axis", start_axis}, {"out_lens", idims}}), + result); + } + return result; + }); + auto op = insert_common_op(m, ins, ins->get_operator(), inputs); + if(iaxes.size() == 1) + { + m.replace_instruction( + ins, + make_op("broadcast", + {{"axis", iaxes.front()}, {"out_lens", ins->get_shape().lens()}}), + op); + } + else + { + auto unsqueeze = + is_unsqueeze_needed_for_multibroadcast(op->get_shape(), ins->get_shape()) + ? m.insert_instruction(ins, make_op("unsqueeze", {{"axes", axes}}), op) + : op; + m.replace_instruction( + ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), unsqueeze); + } } void apply(module& m, const match::matcher_result& r) const { - auto ins = r.result; - auto broadcasts = ins->inputs(); + auto ins = r.result; + const auto& broadcasts = ins->inputs(); if(broadcasts.empty()) return; // Skip if different data types are used @@ -506,65 +657,45 @@ struct find_inner_broadcast return i->get_shape().type() != broadcasts.front()->get_shape().type(); })) return; - bool mixed_broadcasts = any_of(broadcasts, non_scalar_op("broadcast")) and - any_of(broadcasts, non_scalar_op("multibroadcast")); - // If the broadcast is not a single dimension, then dont perform inner_broadcast - if(mixed_broadcasts and any_of(broadcasts, [&](instruction_ref i) { - if(i->get_shape().scalar()) - return false; - if(i->name() == "multibroadcast") - return false; - auto input = i->inputs().at(0); - const auto& lens = input->get_shape().lens(); - return std::count_if(lens.begin(), lens.end(), [&](std::size_t d) { - return d == 1; - }) < (lens.size() - 1); + + // All inputs should have less elements + if(not all_of(broadcasts, [&](instruction_ref broadcast) { + auto input = broadcast->inputs().front(); + return input->get_shape().elements() < ins->get_shape().elements(); })) return; - if(broadcasts.size() > 1) + // Find first broadcast that is not a scalar + auto first = + std::find_if(broadcasts.begin(), broadcasts.end(), [&](instruction_ref broadcast) { + return not broadcast->get_shape().scalar(); + }); + // Try to see if we can do a simple case that just applies the op to + // the inputs of the broadcasts, and then just put that same + // broadcast after the op. For this case we need each of the + // broadcasts to be the same and the inputs to have the same dimesion + // (or be scalar). + const bool same_broadcasts = + std::all_of(first, broadcasts.end(), [&](instruction_ref broadcast) { + if(broadcast->get_operator() != (*first)->get_operator()) + return false; + auto s1 = get_non_broadcast_input(broadcast)->get_shape(); + auto s2 = get_non_broadcast_input(*first)->get_shape(); + if(s1.elements() == 1) + return true; + return s1.lens() == s2.lens(); + }); + if(same_broadcasts) { - auto bcast_strides = broadcasts.front()->get_shape().strides().size(); - std::vector common_axis(bcast_strides, 0); - // go through the strides of each broadcast, - // keep track of values that are equal to 0 in a dimension - for(auto i = 0; i < bcast_strides; i++) - { - for(const auto& broadcast : broadcasts) - { - if(broadcast->get_shape().strides()[i] == 0) - common_axis[i]++; - } - } - // if no common broadcast axis, transformation is not useful - if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common) { - return num_common > 1; - }) == common_axis.end()) - return; + apply_same_broadcasts(m, ins); + } + // Skip if any input to the broadcasted inputs is already broadcasted + // as the below algorithm may not be able to handle such case. + else if(std::none_of(broadcasts.begin(), broadcasts.end(), [](instruction_ref broadcast) { + return broadcast->inputs().front()->get_shape().broadcasted(); + })) + { + apply_diff_broadcasts(m, ins); } - - std::vector inputs; - std::transform(broadcasts.begin(), - broadcasts.end(), - std::back_inserter(inputs), - [&](instruction_ref i) { - auto input = i->inputs().front(); - if(mixed_broadcasts and not i->get_shape().scalar() and - i->get_shape().lens().size() > 1) - return m.insert_instruction(i, make_op("squeeze"), input); - return input; - }); - - std::sort(broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref i) { - if(i->get_shape().scalar()) - return 2; - else if(i->name() == "broadcast") - return 0; - if(i->name() == "multibroadcast") - return 1; - return 3; - })); - auto op = insert_common_op(m, ins, ins->get_operator(), inputs); - m.replace_instruction(ins, broadcasts.front()->get_operator(), op); } }; @@ -792,6 +923,9 @@ void move_instructions_back(module& m, instruction_ref pos, std::vector get_splits(instruction_ref ins) { std::vector result; @@ -803,16 +937,22 @@ std::vector get_splits(instruction_ref ins) return {}; auto get_slice = [](auto& i) -> auto& { return any_cast(i->get_operator()); }; auto&& axes = get_slice(result.front()).axes; + + // "slice" instructions must all have the same axes if(std::any_of(result.begin(), result.end(), [&](auto i) { return get_slice(i).axes != axes; })) return {}; auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts; }; auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends; }; + + // Sort the "slice" instructions in order of starts std::sort( result.begin(), result.end(), [&](auto x, auto y) { return get_start(x) < get_start(y); }); if(std::any_of(get_start(result.front()).begin(), get_start(result.front()).end(), [&](auto i) { return i != 0; })) return {}; + + // one slice must "start" where the last slice "end" auto it = std::adjacent_find( result.begin(), result.end(), [&](auto x, auto y) { return get_end(x) != get_start(y); }); if(it != result.end()) @@ -998,6 +1138,10 @@ struct find_splits } }; +/** + * Matcher for a sequence of "slice" operations whose outputs are put back + * together by a "concat". + */ struct find_split_concat { auto matcher() const @@ -1008,40 +1152,56 @@ struct find_split_concat void apply(module& m, const match::matcher_result& r) const { + // Verifies that the slices meet several conditions: they must all output to the same + // concat instruction, slice on the same (1 only) axis, and the end of one slice + // must match the start of the next. auto ins = r.result; auto splits = get_splits(ins); if(splits.empty()) return; + // Each slice must output to only one instruction if(std::any_of( splits.begin(), splits.end(), [](auto i) { return i->outputs().size() != 1; })) return; - // Check for concat operator + // The single output instruction for all items in the list must be the same one auto concat = splits.front()->outputs().front(); if(std::any_of(splits.begin(), splits.end(), [&](auto i) { return i->outputs().front() != concat; })) return; - // Check axis match + + // The axis for the common output instruction must be the same as for the split ops auto concat_op = any_cast(concat->get_operator()); auto split_op = any_cast(splits.front()->get_operator()); if(split_op.axes.size() != 1) return; if(split_op.axes.front() != concat_op.axis) return; - // Replace args + + // Find where the slices are in the concat instruction's inputs (concat can have + // any number of inputs) auto args = concat->inputs(); auto it = std::find_if(args.begin(), args.end(), [&](auto i) { return i == splits.front(); }); + // Verify the slices were found, and the list is long enough if(std::distance(it, args.end()) < splits.size()) return; - // If the slices are not in order then stop + // Don't do anything if the "slice" inputs to the concat op have other operations mixed in + // among them + if(std::any_of(it, it + splits.size(), [](instruction_ref x) { + return x->get_operator().name() != "slice"; + })) + return; + // Check that the slices passed to concat are in order. if(not std::is_sorted(it, it + splits.size(), [](instruction_ref x, instruction_ref y) { auto xop = any_cast(x->get_operator()); auto yop = any_cast(y->get_operator()); return std::tie(xop.starts, xop.ends) < std::tie(yop.starts, yop.ends); })) return; + + // Perform the substitution *it = splits.front()->inputs().front(); args.erase(std::next(it), it + splits.size()); diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 21c69834811..d22876a274b 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -43,8 +43,8 @@ namespace { template auto skip_post_dq_ops(Ms... ms) { - return match::skip( - match::name("broadcast", "multibroadcast", "contiguous", "transpose", "reshape"))(ms...); + return match::skip(match::name( + "broadcast", "multibroadcast", "contiguous", "transpose", "reshape", "convert"))(ms...); } std::unordered_set get_quantizable_op_names() @@ -90,8 +90,10 @@ struct match_find_quantizable_ops // Helper function to insert quantized versions of any broadcasts and transpose ops that // occur between dequantizelinear and the quantized op - static auto - propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop_arg) + static auto propagate_quantized_ins(module& m, + const instruction_ref dqins, + const instruction_ref qop_arg, + bool is_fp16_model = false) { auto prev_ins = qop_arg; std::vector ins_inbetween; @@ -105,6 +107,10 @@ struct match_find_quantizable_ops auto qinp = dqins->inputs().front(); for(auto ins : reverse_iterator_for(ins_inbetween)) { + if((*ins)->name() == "convert" and is_fp16_model) + { + continue; + } qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp}); } return qinp; @@ -143,8 +149,15 @@ struct match_find_quantizable_ops // Propagate q1 and q2 through any broadcasts and transposes before qop auto qop_args = qop->inputs(); - qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0]); - qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1]); + bool is_fp16_model = false; + if(dq1->get_shape().type() != qop->get_shape().type() and + qop->get_shape().type() == migraphx::shape::half_type) + { + assert(dq1->get_shape().type() == migraphx::shape::float_type); + is_fp16_model = true; + } + qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model); + qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model); auto arg1_lens = qop_args[0]->get_shape().lens(); auto arg2_lens = qop_args[1]->get_shape().lens(); instruction_ref dq; @@ -260,6 +273,11 @@ struct match_find_quantizable_ops } dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale, out_zp); + if(is_fp16_model) + { + dq = m.insert_instruction( + qop, make_op("convert", {{"target_type", migraphx::shape::half_type}}), dq); + } m.replace_instruction(qop, dq); } }; diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 099a7cd35a1..a0a952d6aac 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -120,6 +120,7 @@ struct find_nop_reshapes reshapes.insert("multibroadcast"); reshapes.insert("pad"); reshapes.insert("slice"); + reshapes.insert("step"); reshapes.insert("transpose"); reshapes.insert("reduce_mean"); reshapes.insert("reduce_max"); @@ -243,6 +244,21 @@ struct find_nested_slice } }; +/** + * Example case + * From: + * param0: lens = [3, 4], strides = [4, 1] + * param1: lens = [3, 4], strides = [4, 1] + * mb0: multibroadcast(param0, output_lens = [2, 3, 4]) + * mb1: multibroadcast(param1, output_lens = [2, 3, 4]) + * concat(mb0, mb1, axis = 2) + * + * To: + * param0: lens = [3, 4], strides = [4, 1] + * param1: lens = [3, 4], strides = [4, 1] + * con0: concat(param0, param1, axis = 1) + * multibroadcast(con0, lens = [2, 3, 4]) + */ struct find_concat_multibroadcasts { auto matcher() const @@ -252,32 +268,62 @@ struct find_concat_multibroadcasts void apply(module& m, const match::matcher_result& mr) const { - auto ins = mr.result; - auto op = any_cast(ins->get_operator()); - auto out_lens = ins->get_shape().lens(); - auto inputs = ins->inputs(); - auto in_strides = inputs.front()->get_shape().strides(); + auto concat_ins = mr.result; + auto concat_op = any_cast(concat_ins->get_operator()); + auto concat_out_lens = concat_ins->get_shape().lens(); + auto concat_inputs = concat_ins->inputs(); + auto front_mb_strides = concat_inputs.front()->get_shape().strides(); + assert(concat_op.axis >= 0); // Only apply when concat axis is not a broadcasted dimension - if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) { - return i->get_shape().strides()[op.axis] == 0; + if(std::any_of(concat_inputs.begin(), concat_inputs.end(), [&](auto i) { + return i->get_shape().strides()[concat_op.axis] == 0; })) { return; } - // Use inputs of multibroadcast ops as inputs to new concat op - std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) { + // Get the inputs of multibroadcast ops. Will be used as inputs to new concat op + std::vector mb_inputs(concat_inputs.size()); + std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) { return i->inputs().front(); }); + // Check that the inputs into the multibroadcasts have the same rank + const auto& first_shape = mb_inputs.front()->get_shape(); + if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) { + return mb_in->get_shape().ndim() == first_shape.ndim(); + })) + { + return; + } + // Reduce axis by number of leading broadcasted dimensions - if(inputs.front()->get_shape().lens().size() < out_lens.size()) - op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0); + if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size()) + { + concat_op.axis -= + std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0); + } - auto concat = m.insert_instruction(ins, op, inputs); - m.replace_instruction( - ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat); + // Inputs to multibroadcasts should have the same dimensions except for the axis to + // concatenate over + const auto& front_in_lens = mb_inputs.front()->get_shape().lens(); + if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto input_to_mb) { + const auto& lens = input_to_mb->get_shape().lens(); + return std::equal( + lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and + std::equal(lens.begin() + concat_op.axis + 1, + lens.end(), + front_in_lens.begin() + concat_op.axis + 1); + })) + { + return; + } + + auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs); + m.replace_instruction(concat_ins, + migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}), + new_concat_ins); } }; @@ -645,41 +691,108 @@ struct find_reshape_cont } }; -// match sequence of transpose --> contiguous --> reshaper_op -auto match_transpose_contiguous_reshaper() -{ - return match::name({"reshape", "squeeze", "unsqueeze"})( - match::used_once(), - match::args( - match::name("contiguous")( - match::used_once(), match::args(match::transpose_shape().bind("trans_ins"))) - .bind("cont_ins"))) - .bind("reshaper_ins"); -}; - -// finds the pattern of transpose --> contiguous --> reshaper_op --> unary -// application of this matcher moves the unary operation before the contiguous so it becomes -// transpose --> unary --> contiguous --> reshaper_op. later pointwise sub-module can be created out -// of unary --> contiguous --> reshaper_op. Such pattern appears in depthToSpace or spaceToDepth -// operator. -struct find_transpose_contiguous_reshaper_unary +struct find_unary_shape_transforms { + static const auto& shape_transforms() + { + static const std::unordered_set names = { + "flatten", + "reshape", + "squeeze", + "unsqueeze", + "transpose", + "broadcast", + "multibroadcast", + }; + return names; + } auto matcher() const { - return pointwise(match::used_once(), - match::nargs(1), - match::args(match_transpose_contiguous_reshaper())); + auto output_not_pointwise = + match::none_of(match::skip_output(match::name("contiguous"))(match::pointwise())); + auto input_has_shape_transform = + match::args(match::skip(match::name("contiguous"))(match::name(shape_transforms()))); + return match::pointwise( + match::used_once(), input_has_shape_transform, output_not_pointwise); } - void apply(module& m, const match::matcher_result& r) const + static bool is_shape_transform(instruction_ref ins) { - auto ins = r.result; - auto reshaper_ins = r.instructions["reshaper_ins"]; - auto trans_ins = r.instructions["trans_ins"]; - auto cont_ins = r.instructions["cont_ins"]; - auto unary_ins = m.insert_instruction(cont_ins, ins->get_operator(), trans_ins); - // older cont and reshape are removed by deadcode elimination - m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins); + return ins->inputs().size() == 1 and + (contains(shape_transforms(), ins->name()) or ins->name() == "contiguous"); + } + + static bool can_fuse_unary(instruction_ref ins) + { + return ins->name() == "@literal" or + ins->get_operator().attributes().contains("pointwise") or + contains(ins->name(), "reduce"); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + if(ins->outputs().empty()) + return; + auto input = ins->inputs().front(); + auto output = ins->outputs().front(); + + auto insert_ops = [&](const auto& ops, instruction_ref z) { + for(const auto& op : ops) + { + z = m.insert_instruction(ins, op, z); + } + return z; + }; + + std::vector xops; + auto x = input; + while(is_shape_transform(x)) + { + xops.push_back(x->get_operator()); + x = x->inputs().front(); + } + std::reverse(xops.begin(), xops.end()); + + std::vector yops; + auto y = output; + auto last_transform = m.end(); + while(is_shape_transform(y) and y->outputs().size() == 1) + { + yops.push_back(y->get_operator()); + last_transform = y; + y = y->outputs().front(); + } + + bool move_up = can_fuse_unary(x); + bool move_down = can_fuse_unary(y); + + if(move_up and move_down) + { + if(x->name() == "@literal") + move_down = false; // NOLINT(bugprone-branch-clone) + else if(yops.empty()) + move_up = false; + else + move_down = false; + } + else if(not move_up and not move_down) + { + if(not yops.empty()) + move_up = true; + } + + if(move_up) + { + auto z = m.insert_instruction(ins, ins->get_operator(), x); + z = insert_ops(xops, z); + m.replace_instruction(ins, z); + } + else if(move_down and not yops.empty()) + { + auto z = insert_ops(yops, input); + m.replace_instruction(last_transform, ins->get_operator(), z); + } } }; @@ -967,7 +1080,7 @@ void simplify_reshapes::apply(module& m) const find_transpose_slice{}, find_broadcast_transpose{}, find_slice_transpose{}, - find_transpose_contiguous_reshaper_unary{}, + find_unary_shape_transforms{}, find_reshape_reshape_dot{}, find_scalar_multibroadcast_reshape_or_transpose{}); dead_code_elimination{}.apply(m); diff --git a/src/split_single_dyn_dim.cpp b/src/split_single_dyn_dim.cpp index a5707c38674..5998896724e 100644 --- a/src/split_single_dyn_dim.cpp +++ b/src/split_single_dyn_dim.cpp @@ -144,6 +144,8 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const submod->add_return({outputs}); submodules.push_back(submod); } + // sort parameters by name for consistency (vs. parameter order attr) + std::sort(param_names.begin(), param_names.end()); // redirect to select_module operator and return std::vector sm_inputs; std::transform(param_names.cbegin(), diff --git a/src/targets/cpu/include/migraphx/cpu/parallel.hpp b/src/targets/cpu/include/migraphx/cpu/parallel.hpp index b25483facc5..cb3b9ed6456 100644 --- a/src/targets/cpu/include/migraphx/cpu/parallel.hpp +++ b/src/targets/cpu/include/migraphx/cpu/parallel.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ // #define MIGRAPHX_DISABLE_OMP #include +#include #include #ifdef MIGRAPHX_DISABLE_OMP #include @@ -68,8 +69,10 @@ void parallel_for_impl(std::size_t n, std::size_t threadsize, F f) std::size_t work = 0; std::generate(threads.begin(), threads.end(), [=, &work] { - auto result = - joinable_thread([=]() mutable { f(work, std::min(n, work + grainsize)); }); + auto result = joinable_thread([=]() mutable { + assert(work < n); + f(work, std::min(n, work + grainsize)); + }); work += grainsize; return result; }); @@ -91,10 +94,11 @@ void parallel_for_impl(std::size_t n, std::size_t threadsize, F f) else { std::size_t grainsize = std::ceil(static_cast(n) / threadsize); -#pragma omp parallel for num_threads(threadsize) schedule(static, 1) private(grainsize, n) +#pragma omp parallel for num_threads(threadsize) schedule(static, 1) for(std::size_t tid = 0; tid < threadsize; tid++) { std::size_t work = tid * grainsize; + assert(work < n); f(work, std::min(n, work + grainsize)); } } diff --git a/src/targets/cpu/lowering.cpp b/src/targets/cpu/lowering.cpp index c54d3032cac..3747306158e 100644 --- a/src/targets/cpu/lowering.cpp +++ b/src/targets/cpu/lowering.cpp @@ -352,7 +352,6 @@ struct cpu_apply extend_op("logsoftmax", "dnnl::logsoftmax"); extend_op("lrn", "dnnl::lrn"); extend_op("softmax", "dnnl::softmax"); - extend_op("sub", "cpu::sub"); extend_op("im2col", "cpu::im2col", false); extend_op("leaky_relu", "cpu::leaky_relu", false); diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index a49d1dde0a6..d2e73ef1de7 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -33,9 +33,13 @@ endif() find_package(miopen REQUIRED) message(STATUS "MIGraphX is using MIOpen") -# rocblas -find_package(rocblas REQUIRED) -message(STATUS "MIGraphX build with rocBLAS") +if(MIGRAPHX_USE_ROCBLAS) + # rocblas + find_package(rocblas REQUIRED) + message(STATUS "MIGraphX build with rocBLAS") +else() + message(STATUS "MIGraphX build without rocBLAS") +endif() if(MIGRAPHX_USE_COMPOSABLEKERNEL) find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library) @@ -124,6 +128,7 @@ add_library(migraphx_gpu compile_hip.cpp compile_hip_code_object.cpp compile_miopen.cpp + compile_pointwise.cpp compiler.cpp device_name.cpp fuse_ck.cpp @@ -142,6 +147,7 @@ add_library(migraphx_gpu nonzero.cpp pack_args.cpp prefuse_ops.cpp + prepare_reduce.cpp perfdb.cpp pooling.cpp problem_cache.cpp @@ -187,10 +193,12 @@ register_op(migraphx_gpu HEADER migraphx/gpu/rnn_variable_seq_lens.hpp OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output INCLUDES migraphx/gpu/context.hpp) -register_op(migraphx_gpu - HEADER migraphx/gpu/gemm.hpp - OPERATORS gpu::rocblas_gemm gpu::rocblas_gemm - INCLUDES migraphx/gpu/context.hpp) +if(MIGRAPHX_USE_ROCBLAS) + register_op(migraphx_gpu + HEADER migraphx/gpu/gemm.hpp + OPERATORS gpu::rocblas_gemm gpu::rocblas_gemm + INCLUDES migraphx/gpu/context.hpp) +endif() register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp OPERATORS gpu::miopen_convolution gpu::miopen_convolution gpu::miopen_convolution INCLUDES migraphx/gpu/context.hpp) @@ -258,13 +266,19 @@ target_compile_definitions(migraphx_gpu PUBLIC MIGRAPHX_CXX_COMPILER="${CMAKE_CX include(CheckLibraryExists) get_target_property(MIOPEN_LOCATION MIOpen LOCATION) -get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION) -check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) -# Beta API for automated GEMM tuning -check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API) -# rocblas FP8 API -check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API) +check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) + +if(MIGRAPHX_USE_ROCBLAS) + get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION) + target_compile_definitions(migraphx_gpu PUBLIC MIGRAPHX_USE_ROCBLAS=1) + # Beta API for automated GEMM tuning + check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API) + # rocblas FP8 API + check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API) +else() + target_compile_definitions(migraphx_gpu PUBLIC MIGRAPHX_USE_ROCBLAS=0) +endif() set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") @@ -287,21 +301,27 @@ else() message(STATUS "MIOpen does not have find mode api") endif() -if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API) - target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS) - message(STATUS "MIGraphx is using Beta API of rocBLAS") -else() - message(STATUS "rocBLAS does not have User Tuning Beta API") -endif() +if(MIGRAPHX_USE_ROCBLAS) + if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API) + target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS) + message(STATUS "MIGraphx is using Beta API of rocBLAS") + else() + message(STATUS "rocBLAS does not have User Tuning Beta API") + endif() -if(HAS_ROCBLAS_FP8_BETA_API) - target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS) - message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations") -else() - message(STATUS "rocBLAS does not have Fp8 Beta API") + if(HAS_ROCBLAS_FP8_BETA_API) + target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS) + message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations") + else() + message(STATUS "rocBLAS does not have Fp8 Beta API") + endif() + + + target_link_libraries(migraphx_gpu PUBLIC roc::rocblas) endif() -target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) +target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen) + target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) if(MIGRAPHX_USE_COMPOSABLEKERNEL) target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library) diff --git a/src/targets/gpu/code_object_op.cpp b/src/targets/gpu/code_object_op.cpp index 67c9d59472e..3f640e59d63 100644 --- a/src/targets/gpu/code_object_op.cpp +++ b/src/targets/gpu/code_object_op.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -40,7 +40,7 @@ shape code_object_op::compute_shape(std::vector inputs) const std::transform(einputs.begin(), einputs.end(), einputs.begin(), [](const shape& s) { return s.normalize_standard(); }); - if(einputs != inputs) + if(einputs != flatten(inputs)) MIGRAPHX_THROW("Input shapes have changed: [" + to_string_range(einputs) + "] -> [" + to_string_range(inputs) + "]"); return output; @@ -48,9 +48,10 @@ shape code_object_op::compute_shape(std::vector inputs) const argument code_object_op::compute(context& ctx, const shape&, const std::vector& args) const { - std::vector kargs(args.size()); + auto fargs = flatten(args); + std::vector kargs(fargs.size()); std::transform( - args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); }); + fargs.begin(), fargs.end(), kargs.begin(), [](const argument& a) { return a.data(); }); auto [start, stop] = ctx.get_perf_events(); k.launch(ctx.get_stream().get(), global, local, std::move(kargs), start, stop); return args[get_output_arg(args.size())]; diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index 2469c6571a8..f4c0d66e326 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -180,12 +181,16 @@ std::string make_transformer_args(std::vector transformers) return join_strings(std::move(transformers), ", "); } -void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name) +static void generate_pointwise(cpp_generator& gg, + const module& pm, + const std::string& name, + bool always_return_tuple = false) { module m = pm; run_passes(m, {rewrite_quantization{}, optimize_module{}}); m.sort(); cpp_generator g; + g.always_return_tuple(always_return_tuple); g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})"); @@ -202,28 +207,30 @@ void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& .set_generic_types(m) .set_name(name)); } -std::string generate_pointwise(const module& pm, const std::string& name) +std::string generate_pointwise(const module& pm, const std::string& name, bool always_return_tuple) { cpp_generator g; - generate_pointwise(g, pm, name); + generate_pointwise(g, pm, name, always_return_tuple); return g.str(); } std::string reduce_op::str() const { - return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + input + "))"; + return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + + join_strings(inputs, ", ") + "))"; } -void reduce_op::set(instruction_ref ins, const operation& op) +void reduce_op::set(const std::string& name, const shape& input, const shape& output) { - if(op.name() == "reduce_sum") + assert(input.type() != shape::tuple_type); + assert(output.type() != shape::tuple_type); + if(name == "reduce_sum") { reduction = "op::sum{}"; } - else if(op.name() == "reduce_mean") + else if(name == "reduce_mean") { - auto s = ins->inputs().front()->get_shape(); - auto reduce_elements = s.elements() / ins->get_shape().elements(); - auto reduce_type = s.type(); + auto reduce_elements = input.elements() / output.elements(); + auto reduce_type = input.type(); reduction = "op::sum{}"; std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}"; // Use float accumulator when reduction size is too large for half @@ -234,17 +241,17 @@ void reduce_op::set(instruction_ref ins, const operation& op) else write = mean; } - else if(op.name() == "reduce_max") + else if(name == "reduce_max") { reduction = "op::max{}"; init = "lowest{}"; } - else if(op.name() == "reduce_min") + else if(name == "reduce_min") { reduction = "op::min{}"; init = "highest{}"; } - else if(op.name() == "reduce_prod") + else if(name == "reduce_prod") { reduction = "op::product{}"; init = "1"; @@ -254,7 +261,23 @@ void reduce_op::set(instruction_ref ins, const operation& op) MIGRAPHX_THROW("Unsupported reduce"); } } -std::string reduce_op::generate(instruction_ref ins, const std::string& x) + +void reduce_op::set(instruction_ref ins, const operation& op) +{ + if(op.name() == "gpu::parallel_reduce") + { + auto rop = from_value(op.to_value().at("op")); + auto input = ins->inputs().front()->get_shape(); + auto output = ins->get_shape().sub_shapes().front(); + set(rop.name(), input, output); + read = "compose(array_apply(" + read + "), MIGRAPHX_LIFT(make_array))"; + } + else + { + set(op.name(), ins->inputs().front()->get_shape(), ins->get_shape()); + } +} +std::string reduce_op::generate(instruction_ref ins, const std::vector& x) { reduce_op r{x}; r.set(ins, ins->get_operator()); @@ -265,6 +288,15 @@ static bool use_lazy_inner(instruction_ref ins) { if(ins->outputs().size() != 1) return false; + // When the inputs are broadcasted, it means the lambda will capture SGPRs + // when doing block/wave reduction. This can cause register spilling in + // the compiler when the lambda is evaluated at a later time although it + // shouldn't. Instead, use `inner` to workaround this issue in the + // compiler. + if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](instruction_ref input) { + return input->get_shape().broadcasted(); + })) + return false; auto output = ins->outputs().front(); return contains(output->name(), "reduce") or output->name() == "@return"; } @@ -285,7 +317,7 @@ void preload_params(module& m) std::string generate_reduce(module m, const std::string& name) { preload_params(m); - run_passes(m, {optimize_module{}}); + run_passes(m, {optimize_module{}, prepare_reduce{}, optimize_module{}}); m.sort(); cpp_generator g; auto param_shapes = m.get_parameter_shapes(); @@ -298,9 +330,9 @@ std::string generate_reduce(module m, const std::string& name) auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) { if(contains(ins->name(), "reduce")) { - return reduce_op::generate(ins, names.at(ins->inputs().front())); + return reduce_op::generate(ins, cpp_generator::to_args(ins->inputs(), names)); } - else if(ins->name() == "pointwise") + if(ins->name() == "pointwise") { auto pointwise_name = "pointwise" + std::to_string(i); i++; @@ -342,11 +374,18 @@ std::string generate_reduce(module m, const std::string& name) {"args", join_strings(args, ", ")}, {"call", call_function}}); } - else if(ins->name() == "multibroadcast") + if(ins->name() == "multibroadcast") { return names.at(ins->inputs().front()); } - else if(ins->name() == "identity") + if(ins->name() == "get_tuple_elem") + { + const auto& x = names.at(ins->inputs().front()); + auto index = ins->get_operator().to_value()["index"].to(); + return interpolate_string("${x}[${index}]", + {{"x", x}, {"index", std::to_string(index)}}); + } + if(ins->name() == "identity") { const auto& x = names.at(ins->inputs().front()); return "r.inner(op::id{})(" + x + ")"; diff --git a/src/targets/gpu/compile_miopen.cpp b/src/targets/gpu/compile_miopen.cpp index ce1583e8451..583601bdda1 100644 --- a/src/targets/gpu/compile_miopen.cpp +++ b/src/targets/gpu/compile_miopen.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -29,7 +29,6 @@ #include #include #include -#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index 66bfb7e2052..730708c143c 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include #include #include #include @@ -185,17 +186,29 @@ struct compile_plan std::cout << "No binary" << std::endl; return std::numeric_limits::max(); } - // Time all the code objects for a given perf config and calculate total - // time e.g. in case of split-K GEMM, it may or may not support fusion. - // In that case MLIR compile would return code objects for individual - // GEMM and pre/post fusion code objects. - auto cobjs = cr->replace.code_objects; - double t = transform_accumulate( - cobjs.begin(), - cobjs.end(), - double{0}, - std::plus<>{}, - [&](const operation& op) { return time_op(*ctx, op, 20); }); + /* + create a small program with insturction being compiled and call "replace" + on that which would insert all the compiled code objects, prefills etc. + necessary to run candidate code object + */ + program bench_prog; + auto* bench_mm = bench_prog.get_main_module(); + std::vector bench_ins_inputs; + + std::transform(cr->ins->inputs().begin(), + cr->ins->inputs().end(), + std::back_inserter(bench_ins_inputs), + [&](const auto& arg) { + return bench_mm->add_parameter( + std::to_string(bench_ins_inputs.size()), + arg->get_shape()); + }); + auto bench_ins = bench_mm->add_instruction( + cr->ins->get_operator(), bench_ins_inputs, cr->ins->module_inputs()); + cr->replace.replace(*bench_mm, bench_ins); + // do dead code elimination by directly removing instruction + bench_mm->remove_instruction(bench_ins); + auto t = time_program(*ctx, bench_prog, 20); if(trace_level > 1) std::cout << t << "ms" << std::endl; return t; diff --git a/src/targets/gpu/compile_pointwise.cpp b/src/targets/gpu/compile_pointwise.cpp new file mode 100644 index 00000000000..ee682cf2c3a --- /dev/null +++ b/src/targets/gpu/compile_pointwise.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +operation +compile_pointwise(context& ctx, const std::vector& in_shapes, const_module_ref pm) +{ + auto pf = gen::generate_pointwise(*pm, "inner_pointwise", true); + std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)"; + auto kernel_name = gen::generate_name_from_ops(*pm, "kernel"); + return gpu::compile_op("pointwise", + ctx, + in_shapes, + {{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}}); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index d2ae8584d46..bf9a269f3e1 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -21,11 +21,13 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include -#include + #include #include +#include #include +#include +#include #include namespace migraphx { @@ -106,7 +108,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) return false; } auto device_name = trim(split_string(get_device_name(), ':').front()); - if(device_name == "gfx940") + if(starts_with(device_name, "gfx94")) { if(ins->get_shape().type() == shape::half_type) { diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index a0a16512358..70163a61365 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include namespace migraphx { @@ -150,59 +151,11 @@ struct mlir_op if(inputs.size() < 2) MIGRAPHX_THROW("should have at least two inputs."); - auto type = mod->get_output_shapes().front().type(); - auto mod_params = mod->get_parameter_names(); - std::sort(mod_params.begin(), mod_params.end()); - std::unordered_map mod_ins_shapes; - std::unordered_map adjusted_mod_param_shapes; - std::transform(inputs.begin(), - inputs.end(), - mod_params.begin(), - std::inserter(adjusted_mod_param_shapes, adjusted_mod_param_shapes.end()), - [](auto ps, auto name) { return std::make_pair(name, ps); }); - for(auto ins : iterator_for(*mod)) - { - if(ins->name() == "@param") - { - mod_ins_shapes[ins] = - adjusted_mod_param_shapes[any_cast(ins->get_operator()) - .parameter]; - if(ins->get_shape().type() != mod_ins_shapes[ins].type()) - { - MIGRAPHX_THROW( - "MLIR_OP: adjusted mod parameter doesn't have the same type lens as " - "original input. Type changed from : " + - ins->get_shape().type_string() + " to " + - mod_ins_shapes[ins].type_string()); - } - if(ins->get_shape().lens() != mod_ins_shapes[ins].lens()) - { - MIGRAPHX_THROW("MLIR_OP: adjusted mod parameter doesn't have the same lens as " - "original input. Lens changed from " + - to_string_range(ins->get_shape().lens()) + " to " + - to_string_range(mod_ins_shapes[ins].lens())); - } - } - else if(ins->name() == "@literal") - { - mod_ins_shapes[ins] = ins->get_shape(); - } - else if(ins->name() == "@return") - { - return mod_ins_shapes[ins->inputs().at(0)].with_type(type); - } - else - { - std::vector input_shapes; - input_shapes.resize(ins->inputs().size()); - std::transform(ins->inputs().begin(), - ins->inputs().end(), - input_shapes.begin(), - [&](auto in) { return mod_ins_shapes[in]; }); - mod_ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes); - } - } - MIGRAPHX_THROW("No return found in the submodule"); + auto result = + mod->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true}); + if(result.size() == 1) + return result.front(); + return shape{result}; } }; MIGRAPHX_REGISTER_OP(mlir_op); @@ -248,8 +201,8 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, { auto [upper_input, op_stream] = get_fusable_input_op_stream(input); top_inputs.push_back(upper_input); - instruction_ref prev_input = mm->add_parameter("y" + std::to_string(input_cnt++), - upper_input->get_shape().as_standard()); + instruction_ref prev_input = + mm->add_parameter(param_name(input_cnt++, "y"), upper_input->get_shape().as_standard()); for(const auto& op : reverse(op_stream)) { prev_input = mm->add_instruction(op, {prev_input}); @@ -275,12 +228,12 @@ auto is_mlir_dot(mlir_mode mode) return false; if(ins->name() != "dot" and ins->name() != "quant_dot") return false; - if(mode != mlir_mode::fast) - return true; // dot operation where (FP8 * FP8 = FP8) is not available in MLIR. rocBLAS has the support // for it. if(ins->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type) return false; + if(mode != mlir_mode::fast) + return true; auto a = ins->inputs().front()->get_shape(); auto b = ins->inputs().back()->get_shape(); // auto m = a.lens()[a.lens().size() - 2]; @@ -448,6 +401,20 @@ MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins) }); } +std::vector mlir_contiguous(module_pass_manager& mpm, + const std::vector& inputs) +{ + std::vector result; + std::transform( + inputs.begin(), inputs.end(), std::back_inserter(result), [&](instruction_ref input) { + if(input->get_shape().packed() or input->get_shape().broadcasted()) + return input; + return mpm.get_module().insert_instruction( + std::next(input), make_op("contiguous"), input); + }); + return result; +} + struct find_mlir_fused_ops { mlir_mode conv_mode = mlir_mode::none; @@ -480,7 +447,7 @@ struct find_mlir_fused_ops [&](auto input) { return input != gemm_based_op; }); inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end()); mpm.get_module().replace_instruction( - ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm}); + ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); } }; @@ -509,8 +476,10 @@ struct find_mlir_standalone_op auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op( mm, gemm_based_op->inputs(), gemm_based_op->get_operator()); mm->add_return({anchor_op}); - mpm.get_module().replace_instruction( - gemm_based_op, mlir_op{gemm_based_op->get_operator()}, top_inputs, {mm}); + mpm.get_module().replace_instruction(gemm_based_op, + mlir_op{gemm_based_op->get_operator()}, + mlir_contiguous(mpm, top_inputs), + {mm}); } }; @@ -589,7 +558,7 @@ struct find_mlir_standalone_attention_op mm->add_return({ins_to_replace}); mpm.get_module().replace_instruction( - ins_to_be_replaced, mlir_op{gemm1->get_operator()}, inputs, {mm}); + ins_to_be_replaced, mlir_op{gemm1->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); } }; @@ -612,7 +581,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const { #ifdef MIGRAPHX_MLIR const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name(); - const bool is_navi = starts_with(device_name, "gfx110"); + const bool is_navi = starts_with(device_name, "gfx11"); auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) { if(specific_op(option)) diff --git a/src/targets/gpu/fuse_ops.cpp b/src/targets/gpu/fuse_ops.cpp index e3a0aa06393..579f4bfd552 100644 --- a/src/targets/gpu/fuse_ops.cpp +++ b/src/targets/gpu/fuse_ops.cpp @@ -166,7 +166,7 @@ struct fusion const std::unordered_set& get_supported_archs() { static std::unordered_set supported_archs{ - "gfx900", "gfx906", "gfx908", "gfx1030", "gfx940"}; + "gfx900", "gfx906", "gfx908", "gfx1030", "gfx940", "gfx941", "gfx942"}; return supported_archs; } @@ -550,6 +550,7 @@ struct find_conv_pointwise } }; +#if MIGRAPHX_USE_ROCBLAS struct find_gemm_pointwise { auto matcher() const @@ -675,6 +676,7 @@ struct find_gemm_pointwise m.replace_instruction(ins, gemm, inputs); } }; +#endif struct find_contiguous_tranpose_gemm { @@ -893,7 +895,9 @@ void fuse_ops::apply(module& m) const match::find_matches(m, find_conv_pointwise{ctx}, find_conv_bias_relu{ctx}, find_conv_bias{ctx}); run_passes(m, {dead_code_elimination{}}); match::find_matches(m, +#if MIGRAPHX_USE_ROCBLAS find_gemm_pointwise{}, +#endif find_layernorm_pointwise{}, find_concat_pointwise{}, find_contiguous_tranpose_gemm{}, diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 88da849bdc0..fbe49eb6797 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -36,7 +36,7 @@ using microseconds = std::chrono::duration; namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { - +#if MIGRAPHX_USE_ROCBLAS /* Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum @@ -79,8 +79,10 @@ void blas_shape(const shape& s) { if(s.lens().size() < 2) return; - if(std::none_of(s.strides().end() - 2, s.strides().end(), [&](auto i) { return i == 1; })) + if(std::none_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 1; })) MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1"); + if(std::any_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 0; })) + MIGRAPHX_THROW("GPU_GEMM: matrix dimensions can't be broadcasted"); if(s.lens().size() < 3) return; shape batch_shape{s.type(), @@ -129,7 +131,21 @@ auto rocblas_invoke(F f, Pack p, Ts... xs) }); } -static bool is_transposed(const shape& s) { return s.transposed() and s.strides().back() != 1; } +static bool is_transposed(const shape& s) +{ + if(s.transposed()) + { + return s.strides().back() != 1; + } + + if(not s.broadcasted() and s.strides() != s.as_standard().strides()) + { + auto perm = find_permutation(s); + return not std::is_sorted(perm.begin(), perm.end()); + } + + return false; +} static rocblas_int get_batch_stride(const shape& s) { @@ -662,7 +678,7 @@ int32_t gemm_finalize(context& ctx, return gemm_finalize_impl( ctx, output_shape, input_shapes, alpha, beta, compute_fp32, solution_idx); } - +#endif } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/include/migraphx/gpu/compile_gen.hpp b/src/targets/gpu/include/migraphx/gpu/compile_gen.hpp index 0ed50920584..645788aaee8 100644 --- a/src/targets/gpu/include/migraphx/gpu/compile_gen.hpp +++ b/src/targets/gpu/include/migraphx/gpu/compile_gen.hpp @@ -72,7 +72,8 @@ std::string make_transformer_args(Ts... xs) return make_transformer_args({xs.str()...}); } -std::string generate_pointwise(const module& pm, const std::string& name); +std::string +generate_pointwise(const module& pm, const std::string& name, bool always_return_tuple = false); std::string generate_reduce(module m, const std::string& name); @@ -80,15 +81,16 @@ std::string generate_name_from_ops(const module& m, const std::string& postname struct reduce_op { - std::string input = ""; + std::vector inputs = {}; std::string reduction = ""; std::string init = "0"; std::string read = "op::id{}"; std::string write = "op::id{}"; void set(instruction_ref ins, const operation& op); + void set(const std::string& name, const shape& input, const shape& output); std::string str() const; - static std::string generate(instruction_ref ins, const std::string& x); + static std::string generate(instruction_ref ins, const std::vector& x); }; } // namespace gen diff --git a/src/targets/gpu/include/migraphx/gpu/compile_miopen.hpp b/src/targets/gpu/include/migraphx/gpu/compile_miopen.hpp index 197560d2cf1..03dd669e536 100644 --- a/src/targets/gpu/include/migraphx/gpu/compile_miopen.hpp +++ b/src/targets/gpu/include/migraphx/gpu/compile_miopen.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/include/migraphx/gpu/compile_pointwise.hpp b/src/targets/gpu/include/migraphx/gpu/compile_pointwise.hpp new file mode 100644 index 00000000000..8e6dc229aaf --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/compile_pointwise.hpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP +#define MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +namespace gpu { + +operation +compile_pointwise(context& ctx, const std::vector& in_shapes, const_module_ref pm); + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index ef1d6259a4d..457e3712f87 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -107,7 +107,7 @@ struct hip_device assert(mihandle.get() != nullptr); return mihandle.get(); } - +#if MIGRAPHX_USE_ROCBLAS auto get_rocblas() { setup(); @@ -116,6 +116,7 @@ struct hip_device assert(rbhandle.get() != nullptr); return rbhandle.get(); } +#endif void wait() const { @@ -144,10 +145,12 @@ struct hip_device } private: - std::size_t id = 0; - shared s = nullptr; - shared mihandle = nullptr; + std::size_t id = 0; + shared s = nullptr; + shared mihandle = nullptr; +#if MIGRAPHX_USE_ROCBLAS shared rbhandle = nullptr; +#endif }; void add_stream() { streams.emplace_back(device_id); } diff --git a/src/targets/gpu/include/migraphx/gpu/gemm.hpp b/src/targets/gpu/include/migraphx/gpu/gemm.hpp index 321662888bf..f4ec57760a8 100644 --- a/src/targets/gpu/include/migraphx/gpu/gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/gemm.hpp @@ -52,7 +52,6 @@ struct rocblas_gemm bool compute_fp32 = false; unsigned trans_batch = 0; int32_t solution_idx = 0; - template static auto reflect(Self& self, F f) { @@ -158,9 +157,7 @@ struct rocblas_gemm #endif } }; - +#endif } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx - -#endif diff --git a/src/targets/gpu/include/migraphx/gpu/mlir.hpp b/src/targets/gpu/include/migraphx/gpu/mlir.hpp index c8b6a27e1a9..dc395b8eece 100644 --- a/src/targets/gpu/include/migraphx/gpu/mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/mlir.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -37,10 +38,21 @@ struct module; namespace gpu { MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m); -MIGRAPHX_GPU_EXPORT code_object_op compile_mlir(const context& migraphx_ctx, - module m, - const std::vector& inputs, - const value& solution); + +MIGRAPHX_GPU_EXPORT bool +is_module_fusible(const module& m, const context& migraphx_ctx, const value& solution); + +struct MIGRAPHX_GPU_EXPORT mlir_code_object +{ + code_object_op cop; + std::vector prefill_indices = {}; + std::vector prefill_values = {}; +}; + +MIGRAPHX_GPU_EXPORT mlir_code_object compile_mlir(const context& migraphx_ctx, + module m, + const std::vector& in_shapes, + const value& solution); MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m, instruction_ref ins, diff --git a/src/targets/gpu/include/migraphx/gpu/prepare_reduce.hpp b/src/targets/gpu/include/migraphx/gpu/prepare_reduce.hpp new file mode 100644 index 00000000000..3c6bfdd42f8 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/prepare_reduce.hpp @@ -0,0 +1,47 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_GPU_PREPARE_REDUCE_HPP +#define MIGRAPHX_GUARD_GPU_PREPARE_REDUCE_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +namespace gpu { + +struct prepare_reduce +{ + std::string name() const { return "gpu::prepare_reduce"; } + void apply(module& m) const; +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_PREPARE_REDUCE_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/rocblas.hpp b/src/targets/gpu/include/migraphx/gpu/rocblas.hpp index e72666e25ae..d23c40f9d05 100644 --- a/src/targets/gpu/include/migraphx/gpu/rocblas.hpp +++ b/src/targets/gpu/include/migraphx/gpu/rocblas.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,17 +25,20 @@ #define MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP #include #include +#if MIGRAPHX_USE_ROCBLAS #include +#endif namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { +#if MIGRAPHX_USE_ROCBLAS using rocblas_handle_ptr = MIGRAPHX_MANAGE_PTR(rocblas_handle, rocblas_destroy_handle); rocblas_handle_ptr create_rocblas_handle_ptr(); rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s); - +#endif struct context; MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag(); diff --git a/src/targets/gpu/include/migraphx/gpu/time_op.hpp b/src/targets/gpu/include/migraphx/gpu/time_op.hpp index 69a4767afcf..2c5893eed2f 100644 --- a/src/targets/gpu/include/migraphx/gpu/time_op.hpp +++ b/src/targets/gpu/include/migraphx/gpu/time_op.hpp @@ -24,6 +24,7 @@ #ifndef MIGRAPHX_GUARD_GPU_DRIVER_PERF_HPP #define MIGRAPHX_GUARD_GPU_DRIVER_PERF_HPP +#include #include #include #include @@ -33,10 +34,12 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { MIGRAPHX_GPU_EXPORT double -time_op(context& ictx, operation op, const std::vector& inputs, int n = 100); +time_op(const context& ictx, operation op, const std::vector& inputs, int n = 100); + +MIGRAPHX_GPU_EXPORT double time_program(const context& ictx, program p, int n = 100); /* benchmark gpu::code_object with expected input shapes over n iterations */ -MIGRAPHX_GPU_EXPORT double time_op(context& ictx, operation op, int n = 100); +MIGRAPHX_GPU_EXPORT double time_op(const context& ictx, operation op, int n = 100); } // namespace gpu } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index e319cb5c716..656c28d2ec9 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -21,15 +21,50 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include + +#include +#include +#include +#include #include +#include +#include #include +#include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { +static module create_pointwise_module(module_ref in_mod) +{ + module pw_mod; + std::unordered_map map_ins; + for(auto param : in_mod->get_parameters()) + { + map_ins[param] = + pw_mod.add_parameter(any_cast(param->get_operator()).parameter, + shape{param->get_shape().type()}); + } + auto return_args = pw_mod.add_instructions( + in_mod, + &map_ins, + [](module& m, + instruction_ref ins, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) -> instruction_ref { + if(op.name() == "multibroadcast" and inputs.front()->name() == "@literal") + return inputs.front(); + else + return m.insert_instruction(ins, op, inputs, mod_args); + }); + pw_mod.add_return(return_args); + return pw_mod; +} + struct mlir_compiler : compiler { std::vector names() const { return {"gpu::mlir_op"}; } @@ -37,23 +72,127 @@ struct mlir_compiler : compiler operation compile_op(context&, const std::vector&, const value&) const { return {}; } compiler_replace - compile(const context& ctx, instruction_ref ins, const operation&, const value& solution) const + compile(context& ctx, instruction_ref ins, const operation&, const value& solution) const { auto* smod = ins->module_inputs().front(); assert(smod->get_parameter_names().size() == ins->inputs().size() - 1); - return insert(compile_mlir(ctx, *smod, ins->inputs(), solution)); + auto gemm_ins = std::find_if(smod->begin(), smod->end(), [&](const auto& i) { + return i.name() == "dot" or i.name() == "quant_dot"; + }); + // check if (a) module is fused (b) contains a dot instruction and (c) perfConfig can not + // allow fused module + if(gemm_ins != smod->end() and std::distance(gemm_ins, smod->end()) > 2 and + not is_module_fusible(*smod, ctx, solution)) + { + auto input_args = ins->inputs(); + input_args.pop_back(); + auto mod_splits = smod->split(input_args, {gemm_ins}); + auto dot_mlir_inputs = to_shapes(mod_splits[0].inputs); + dot_mlir_inputs.push_back(mod_splits[0].mod.get_output_shapes().front()); + mlir_code_object cop1 = compile_mlir(ctx, mod_splits[0].mod, dot_mlir_inputs, solution); + auto pw_inputs = mod_splits[1].inputs; + auto dot_ins_idx = std::distance( + std::find(pw_inputs.begin(), pw_inputs.end(), gemm_ins), pw_inputs.begin()); + auto pw_shapes = to_shapes(mod_splits[1].inputs); + pw_shapes[dot_ins_idx] = cop1.cop.output; + pw_shapes.push_back(mod_splits[1].mod.get_output_shapes().front()); + assert(pw_shapes.back() == ins->get_shape()); + auto pw_mod = create_pointwise_module(&mod_splits[1].mod); + auto cop2 = compile_pointwise(ctx, pw_shapes, &pw_mod); + std::vector cops = {cop1, + mlir_code_object{any_cast(cop2)}}; + return insert(cops, mod_splits, ins, gemm_ins); + } + return insert(compile_mlir(ctx, *smod, to_shapes(ins->inputs()), solution)); } - compiler_replace insert(code_object_op co) const + compiler_replace insert(const mlir_code_object& mco) const { - return {std::vector{std::move(co)}, - [](module& m, instruction_ref ins, const std::vector& op) { - auto mlir = - insert_mlir(m, ins, any_cast(op.front()), ins->inputs()); - m.replace_instruction(ins, mlir); + return {std::vector{mco.cop}, + [=](module& m, instruction_ref ins, const std::vector& ops) { + std::vector inputs = ins->inputs(); + for(const auto i : range(mco.prefill_indices.size())) + { + auto prefilled_ins = m.insert_instruction( + ins, + migraphx::make_op("hip::fill", {{"value", mco.prefill_values[i]}}), + inputs[mco.prefill_indices[i]]); + replace(inputs, inputs[mco.prefill_indices[i]], prefilled_ins); + } + auto mlir = insert_mlir(m, ins, any_cast(ops.front()), inputs); + return m.replace_instruction(ins, mlir); }}; } + compiler_replace insert(const std::vector& mcos, + const std::array& mods, + instruction_ref precompile_ins, + instruction_ref split_ins) const + { + std::vector cobjs(mcos.size()); + std::transform( + mcos.begin(), mcos.end(), cobjs.begin(), [](const auto& mco) { return mco.cop; }); + return { + cobjs, [=](module& m, instruction_ref ins, const std::vector& ops) { + auto compiled_inputs = ins->inputs(); + auto precompiled_inputs = precompile_ins->inputs(); + std::unordered_map inputs_rep_map; + for(const auto i : range(precompiled_inputs.size())) + { + inputs_rep_map[precompiled_inputs[i]] = compiled_inputs[i]; + } + auto dot_inputs = mods[0].inputs; + auto dot_mod_out_shape = mods[0].mod.get_output_shapes().front(); + auto dot_alloc = m.insert_instruction( + ins, + migraphx::make_op("hip::allocate", {{"shape", to_value(dot_mod_out_shape)}})); + dot_inputs.push_back(dot_alloc); + for(const auto i : range(mcos[0].prefill_indices.size())) + { + auto prefilled_ins = m.insert_instruction( + ins, + migraphx::make_op("hip::fill", {{"value", mcos[0].prefill_values[i]}}), + dot_inputs[mcos[0].prefill_indices[i]]); + replace(dot_inputs, dot_inputs[mcos[0].prefill_indices[i]], prefilled_ins); + } + + std::vector dot_inputs_updated; + std::transform(dot_inputs.begin(), + dot_inputs.end(), + std::back_inserter(dot_inputs_updated), + [&](const auto& i) { + if(inputs_rep_map.find(i) != inputs_rep_map.end()) + { + assert(inputs_rep_map.at(i)->get_shape() == i->get_shape()); + return inputs_rep_map.at(i); + } + return i; + }); + auto mlir = + insert_mlir(m, ins, any_cast(ops[0]), dot_inputs_updated); + assert(contains(mods[1].inputs, split_ins)); + auto pwm = mods[1]; + pwm.replace(split_ins, mlir); + auto pw_inputs = pwm.inputs; + pw_inputs.push_back(ins->inputs().back()); + std::vector pw_inputs_updated; + std::transform(pw_inputs.begin(), + pw_inputs.end(), + std::back_inserter(pw_inputs_updated), + [&](const auto& i) { + if(inputs_rep_map.find(i) != inputs_rep_map.end()) + { + assert(inputs_rep_map.at(i)->get_shape() == i->get_shape()); + return inputs_rep_map.at(i); + } + return i; + }); + auto pw_ins = + insert_mlir(m, ins, any_cast(ops[1]), pw_inputs_updated); + return m.replace_instruction(ins, pw_ins); + }}; + } + optional get_tuning_config(const context& ctx, instruction_ref ins, const operation&, diff --git a/src/targets/gpu/jit/pointwise.cpp b/src/targets/gpu/jit/pointwise.cpp index bc0e7580cdf..42beb6de070 100644 --- a/src/targets/gpu/jit/pointwise.cpp +++ b/src/targets/gpu/jit/pointwise.cpp @@ -21,12 +21,13 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include #include #include #include #include #include -#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -47,7 +48,7 @@ extern "C" { MIGRAPHX_GLOBAL void ${kernel}(${params}) { auto idx = make_index(); - pointwise(idx, ${transformers})(${lambda}, ${args}); + pointwise<${noutputs}>(idx, ${transformers})(${lambda}, ${args}); } } @@ -70,22 +71,25 @@ struct pointwise_compiler : compiler operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { hip_compile_options options; - options.inputs = inputs; + options.inputs = flatten(inputs); options.output = inputs.back(); - options.virtual_inputs = reduce_dims(normalize_permutation(inputs)); + options.virtual_inputs = reduce_dims(normalize_permutation(options.inputs)); options.emplace_param("-Wno-float-equal"); auto axis = find_fast_axis(options.virtual_inputs); auto vec = vectorize::elements(ctx, axis, options.virtual_inputs); options.kernel_name = v.get("kernel", "kernel"); options.set_launch_params( - v, compute_global_for(ctx, options.output.elements() / vec.size, 256)); - auto src = interpolate_string(pointwise_kernel, - {{"kernel", options.kernel_name}, - {"params", enum_params(inputs.size(), "void * private_p")}, - {"args", enum_params(inputs.size(), "private_p")}, - {"lambda", v.at("lambda").to()}, - {"transformers", make_transformer_args(vec)}, - {"preamble", v.get("preamble", std::string{})}}); + v, compute_global_for(ctx, options.inputs.front().elements() / vec.size, 256)); + auto noutputs = options.inputs.size() - inputs.size() + 1; + auto src = + interpolate_string(pointwise_kernel, + {{"kernel", options.kernel_name}, + {"params", enum_params(options.inputs.size(), "void * private_p")}, + {"args", enum_params(options.inputs.size(), "private_p")}, + {"lambda", v.at("lambda").to()}, + {"transformers", make_transformer_args(vec)}, + {"noutputs", std::to_string(noutputs)}, + {"preamble", v.get("preamble", std::string{})}}); return compile_hip_code_object(src, options); } @@ -93,21 +97,16 @@ struct pointwise_compiler : compiler { if(contains({"layout", "contiguous"}, op.name())) { - return compile_op( - ctx, - to_shapes(ins->inputs()), - {{"lambda", "[](auto x) { return x; }"}, {"kernel", op.name() + "_kernel"}}); + return compile_op(ctx, + to_shapes(ins->inputs()), + {{"lambda", "[](auto x) { return make_tuple(x); }"}, + {"kernel", op.name() + "_kernel"}}); } else { assert(not ins->module_inputs().empty()); - auto* pm = ins->module_inputs().front(); - auto pf = generate_pointwise(*pm, "inner_pointwise"); - std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)"; - auto kernel_name = generate_name_from_ops(*pm, "kernel"); - return compile_op(ctx, - to_shapes(ins->inputs()), - {{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}}); + const_module_ref pm = ins->module_inputs().front(); + return compile_pointwise(ctx, to_shapes(ins->inputs()), pm); } } }; diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 0dc8e34b855..25773abc115 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -308,13 +308,21 @@ struct fused_reduce_compiler : compiler { std::vector names() const { return {"fused_reduce", "split_fused_reduce"}; } + static shape get_input_shape(const std::vector& inputs) + { + auto it = std::max_element(inputs.begin(), + inputs.end(), + by(std::less<>{}, [](const shape& s) { return s.elements(); })); + return *it; + } + operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { auto assign = v.get("assign", "assign_none"); auto axes = v.at("axes").to_vector(); auto virtual_inputs = inputs; - virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes)); - virtual_inputs.push_back(get_output_shape(inputs.front(), axes)); + virtual_inputs.push_back(get_reduced_shape(get_input_shape(inputs), axes)); + virtual_inputs.push_back(get_output_shape(get_input_shape(inputs), axes)); virtual_inputs = reduce_dims(normalize_permutation(virtual_inputs)); if(assign != "assign_none") virtual_inputs = split_reduce(virtual_inputs); diff --git a/src/targets/gpu/kernel.cpp b/src/targets/gpu/kernel.cpp index 1cbb45852b1..19e8b1dd2e5 100644 --- a/src/targets/gpu/kernel.cpp +++ b/src/targets/gpu/kernel.cpp @@ -28,20 +28,20 @@ #include // extern declare the function since hip/hip_ext.h header is broken -extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT - uint32_t, - uint32_t, - uint32_t, - uint32_t, - uint32_t, - uint32_t, - size_t, - hipStream_t, - void**, - void**, - hipEvent_t = nullptr, - hipEvent_t = nullptr, - uint32_t = 0); +extern "C" hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + size_t, + hipStream_t, + void**, + void**, + hipEvent_t = nullptr, + hipEvent_t = nullptr, + uint32_t = 0); namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/array.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/array.hpp index 0aa086c3352..de9a82c4eb0 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/array.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/array.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -123,6 +123,22 @@ struct array { using value_type = T; T d[N]; + + constexpr array() = default; + + template {} and ...))> + constexpr array(Ts... xs) : d{xs...} + { + } + + template {} and (N > 1))> + constexpr explicit array(U x) + { + for(index_int i = 0; i < N; i++) + d[i] = x; + } + constexpr T& operator[](index_int i) { MIGRAPHX_ASSERT(i < N); @@ -260,6 +276,12 @@ struct array } }; +template +constexpr auto array_apply(F f) +{ + return [=](auto&& x) { return x.apply(f); }; +} + template constexpr array make_array(T x, Ts... xs) { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index 019350c54e0..fab865c0587 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -291,16 +291,39 @@ inline constexpr auto transform_args() return make_transform([](auto f, auto... xs) { return f(xs...); }); } -// Rotate the first argument to the last argument -inline constexpr auto rotate_last() +// Rotate the last N arguments to the first N arguments +template +constexpr auto rotate_last() { return make_transform([](auto f, auto... xs) { return sequence_c([&](auto... is) { constexpr auto size = sizeof...(is); - return f(arg_c<(is + size - 1) % size>()(xs...)...); + return f(arg_c<(is + size - N) % size>()(xs...)...); }); }); } +inline constexpr auto rotate_last() { return rotate_last<1>(); } + +// Pack the first N arguments +template +constexpr auto pack_first() +{ + return make_transform([](auto f, auto... xs) { + return sequence_c([&](auto... is) { + return sequence_c([&](auto... js) { + return f(pack(arg_c()(xs...)...), arg_c()(xs...)...); + }); + }); + }); +} + +// Rotate the last N arguments as the first argument packed +template +constexpr auto rotate_and_pack_last() +{ + return transform_args(rotate_last(), pack_first()); +} + } // namespace migraphx #endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp index b52a61eb498..c64ab553159 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,18 @@ namespace migraphx { +template +struct acc_type +{ + using type = float; +}; + +template <> +struct acc_type +{ + using type = double; +}; + template constexpr auto vec_reduce(const array& a, Op op) { @@ -50,33 +62,33 @@ __device__ void generic_binary_layernorm( using reduce_output = reduce::with_axis; block::template run([&](auto, auto r) { - auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); - using value_type = typename Input1::type; - using vec_value_type = vec_type; + using value_type = typename Input1::type; + using vec_value_type = typename acc_type>::type; + + auto input = r.inner([&](auto x1, auto x2) { + return migraphx::convert(op(x1, x2)); + })(input1, input2); + constexpr auto relements = r.template elements(); constexpr auto relements_r = vec_value_type{1.0 / relements}; auto relements_rsqrt = sqrt(relements_r); - auto means = r.reduce(op::sum{}, - make_array(vec_value_type{0}, vec_value_type{0}), - [&](auto x) { - auto x_out = x * relements_r; - // dividing x by sqrt(relements) before squaring allows computing - // higher values before overflow in low precision - auto x2_sqrt = x * relements_rsqrt; - return make_array(x_out, x2_sqrt * x2_sqrt); - })(input); + auto means = r.reduce(op::sum{}, make_array(0, 0), [&](auto x) { + auto x_out = x * relements_r; + // dividing x by sqrt(relements) before squaring allows computing + // higher values before overflow in low precision + auto x2_sqrt = x * relements_rsqrt; + return make_array(x_out, x2_sqrt * x2_sqrt); + })(input); - auto mean_x = means[0]; - auto mean_x2 = means[1]; - auto variance = mean_x2 - (mean_x * mean_x); - value_type eps_val = implicit_conversion(eps); + auto mean_x = means[0]; + auto mean_x2 = means[1]; + auto variance = mean_x2 - (mean_x * mean_x); + vec_value_type eps_val = implicit_conversion(eps); + auto rsqrt_val = rsqrt(variance + eps_val); r.inner([&](auto& y, auto x, auto... xs) { - auto m = x - mean_x; - - // m * rsqrt(mean(m ^ 2) + epsilon) - y = compute(m * rsqrt(variance + eps_val), xs...); + y = compute(migraphx::convert>((x - mean_x) * rsqrt_val), xs...); })(output, input, inputs...); }); } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp index 4b5f9fc865c..e7dc2fd845e 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,21 +30,29 @@ #include #include #include +#include namespace migraphx { -template -__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) +template +__device__ void pointwise_tensor(index idx, F f, Output out, T x, Ts... xs) { - idx.global_stride(out.get_shape().elements(), - [&](auto i) { out[i] = implicit_conversion(f(xs[i]...)); }); + idx.global_stride(x.get_shape().elements(), [&](auto i) { + auto r = f(x[i], xs[i]...); + out([&](auto... outs) { + r([&](auto... rs) { + static_assert(sizeof...(outs) == sizeof...(rs)); + swallow{(outs[i] = implicit_conversion(rs))...}; + }); + }); + }); } -template +template __device__ auto pointwise(index idx, Transforms... transforms) { return [=](auto f, auto*... ps) { - auto t = transform_args(make_tensors(), rotate_last(), transforms...); + auto t = transform_args(make_tensors(), transforms..., rotate_and_pack_last()); t(ps...)([&](auto... xs) { pointwise_tensor(idx, f, xs...); }); }; } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/print.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/print.hpp index 8d197570ce2..a1242453516 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/print.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/print.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -215,42 +215,55 @@ inline __device__ auto coutln() return make_printer([](auto f) { f(); }, [] { printf("\n"); }); } -template -__device__ void print_each(F f, Ts... xs) +template +__device__ void unsafe_print_each(Stream s, T x, Ts... xs) { - each_args([&](auto x) { f() << x; }, xs...); + s << x; + each_args([&](auto xx) { s << ' ' << xx; }, xs...); } -template -__device__ void print_each_once(F f, Ts... xs) +template +__device__ void print_each(Stream s, Ts... xs) +{ + auto idx = make_index(); + for(auto i = 0; i < idx.nglobal(); i++) + { + if(i == idx.global) + unsafe_print_each(s, xs...); + __syncthreads(); + } +} + +template +__device__ void print_each_once(Stream s, Ts... xs) { auto idx = make_index(); if(idx.global == 0) - print_each(f, xs...); + unsafe_print_each(s, xs...); } template __device__ void print(Ts... xs) { - print_each(&cout, xs...); + print_each(cout(), xs...); } template __device__ void print_once(Ts... xs) { - print_each_once(&cout, xs...); + print_each_once(cout(), xs...); } template __device__ void println(Ts... xs) { - print_each(&cout, xs..., '\n'); + print_each(cout(), xs..., '\n'); } template __device__ void println_once(Ts... xs) { - print_each_once(&cout, xs..., '\n'); + print_each_once(cout(), xs..., '\n'); } } // namespace migraphx diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index b2fb0f4b00f..5689e180030 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -312,6 +312,18 @@ constexpr auto compute_reduce_axis() return make_shape(lens, get_shape_c{}.strides); } +template +constexpr auto final_reduce(T x, F f) +{ + return vec_reduce(x, f); +} + +template +constexpr auto final_reduce(array a, F f) +{ + return a.apply([&](auto x) { return final_reduce(x, f); }); +} + template using with_axis = decltype(compute_reduce_axis()); @@ -455,7 +467,7 @@ struct block __device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const { return block_reduce(idx, op, init, n, [&](auto j, auto d) { - return vec_reduce(read(xs(j, d)...), op); + return final_reduce(read(xs(j, d)...), op); }); } @@ -512,7 +524,7 @@ struct block_large __device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const { return block_reduce(idx, op, init, index_int{n}, [&](auto j, auto d) { - return vec_reduce(read(xs(j, d)...), op); + return final_reduce(read(xs(j, d)...), op); }); } @@ -585,7 +597,7 @@ struct subwave __device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const { return subwave_reduce(idx, op, init, n, [&](auto j, auto d) { - return vec_reduce(read(xs(j, d)...), op); + return final_reduce(read(xs(j, d)...), op); }); } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/tuple.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/tuple.hpp new file mode 100644 index 00000000000..eceefa4714f --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/tuple.hpp @@ -0,0 +1,164 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_KERNELS_TUPLE_HPP +#define MIGRAPHX_GUARD_KERNELS_TUPLE_HPP + +#include + +namespace migraphx { + +namespace tuple_detail { + +template +struct element_storage +{ + [[no_unique_address]] T element; +}; + +template +constexpr const auto& get_element(const element_storage& x) +{ + return x.element; +} + +template +constexpr auto& get_element(element_storage& x) +{ + return x.element; +} + +template +struct tuple_storage; + +template +struct tuple_storage, Ts...> : element_storage... +{ + template + constexpr tuple_storage(Us... ys) : element_storage{ys}... + { + } + + template + constexpr auto operator()(F f) const + { + return f(static_cast&>(*this).element...); + } + + template + constexpr auto operator()(F f) + { + return f(static_cast&>(*this).element...); + } + + template + constexpr auto& operator[](IntegralConstant i) + { + static_assert(i < sizeof...(Ts), "Out of bounds tuple access"); + return get_element(*this); + } + + template + constexpr auto& operator[](IntegralConstant i) const + { + static_assert(i < sizeof...(Ts), "Out of bounds tuple access"); + return get_element(*this); + } + + constexpr index_constant size() const { return {}; } + constexpr auto empty() const { return size() == _c<0>; } +}; + +template +using tuple_base = tuple_detail::tuple_storage::type, Ts...>; + +} // namespace tuple_detail + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_TUPLE_OP(op, binary_op) \ + template \ + constexpr tuple& operator op(const tuple& rhs) \ + { \ + (*this)( \ + [&](auto&... xs) { rhs([&](const auto&... ys) { swallow{((xs op ys), 0)...}; }); }); \ + return *this; \ + } \ + template \ + friend constexpr auto operator binary_op(const tuple& lhs, const tuple& rhs) \ + { \ + using result = tuple() binary_op declval())...>; \ + return lhs([&](auto&... xs) { \ + return rhs([&](const auto&... ys) { return result{xs op ys...}; }); \ + }); \ + } + +template +struct tuple : tuple_detail::tuple_base +{ + using base = tuple_detail::tuple_base; + + template + constexpr tuple(Us... ys) : base(ys...) + { + } + + MIGRAPHX_DEVICE_TUPLE_OP(+=, +) + MIGRAPHX_DEVICE_TUPLE_OP(-=, -) + MIGRAPHX_DEVICE_TUPLE_OP(*=, *) + MIGRAPHX_DEVICE_TUPLE_OP(/=, /) + MIGRAPHX_DEVICE_TUPLE_OP(%=, %) + MIGRAPHX_DEVICE_TUPLE_OP(&=, &) + MIGRAPHX_DEVICE_TUPLE_OP(|=, |) + MIGRAPHX_DEVICE_TUPLE_OP(^=, ^) + + friend constexpr bool operator==(const tuple& x, const tuple& y) + { + return x([&](const auto&... xs) { + return y([&](const auto&... ys) { return ((xs == ys) and ...); }); + }); + } + friend constexpr bool operator!=(const tuple& x, const tuple& y) { return not(x == y); } + friend constexpr bool operator<(const tuple& x, const tuple& y) + { + return x([&](const auto&... xs) { + return y([&](const auto&... ys) { + fold([&](auto a, auto b) { return a == 0 ? b() : 0; })(0, [&] { + return (xs < ys) ? -1 : (ys < xs) ? 1 : 0; + }...); + }); + }); + } + friend constexpr bool operator>(const tuple& x, const tuple& y) { return y < x; } + friend constexpr bool operator<=(const tuple& x, const tuple& y) { return not(x > y); } + friend constexpr bool operator>=(const tuple& x, const tuple& y) { return not(x < y); } +}; + +template +constexpr tuple make_tuple(Ts... xs) +{ + return {xs...}; +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_TUPLE_HPP diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 51d924f6a5f..6c87d8fdc45 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -82,8 +82,9 @@ struct miopen_apply { assert(mod != nullptr); assert(pass != nullptr); - +#if MIGRAPHX_USE_ROCBLAS compute_fp32 = get_compute_fp32_flag(); +#endif offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false; add_generic_op("contiguous"); @@ -104,8 +105,10 @@ struct miopen_apply add_convolution_op("convolution"); add_convolution_op("convolution_backwards"); add_convolution_op("quant_convolution"); +#if MIGRAPHX_USE_ROCBLAS add_gemm_op("dot"); add_gemm_op("quant_dot"); +#endif add_if_op(); add_loop_op(); add_neg_op(); @@ -232,6 +235,7 @@ struct miopen_apply return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}})); } +#if MIGRAPHX_USE_ROCBLAS template void add_gemm_op(const std::string& name) { @@ -243,6 +247,7 @@ struct miopen_apply return mod->replace_instruction(ins, rocblas_gemm{Op{}, 1, 0, compute_fp32}, refs); }); } +#endif void add_convolution_op(const std::string& name) { diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 76f18563fd5..d12e9b56a2d 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #ifdef MIGRAPHX_MLIR @@ -594,7 +595,8 @@ struct mlir_program {"sym_name", sym_name}, {"kernel", std::string("mixr")}, {"arch", target_arch}, - {"num_cu", num_cu}}); + {"num_cu", num_cu}, + {"enable_splitk_for_tuning", true}}); ops.add_region(std::move(region)); insert(body, std::move(ops)); @@ -945,6 +947,15 @@ struct mlir_program std::string sym_name; }; +bool is_module_fusible(const module& m, const context& migraphx_ctx, const value& solution) +{ + mlir_program mp; + mp.set_gpu_properties(migraphx_ctx); + mp.parse(m); + mp.run_high_level_pipeline(); + return mlirIsModuleFusible(mp.mmodule.get(), make_mlir_string_ref(*solution.if_string())); +} + std::string dump_mlir(const module& m) { mlir_program mp; @@ -971,12 +982,12 @@ void adjust_param_shapes(module& m, const std::vector& inputs) } } -code_object_op compile_mlir(const context& migraphx_ctx, - module m, - const std::vector& inputs, - const value& solution) +mlir_code_object compile_mlir(const context& migraphx_ctx, + module m, + const std::vector& in_shapes, + const value& solution) { - adjust_param_shapes(m, to_shapes(inputs)); + adjust_param_shapes(m, in_shapes); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); static std::mutex mutex; @@ -987,6 +998,7 @@ code_object_op compile_mlir(const context& migraphx_ctx, } mlir_program mp; + mp.set_gpu_properties(migraphx_ctx); mp.parse(m); auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); @@ -996,9 +1008,33 @@ code_object_op compile_mlir(const context& migraphx_ctx, std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; } auto co = mp.compile(solution); - co.expected_inputs = to_shapes(inputs); + + co.expected_inputs = in_shapes; co.output = m.get_output_shapes().front(); - return co; + mlir_code_object mco; + mco.cop = co; + size_t num_prefill_args = mlirGetNumPrefillArgs(mp.mmodule.get()); + if(num_prefill_args > 0) + { + std::vector prefill_indices(num_prefill_args); + std::vector prefill_mlir_values(num_prefill_args); + mlirGetPrefillArgsInfo( + mp.mmodule.get(), prefill_indices.data(), prefill_mlir_values.data(), num_prefill_args); + std::vector prefill_values(prefill_mlir_values.size()); + std::transform(prefill_mlir_values.begin(), + prefill_mlir_values.end(), + prefill_values.begin(), + [](const auto& v) { + // mlir sets fill attribute as float but migx hip::fill operator only + // supports integer type. + // TODO: Need to add checks that it is indeed an integer. + double dv = mlirFloatAttrGetValueDouble(v); + return static_cast(dv); + }); + mco.prefill_indices = prefill_indices; + mco.prefill_values = prefill_values; + } + return mco; } instruction_ref insert_mlir(module& m, @@ -1050,8 +1086,7 @@ void use(T&) // Disabling clang-tidy warning on non-real useage. // NOLINTBEGIN(performance-unnecessary-value-param) -code_object_op -compile_mlir(const context&, module, const std::vector&, const value&) +mlir_code_object compile_mlir(const context&, module, const std::vector&, const value&) { return {}; } diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index f16b02d59eb..f06c8c1775d 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -38,6 +38,8 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_LAYERNORM_FUSION); + namespace { template @@ -241,9 +243,12 @@ struct find_gemm_softmax_gemm void prefuse_ops::apply(module_pass_manager& mpm) const { - match::find_matches(mpm.get_module(), find_layernorm{}); - mpm.run_pass(dead_code_elimination{}); - match::find_matches(mpm.get_module(), find_add_layernorm{}); + if(not enabled(MIGRAPHX_DISABLE_LAYERNORM_FUSION{})) + { + match::find_matches(mpm.get_module(), find_layernorm{}); + mpm.run_pass(dead_code_elimination{}); + match::find_matches(mpm.get_module(), find_add_layernorm{}); + } match::find_matches(mpm, find_gemm_softmax_gemm{enable_attention}); } diff --git a/src/targets/gpu/prepare_reduce.cpp b/src/targets/gpu/prepare_reduce.cpp new file mode 100644 index 00000000000..ebb1752eeed --- /dev/null +++ b/src/targets/gpu/prepare_reduce.cpp @@ -0,0 +1,134 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct parallel_reduce +{ + operation op; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.op, "op")); + } + + std::string name() const { return "gpu::parallel_reduce"; } + + shape compute_shape(const std::vector& inputs) const + { + std::vector result; + std::transform(inputs.begin(), inputs.end(), std::back_inserter(result), [&](auto input) { + return op.compute_shape({input}); + }); + return shape{result}; + } +}; +MIGRAPHX_REGISTER_OP(parallel_reduce); + +namespace { + +std::vector find_reduce(module& m) +{ + std::vector result; + auto im = iterator_for(m); + std::copy_if(im.begin(), im.end(), std::back_inserter(result), [](auto ins) { + if(contains({"gpu::parallel_reduce", "reduce_mean"}, ins->name())) + return false; + return contains(ins->name(), "reduce"); + }); + return result; +} + +bool reaches(instruction_ref start, instruction_ref end) +{ + std::unordered_set visited; + return fix([&](auto self, auto ins) -> bool { + if(ins == start) + return true; + if(not visited.insert(ins).second) + return false; + return std::any_of(ins->inputs().begin(), ins->inputs().end(), self); + })(end); +} + +std::vector find_parallel_reduce(const std::vector& r) +{ + std::vector result; + auto ir = iterator_for(r); + transform_if( + ir.begin(), + ir.end(), + std::back_inserter(result), + [&](auto x) { + return std::none_of( + std::next(x), r.end(), [&](auto reduce) { return reaches(*x, reduce); }); + }, + [](auto x) { return *x; }); + return result; +} + +void fuse_reductions(module& m) +{ + auto rs = find_parallel_reduce(find_reduce(m)); + if(rs.size() < 2) + return; + // Only handle the same reduction operator for now + if(std::any_of(std::next(rs.begin()), rs.end(), [&](auto r) { + return rs.front()->name() != r->name(); + })) + return; + auto last = rs.front(); + auto op = last->get_operator(); + std::vector inputs; + std::transform(rs.begin(), rs.end(), std::back_inserter(inputs), [&](auto r) { + return r->inputs().front(); + }); + auto pr = m.insert_instruction(last, parallel_reduce{op}, inputs); + int i = 0; + for(auto r : rs) + { + m.replace_instruction(r, make_op("get_tuple_elem", {{"index", i}}), pr); + i++; + } + m.sort(); +} + +} // namespace + +void prepare_reduce::apply(module& m) const { fuse_reductions(m); } + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/rocblas.cpp b/src/targets/gpu/rocblas.cpp index 1b37f08e1ed..798fefbb811 100644 --- a/src/targets/gpu/rocblas.cpp +++ b/src/targets/gpu/rocblas.cpp @@ -32,7 +32,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { - +#if MIGRAPHX_USE_ROCBLAS rocblas_handle_ptr create_rocblas_handle_ptr() { // add a call to rocblas_initialize() to workaround a rocblas bug SWDEV-438929 @@ -48,7 +48,7 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s) rocblas_set_stream(rb.get(), s); return rb; } - +#endif bool get_compute_fp32_flag() { const auto device_name = trim(split_string(get_device_name(), ':').front()); @@ -57,11 +57,15 @@ bool get_compute_fp32_flag() bool rocblas_fp8_available() { +#if MIGRAPHX_USE_ROCBLAS #ifndef MIGRAPHX_USE_ROCBLAS_FP8_API return false; #else return gfx_has_fp8_intrinsics(); #endif +#else + return false; +#endif } } // namespace gpu diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 47a45b967d3..4a18e25aab5 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -32,8 +32,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -77,7 +76,6 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SPLIT_REDUCE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) #ifndef _WIN32 @@ -133,8 +131,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti simplify_qdq{}, enable_pass(not mlir_enabled(), rewrite_quantization{}), dead_code_elimination{}, - // workaround for rocBLAS unsupported error when using uint8 in quant_dot & quant_convolution - eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_convolution", "quant_dot"}}, + // workaround for rocBLAS unsupported error when using uint8 in quant_dot, quant_convolution & pooling + eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_convolution", "quant_dot", "pooling"}}, eliminate_data_type{unsupported_types, shape::type_t::float_type}, simplify_reshapes{}, eliminate_identity{}, @@ -160,10 +158,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti rewrite_low_precision{}, dead_code_elimination{}, optimize_module{}, - fuse_pointwise{}, - dead_code_elimination{}, - enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), - dead_code_elimination{}, + fuse_pointwise_reduce{}, enable_pass(enabled(MIGRAPHX_ENABLE_SPLIT_REDUCE{}), split_reduce{}), dead_code_elimination{}, fuse_concat{}, diff --git a/src/targets/gpu/time_op.cpp b/src/targets/gpu/time_op.cpp index 51459b64da5..5321bc9d775 100644 --- a/src/targets/gpu/time_op.cpp +++ b/src/targets/gpu/time_op.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include #include #include #include @@ -41,35 +42,58 @@ std::vector generate_arguments(const std::vector& shapes, unsig return args; } -double time_op(context& ictx, operation op, const std::vector& inputs, int n) +template +double time_loop(migraphx::gpu::context& gctx, int n, F f) { - // TODO: Use std::ref - migraphx::context ctx = ictx; - auto& gctx = any_cast(ctx); - auto output = op.compute_shape(inputs); - op.finalize(ctx, output, inputs); - auto args = generate_arguments(inputs); auto start = context::create_event_for_timing(); auto stop = context::create_event_for_timing(); - auto run = [&] { op.compute(ctx, output, args); }; - run(); + f(); gctx.get_stream().record(start.get()); for(auto i : range(n)) { (void)i; - run(); + f(); } gctx.get_stream().record(stop.get()); gctx.finish(); return context::get_elapsed_ms(start.get(), stop.get()) / n; } -double time_op(context& ictx, operation op, int n) +double time_op(const context& ictx, operation op, const std::vector& inputs, int n) +{ + // TODO: Use std::ref + migraphx::context ctx = ictx; + auto& gctx = any_cast(ctx); + auto output = op.compute_shape(inputs); + op.finalize(ctx, output, inputs); + auto args = generate_arguments(inputs); + auto run = [&] { op.compute(ctx, output, args); }; + return time_loop(gctx, n, run); +} + +double time_op(const context& ictx, operation op, int n) { auto inputs = any_cast(op).expected_inputs; return time_op(ictx, op, inputs, n); } +double time_program(const context& ictx, program p, int n) +{ + std::vector ctx_vec = {ictx}; + auto& gctx = any_cast(ctx_vec.front()); + auto* mm = p.get_main_module(); + mm->finalize(ctx_vec); + auto in_shapes = p.get_parameter_shapes(); + std::unordered_map param_map; + unsigned long seed = 0; + for(const auto& [name, shape] : in_shapes) + { + param_map[name] = to_gpu(generate_argument(shape, seed++)); + } + auto run = [&] { p.eval_with_context(ctx_vec, param_map); }; + return time_loop(gctx, n, run); +} + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/tf/tf.cpp b/src/tf/tf.cpp index e5b5cdff055..7b6c1322d66 100644 --- a/src/tf/tf.cpp +++ b/src/tf/tf.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,9 +37,9 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -program parse_tf(const std::string& name, const tf_options& options) +template +program parse_tf_from(const tf_options& options, Ts&&... xs) { - std::fstream input(name.c_str(), std::ios::in | std::ios::binary); tf::tf_parser parser; parser.is_nhwc = options.is_nhwc; parser.batch_size = options.batch_size; @@ -50,7 +50,7 @@ program parse_tf(const std::string& name, const tf_options& options) // Log the program when it can't be parsed try { - parser.parse_from(input); + parser.parse_from(std::forward(xs)...); } catch(...) { @@ -58,11 +58,27 @@ program parse_tf(const std::string& name, const tf_options& options) throw; } #else - parser.parse_from(input); + parser.parse_from(std::forward(xs)...); #endif return std::move(parser.prog); } +program parse_tf(const std::string& name, const tf_options& options) +{ + std::fstream input(name.c_str(), std::ios::in | std::ios::binary); + return parse_tf_from(options, input); +} + +program parse_tf_buffer(const std::string& buffer, const tf_options& options) +{ + return parse_tf_from(options, buffer.data(), buffer.size()); +} + +program parse_tf_buffer(const void* data, std::size_t size, const tf_options& options) +{ + return parse_tf_from(options, data, size); +} + std::vector get_tf_operators() { return tf::get_op_parsers(); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/tf/tf_parser.cpp b/src/tf/tf_parser.cpp index 6d95c34d988..a53c7b29ec1 100644 --- a/src/tf/tf_parser.cpp +++ b/src/tf/tf_parser.cpp @@ -393,6 +393,19 @@ void tf_parser::parse_from(std::istream& is) } } +void tf_parser::parse_from(const void* data, std::size_t size) +{ + tensorflow::GraphDef graph; + if(graph.ParseFromArray(data, size)) + { + this->parse_graph(graph); + } + else + { + throw std::runtime_error("Failed reading tf buffer array"); + } +} + shape::type_t tf_parser::parse_type(const tensorflow::DataType t) const { shape::type_t shape_type{}; diff --git a/tools/check_stamped.py b/tools/check_stamped.py index 16dde757ea0..793f60d1ac1 100644 --- a/tools/check_stamped.py +++ b/tools/check_stamped.py @@ -30,7 +30,7 @@ # in the license stamp, with the assumption being that any modifications/creations will need to be stamped to the year that the # modification/creation was made. ##################################################################################### -import subprocess, sys, datetime +import subprocess, sys, datetime, argparse debug = False @@ -111,14 +111,15 @@ def check_filename(filename: str, fileTuple: tuple or list) -> bool: return False -def main() -> None: +def main(branch) -> None: unsupported_file_types.extend(specificIgnores) ## Get a list of all files (not including deleted) that have changed/added in comparison to the latest Dev branch from MI Graphx # Subprocess 1 is fetching the latest dev branch from the MIgraphX Url and naming it as 'FETCH_HEAD' subprocess.run( - "git fetch https://github.com/ROCmSoftwarePlatform/AMDMIGraphX develop --quiet", + "git fetch https://github.com/ROCmSoftwarePlatform/AMDMIGraphX {0} --quiet" + .format(branch), shell=True, stdout=subprocess.PIPE) @@ -153,7 +154,7 @@ def main() -> None: elif len(stampedFilesWithBadYear) > 0: print( - f"\nError: The licenses for the following {str(len(stampedFilesWithBadYear))} file(s) either... do not match the year of commit, have a different copyright format or have not been synced from the latest develop branch:\n{str(stampedFilesWithBadYear)}\nThere is a license_stamper script (./tools/license_stamper.py), which you can run to automatically update and add any needed license stamps" + f"\nError: The licenses for the following {str(len(stampedFilesWithBadYear))} file(s) either... do not match the year of commit, have a different copyright format or have not been synced from the latest {branch} branch:\n{str(stampedFilesWithBadYear)}\nThere is a license_stamper script (./tools/license_stamper.py), which you can run to automatically update and add any needed license stamps" ) sys.exit(1) @@ -168,4 +169,9 @@ def main() -> None: if __name__ == "__main__": - main() + + parser = argparse.ArgumentParser() + parser.add_argument("branch") + args = parser.parse_args() + + main(args.branch) diff --git a/tools/docker/migraphx_with_onnxruntime_pytorch.docker b/tools/docker/migraphx_with_onnxruntime_pytorch.docker index c4b630da3a0..6b1eb0c5c23 100644 --- a/tools/docker/migraphx_with_onnxruntime_pytorch.docker +++ b/tools/docker/migraphx_with_onnxruntime_pytorch.docker @@ -6,19 +6,20 @@ ARG PREFIX=/usr/local RUN apt update && apt install -y wget #Aquire and install ROCm -RUN wget https://repo.radeon.com/amdgpu-install/6.0.2/ubuntu/jammy/amdgpu-install_6.0.60002-1_all.deb -RUN apt install -y ./amdgpu-install_6.0.60002-1_all.deb -RUN amdgpu-install --usecase=rocm -y && rm amdgpu-install_6.0.60002-1_all.deb +RUN wget https://repo.radeon.com/amdgpu-install/6.1/ubuntu/jammy/amdgpu-install_6.1.60100-1_all.deb +RUN apt install -y ./amdgpu-install_6.1.60100-1_all.deb +RUN amdgpu-install --usecase=rocm -y && rm amdgpu-install_6.1.60100-1_all.deb #Install MIGraphX from package manager RUN apt install -y migraphx #Pieces for Onnxruntime for ROCm and MIGraphX Execution Provider Support -RUN pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.0.2/onnxruntime_rocm-inference-1.17.0-cp310-cp310-linux_x86_64.whl +RUN pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.1/onnxruntime_rocm-inference-1.17.0-cp310-cp310-linux_x86_64.whl #Pieces for pytorch -RUN pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.0/torch-2.1.1+rocm6.0-cp310-cp310-linux_x86_64.whl -RUN pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.0/torchvision-0.16.1+rocm6.0-cp310-cp310-linux_x86_64.whl +RUN pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.1/pytorch_triton_rocm-2.1.0%2Brocm6.1.4d510c3a44-cp310-cp310-linux_x86_64.whl +RUN pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.1/torch-2.1.2+rocm6.1-cp310-cp310-linux_x86_64.whl +RUN pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.1/torchvision-0.16.1+rocm6.1-cp310-cp310-linux_x86_64.whl #Adjust final path for ability to use rocm components ENV PATH=$PATH:/opt/rocm/bin/ diff --git a/tools/docker/sles.docker b/tools/docker/sles.docker index 5006593a500..b38645a2d95 100644 --- a/tools/docker/sles.docker +++ b/tools/docker/sles.docker @@ -5,7 +5,7 @@ RUN sh -c 'echo -e "\ name=rocm\n\ baseurl=https://repo.radeon.com/rocm/zyp/6.0.2/main\n\ enabled=1\n\ -gpgcheck=1\n\ +gpgcheck=0\n\ gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n\ " > /etc/zypp/repos.d/rocm.repo' diff --git a/tools/docker/ubuntu_2204.dockerfile b/tools/docker/ubuntu_2204.dockerfile index 906b256eebe..15afe6d56d7 100644 --- a/tools/docker/ubuntu_2204.dockerfile +++ b/tools/docker/ubuntu_2204.dockerfile @@ -36,7 +36,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- software-properties-common \ wget \ rocm-device-libs \ - hip-base \ + hip-dev \ libnuma-dev \ miopen-hip \ rocblas \ diff --git a/tools/format.py b/tools/format.py index f04b59b3609..ad55450b2a0 100644 --- a/tools/format.py +++ b/tools/format.py @@ -63,8 +63,9 @@ def get_merge_base(branch): def get_files_changed(against, ext=('.py')): - files = eval(f"git diff-index --cached --name-only {against}", - cwd=get_top()).splitlines() + files = eval( + f"git diff-index --cached --name-only --diff-filter=d {against}", + cwd=get_top()).splitlines() return (f for f in files if f.endswith(ext) and not is_excluded(f))