Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Triton Auto-tuning in XLA #81

Merged

Conversation

zoranjovanovic-ns
Copy link

No description provided.

@@ -565,6 +607,9 @@ ENTRY e {

// TODO(b/344770374): Make this test not fragile.
TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) {
if (isRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
const std::string kHloText = R"(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm...does triton rocm auotune not have register spliing prevention?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that test case does not trigger register spilling on ROCm.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yes, that rings a bell.....I remember you mentioned it and we had discussion before long while ago....! it's more about test case itself. But do we know how does triton triggers the register spilling on ROCm?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need to create corresponding test case for ROCm, but did not had time yet to focus on it.

@@ -758,6 +803,9 @@ ENTRY main {
}

TEST_F(GemmFusionAutotunerDumpTest, DumpingWorks) {
if (isRocm()) {
GTEST_SKIP() << "cuBLAS not selected on ROCM.";
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this not fallback to rocblas or hipblaslt?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, on ROCm Triton is selected, it seems that the difference between rocblas and Triton is small and and ROCm Triton is selected.

Copy link

@i-chaochen i-chaochen Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it's designed for gemm auotuner fallback to cublas when triton gemm is not good enough. But after I checked gemm_fusion_autotuner.cc, correct me if I'm wrong, this fallback flag is for some partcilar cases which is not related to gemm fusion autotuner over cudnn, triton, cublas and custom kernel.

except this one, it only used here as well

TEST_P(FloatNormalizationTest, Fp8Normalization) {
// TODO(b/344573710) Make this test not require a GPU when AutotuneCacheKey is
// more stable.
const PrimitiveType lhs_type = GetParam().first;
const PrimitiveType rhs_type = GetParam().second;
const std::string lhs_name =
primitive_util::LowercasePrimitiveTypeName(lhs_type);
const std::string rhs_name =
primitive_util::LowercasePrimitiveTypeName(rhs_type);
const std::string module_str = absl::Substitute(R"(
HloModule sch
ENTRY main {
parameter = $0[1600,1600]{1,0} parameter(0)
parameter.1 = $1[1600,1600]{1,0} parameter(1)
neg = $1[1600,1600]{1,0} negate(parameter.1)
dot = f16[1600,1600]{1,0} dot(parameter,neg), lhs_contracting_dims={1}, rhs_contracting_dims={0}
constant = f16[] constant(0)
broadcast = f16[1600,1600]{1,0} broadcast(constant), dimensions={}
ROOT maximum = f16[1600,1600]{1,0} maximum(dot,broadcast)
})",
lhs_name, rhs_name);
auto optimize_module = [&](bool enable_triton, bool enable_blas,
bool enable_blas_fallback)
-> absl::StatusOr<std::unique_ptr<HloModule>> {
HloModuleConfig config;
DebugOptions debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_cublas_fallback(enable_blas_fallback);

@@ -1148,6 +1202,9 @@ TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) {
}

TEST_F(GemmFusionAutotunerTest, GeneratesConfigForUpcastGemmWithPrologue) {
if (isRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
Copy link

@i-chaochen i-chaochen Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may I ask why GeneratesConfigForUpcastGemmWithPrologue) and GeneratesConfigForUpcastGemmWithPrologueAndEpilogue are not supported? is it because expecting CustomKernelFusionConfig ? If so, could you mention it in the comment?

Copy link

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

thanks!

@i-chaochen i-chaochen merged commit f55fc88 into rocm-jaxlib-v0.4.35-qa Jan 8, 2025
6 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants