-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Triton Auto-tuning in XLA #81
Enable Triton Auto-tuning in XLA #81
Conversation
@@ -565,6 +607,9 @@ ENTRY e { | |||
|
|||
// TODO(b/344770374): Make this test not fragile. | |||
TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) { | |||
if (isRocm()) { | |||
GTEST_SKIP() << "Not supported on ROCm."; | |||
} | |||
const std::string kHloText = R"( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm...does triton rocm auotune not have register spliing prevention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that test case does not trigger register spilling on ROCm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, yes, that rings a bell.....I remember you mentioned it and we had discussion before long while ago....! it's more about test case itself. But do we know how does triton triggers the register spilling on ROCm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably need to create corresponding test case for ROCm, but did not had time yet to focus on it.
@@ -758,6 +803,9 @@ ENTRY main { | |||
} | |||
|
|||
TEST_F(GemmFusionAutotunerDumpTest, DumpingWorks) { | |||
if (isRocm()) { | |||
GTEST_SKIP() << "cuBLAS not selected on ROCM."; | |||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this not fallback to rocblas or hipblaslt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember correctly, on ROCm Triton is selected, it seems that the difference between rocblas and Triton is small and and ROCm Triton is selected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it's designed for gemm auotuner fallback to cublas when triton gemm is not good enough. But after I checked gemm_fusion_autotuner.cc, correct me if I'm wrong, this fallback flag is for some partcilar cases which is not related to gemm fusion autotuner over cudnn, triton, cublas and custom kernel.
except this one, it only used here as well
xla/xla/service/gpu/gpu_compiler_test.cc
Lines 547 to 575 in 2bd6f7e
TEST_P(FloatNormalizationTest, Fp8Normalization) { | |
// TODO(b/344573710) Make this test not require a GPU when AutotuneCacheKey is | |
// more stable. | |
const PrimitiveType lhs_type = GetParam().first; | |
const PrimitiveType rhs_type = GetParam().second; | |
const std::string lhs_name = | |
primitive_util::LowercasePrimitiveTypeName(lhs_type); | |
const std::string rhs_name = | |
primitive_util::LowercasePrimitiveTypeName(rhs_type); | |
const std::string module_str = absl::Substitute(R"( | |
HloModule sch | |
ENTRY main { | |
parameter = $0[1600,1600]{1,0} parameter(0) | |
parameter.1 = $1[1600,1600]{1,0} parameter(1) | |
neg = $1[1600,1600]{1,0} negate(parameter.1) | |
dot = f16[1600,1600]{1,0} dot(parameter,neg), lhs_contracting_dims={1}, rhs_contracting_dims={0} | |
constant = f16[] constant(0) | |
broadcast = f16[1600,1600]{1,0} broadcast(constant), dimensions={} | |
ROOT maximum = f16[1600,1600]{1,0} maximum(dot,broadcast) | |
})", | |
lhs_name, rhs_name); | |
auto optimize_module = [&](bool enable_triton, bool enable_blas, | |
bool enable_blas_fallback) | |
-> absl::StatusOr<std::unique_ptr<HloModule>> { | |
HloModuleConfig config; | |
DebugOptions debug_options = GetDebugOptionsForTest(); | |
debug_options.set_xla_gpu_cublas_fallback(enable_blas_fallback); |
@@ -1148,6 +1202,9 @@ TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) { | |||
} | |||
|
|||
TEST_F(GemmFusionAutotunerTest, GeneratesConfigForUpcastGemmWithPrologue) { | |||
if (isRocm()) { | |||
GTEST_SKIP() << "Not supported on ROCm."; | |||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may I ask why GeneratesConfigForUpcastGemmWithPrologue)
and GeneratesConfigForUpcastGemmWithPrologueAndEpilogue
are not supported? is it because expecting CustomKernelFusionConfig
? If so, could you mention it in the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
thanks!
No description provided.