Skip to content

Commit

Permalink
Merge pull request #80 from ROCm/rocm-jaxlib-v0.4.35-qa-fix_launch_dims
Browse files Browse the repository at this point in the history
[ROCm] Fix kernel launch dimension
  • Loading branch information
hsharsha authored Jan 2, 2025
2 parents 2e3b40e + bb2d621 commit e38372e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ cc_library(
deps = [
"//xla:shape_util",
"//xla:util",
"//xla/service:platform_util",
"//xla/stream_executor:device_description",
"//xla/stream_executor:launch_dim",
"@com_google_absl//absl/log",
Expand Down
15 changes: 12 additions & 3 deletions xla/service/gpu/launch_dimensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <cstdint>

#include "xla/service/platform_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
Expand All @@ -37,11 +38,19 @@ LaunchDimensions CalculateLaunchDimensions(
num_elements = CeilOfRatio(num_elements, int64_t{dim_config.unroll_factor});

const int kWarpSchedulers = 4;
int64_t threads_per_block = std::min<int64_t>(
int64_t threads_per_block_x = std::min<int64_t>(
gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements);
int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block);
int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block_x);
CHECK(num_blocks < gpu_device_info.block_dim_limit().x);
int threads_per_block_y = 1;
if (xla::PlatformUtil::CanonicalPlatformName("gpu").value() == "rocm") {
while ((num_blocks * threads_per_block_x) > std::numeric_limits<uint32_t>::max()) {
threads_per_block_x /= 2;
threads_per_block_y *= 2;
}
}
return LaunchDimensions(se::BlockDim(num_blocks, 1, 1),
se::ThreadDim(threads_per_block, 1, 1));
se::ThreadDim(threads_per_block_x, threads_per_block_y, 1));
}

} // namespace gpu
Expand Down

0 comments on commit e38372e

Please sign in to comment.