Skip to content

Commit

Permalink
[ROCm] Fix kernel launch dimension
Browse files Browse the repository at this point in the history
Launch dimension should be of the form
((block.x, 1, 1), (thread.x, thready, 1)) to accommodate checks in
(parallel_loop_emitter.cc)[https://github.com/openxla/xla/blob/main/xla/service/gpu/parallel_loop_emitter.cc#L169-L171]
  • Loading branch information
hsharsha committed Dec 13, 2024
1 parent 2e3b40e commit bb2d621
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ cc_library(
deps = [
"//xla:shape_util",
"//xla:util",
"//xla/service:platform_util",
"//xla/stream_executor:device_description",
"//xla/stream_executor:launch_dim",
"@com_google_absl//absl/log",
Expand Down
15 changes: 12 additions & 3 deletions xla/service/gpu/launch_dimensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <cstdint>

#include "xla/service/platform_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
Expand All @@ -37,11 +38,19 @@ LaunchDimensions CalculateLaunchDimensions(
num_elements = CeilOfRatio(num_elements, int64_t{dim_config.unroll_factor});

const int kWarpSchedulers = 4;
int64_t threads_per_block = std::min<int64_t>(
int64_t threads_per_block_x = std::min<int64_t>(
gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements);
int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block);
int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block_x);
CHECK(num_blocks < gpu_device_info.block_dim_limit().x);
int threads_per_block_y = 1;
if (xla::PlatformUtil::CanonicalPlatformName("gpu").value() == "rocm") {
while ((num_blocks * threads_per_block_x) > std::numeric_limits<uint32_t>::max()) {
threads_per_block_x /= 2;
threads_per_block_y *= 2;
}
}
return LaunchDimensions(se::BlockDim(num_blocks, 1, 1),
se::ThreadDim(threads_per_block, 1, 1));
se::ThreadDim(threads_per_block_x, threads_per_block_y, 1));
}

} // namespace gpu
Expand Down

0 comments on commit bb2d621

Please sign in to comment.