Skip to content

Commit

Permalink
[XLA:GPU] Clean up flags (and uses of) `--xla_gpu_enable_bf16_{3,6}wa…
Browse files Browse the repository at this point in the history
…y_gemm`.

Those are no longer necessary now that algorithms can be requested explicitly.

PiperOrigin-RevId: 717885033
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Jan 22, 2025
1 parent 28a75be commit 15af0b3
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 222 deletions.
157 changes: 0 additions & 157 deletions xla/backends/gpu/codegen/triton/dot_algorithms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,6 @@ class Triton6xBF16GemmTest : public AlgorithmTest {
}
};

// In these tests, we depend on debug option flags for selecting the 6XBF16
// algorithm.
// TODO(b/379905071): Remove this class and the --xla_gpu_enable_bf16_6way_gemm
// flag after we will support the algorithm values through the entire stack.
class Triton6xBF16GemmTestWithFlag : public AlgorithmTest {
public:
DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest();
// Do not fall back to cuBLAS, we are testing Triton.
debug_options.set_xla_gpu_cublas_fallback(false);
// Do not autotune split-k by default, since this prevents deterministically
// matching the optimized HLO.
debug_options.set_xla_gpu_enable_split_k_autotuning(false);
// Enable bf16_6way gemm to compute F32 matmul.
debug_options.set_xla_gpu_enable_bf16_6way_gemm(true);
return debug_options;
}
};

class BlasAlgorithmTest : public AlgorithmTest {
public:
DebugOptions GetDebugOptionsForTest() const override {
Expand Down Expand Up @@ -518,47 +499,6 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf
/*arel=*/1e-6}));
}

TEST_F(Triton6xBF16GemmTestWithFlag, Emit6xBF16GemmWhenBothInputsAreF32) {
constexpr absl::string_view kHloText = R"(
HloModule Emit6xBF16GemmWhenBothInputsAreF32
triton_dot {
p0 = f32[5,7] parameter(0)
p1 = f32[7,33] parameter(1)
ROOT dot = f32[5,33] dot(p0, p1),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
ENTRY e {
p0 = f32[5,7]{1,0} parameter(0)
p1 = f32[7,33]{1,0} parameter(1)
ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot,
backend_config={"fusion_backend_config": {kind: "__triton_gemm",
triton_gemm_config:
{"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}}
}
)";
TF_ASSERT_OK(
CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"(
CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32>
CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32>
CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32>
CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32>
CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32>
CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16>
CHECK-COUNT-5: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32>
CHECK: %[[ABS:.*]] = math.absf
CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32>
CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32>
CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32>
CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32>
)"));

EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-6,
/*arel=*/1e-6}));
}

TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForLongContractingDimension) {
constexpr absl::string_view kHloText = R"(
HloModule Triton6xBF16GemmWorksForLongContractingDimension
Expand Down Expand Up @@ -635,34 +575,6 @@ class Triton3xBF16GemmTest : public AlgorithmTest {
}
};

// In these tests, we depend on debug option flags for selecting the 3XBF16
// algorithm.
// TODO(b/379905071): Remove this class and the --xla_gpu_enable_bf16_3way_gemm
// flag after we will support the algorithm values through the entire stack.
class Triton3xBF16GemmTestWithFlag : public AlgorithmTest {
public:
DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest();
// Enable triton fusion for all supported GEMMs.
debug_options.set_xla_gpu_triton_gemm_any(true);
// Do not fall back to cuBLAS, we are testing Triton.
debug_options.set_xla_gpu_cublas_fallback(false);
// Do not autotune split-k by default, since this prevents deterministically
// matching the optimized HLO.
debug_options.set_xla_gpu_enable_split_k_autotuning(false);
// Enable bf16_3way gemm to compute F32 matmul.
debug_options.set_xla_gpu_enable_bf16_3way_gemm(true);
return debug_options;
}

protected:
void SetUp() override {
if (!SupportsBF16(GpuComputeComp())) {
GTEST_SKIP() << "BF16 not supported.";
}
}
};

TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmWhenBothInputsAreF32) {
constexpr absl::string_view kHloText = R"(
HloModule Emit3xBF16GemmWhenBothInputsAreF32
Expand Down Expand Up @@ -705,75 +617,6 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf
/*arel=*/1e-5}));
}

TEST_F(Triton3xBF16GemmTestWithFlag, Emit3xBF16GemmWhenBothInputsAreF32) {
constexpr absl::string_view kHloText = R"(
HloModule Emit3xBF16GemmWhenBothInputsAreF32
triton_dot {
p0 = f32[5,7] parameter(0)
p1 = f32[7,33] parameter(1)
ROOT dot = f32[5,33] dot(p0, p1),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
ENTRY e {
p0 = f32[5,7]{1,0} parameter(0)
p1 = f32[7,33]{1,0} parameter(1)
ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot,
backend_config={"fusion_backend_config": {kind: "__triton_gemm",
triton_gemm_config:
{"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}}
}
)";
TF_ASSERT_OK(
CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"(
CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32>
CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32>
CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32>
CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32>
CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32>
CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16>
CHECK-COUNT-2: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32>
CHECK: %[[ABS:.*]] = math.absf
CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32>
CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32>
CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32>
CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32>
)"));

EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5,
/*arel=*/1e-5}));
}

TEST_F(Triton3xBF16GemmTestWithFlag, NoEmit3xBF16GemmWhenBothInputsAreNotF32) {
constexpr absl::string_view kHloText = R"(
HloModule NoEmit3xBF16GemmWhenBothInputsAreNotF32
triton_dot {
p0 = f16[5,7] parameter(0)
p1 = f16[7,33] parameter(1)
ROOT dot = f16[5,33] dot(p0, p1),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
ENTRY e {
p0 = f16[5,7]{1,0} parameter(0)
p1 = f16[7,33]{1,0} parameter(1)
ROOT _ = f16[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot,
backend_config={"fusion_backend_config": {kind: "__triton_gemm",
triton_gemm_config:
{"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}}
}
)";
TF_ASSERT_OK(
CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"(
CHECK: tt.dot
CHECK-SAME: tensor<32x32xf16> * tensor<32x32xf16> -> tensor<32x32xf32>
CHECK-NOT: tt.dot
)"));
}

TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForLongContractingDimension) {
constexpr absl::string_view kHloText = R"(
HloModule Triton3xBF16GemmWorksForLongContractingDimension
Expand Down
31 changes: 0 additions & 31 deletions xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1664,18 +1664,6 @@ bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr,
const PrecisionConfig::Algorithm algorithm =
dot_instr->precision_config().algorithm();

if (algorithm == PrecisionConfig::ALG_UNSET) {
const HloModule* hlo_module = dot_instr->GetModule();
Type f32 = b.getF32Type();
return hlo_module->config()
.debug_options()
.xla_gpu_enable_bf16_6way_gemm() &&
mlir::cast<ShapedType>(dot_input_lhs.getType()).getElementType() ==
f32 &&
mlir::cast<ShapedType>(dot_input_rhs.getType()).getElementType() ==
f32;
}

return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6;
}

Expand All @@ -1686,18 +1674,6 @@ bool Is3xBfloat16MatMul(const HloDotInstruction* dot_instr,
const PrecisionConfig::Algorithm algorithm =
dot_instr->precision_config().algorithm();

if (algorithm == PrecisionConfig::ALG_UNSET) {
const HloModule* hlo_module = dot_instr->GetModule();
Type f32 = b.getF32Type();
return hlo_module->config()
.debug_options()
.xla_gpu_enable_bf16_3way_gemm() &&
mlir::cast<ShapedType>(dot_input_lhs.getType()).getElementType() ==
f32 &&
mlir::cast<ShapedType>(dot_input_rhs.getType()).getElementType() ==
f32;
}

return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3;
}

Expand Down Expand Up @@ -2105,13 +2081,6 @@ absl::Status EmitMatMul(EmitterLocOpBuilder& b,
return;
}

const HloModule* hlo_module = dot_instr->GetModule();
if (hlo_module->config().debug_options().xla_gpu_enable_bf16_3way_gemm() &&
hlo_module->config().debug_options().xla_gpu_enable_bf16_6way_gemm()) {
LOG(WARNING) << "Both BF16 6way gemm and 3way gemm are enabled."
<< " Fallback to BF16 6way gemm.";
}

Value accumulator_next;
if (Is6xBfloat16MatMul(dot_instr, b, dot_input_lhs, dot_input_rhs,
device_info)) {
Expand Down
20 changes: 4 additions & 16 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_collective_permute_decomposer_threshold(
std::numeric_limits<int64_t>::max());
opts.set_xla_gpu_enable_experimental_pipeline_parallelism_opt(false);
opts.set_xla_gpu_experimental_enable_pipeline_parallelism_opt(false);

opts.set_xla_cpu_enable_mlir_tiling_and_fusion(true);
opts.set_xla_cpu_enable_custom_matmul_tiling(false);
Expand Down Expand Up @@ -268,8 +268,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
stream_executor::IsLibNvPtxCompilerSupported());
opts.set_xla_gpu_libnvjitlink_mode(DebugOptions::LIB_NV_JIT_LINK_MODE_AUTO);

opts.set_xla_gpu_enable_bf16_6way_gemm(false);
opts.set_xla_gpu_enable_bf16_3way_gemm(false);
opts.set_xla_gpu_nccl_collective_max_nchannels(0);
opts.set_xla_gpu_nccl_p2p_max_nchannels(0);
opts.set_xla_gpu_multi_streamed_windowed_einsum(true);
Expand Down Expand Up @@ -1713,11 +1711,11 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_collective_permute_decomposer_threshold(),
"Collective permute decomposer threshold."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_experimental_pipeline_parallelism_opt",
"xla_gpu_experimental_enable_pipeline_parallelism_opt",
bool_setter_for(
&DebugOptions::
set_xla_gpu_enable_experimental_pipeline_parallelism_opt),
debug_options->xla_gpu_enable_experimental_pipeline_parallelism_opt(),
set_xla_gpu_experimental_enable_pipeline_parallelism_opt),
debug_options->xla_gpu_experimental_enable_pipeline_parallelism_opt(),
"Experimental optimizations for SPMD-based pipeline parallelism on "
"GPU."));
flag_list->push_back(tsl::Flag(
Expand Down Expand Up @@ -1993,16 +1991,6 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
flag_list->push_back(tsl::Flag("xla_gpu_enable_dot_strength_reduction",
noop_flag_setter<bool>, true,
"[Deprecated, do not use]"));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_bf16_6way_gemm",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_bf16_6way_gemm),
debug_options->xla_gpu_enable_bf16_6way_gemm(),
"Use BF16 6way gemm to compute F32 gemm."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_bf16_3way_gemm",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_bf16_3way_gemm),
debug_options->xla_gpu_enable_bf16_3way_gemm(),
"Use BF16 3way gemm to compute F32 gemm."));
flag_list->push_back(
tsl::Flag("xla_gpu_nccl_collective_max_nchannels",
int64_setter_for(
Expand Down
6 changes: 3 additions & 3 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ absl::Status RunCollectiveOptimizationPasses(

if (hlo_module->config()
.debug_options()
.xla_gpu_enable_experimental_pipeline_parallelism_opt()) {
.xla_gpu_experimental_enable_pipeline_parallelism_opt()) {
collectives_pipeline.AddPass<CollectiveSelectFolder>();
}

Expand All @@ -971,7 +971,7 @@ absl::Status RunCollectiveOptimizationPasses(
collectives_pipeline,
hlo_module->config()
.debug_options()
.xla_gpu_enable_experimental_pipeline_parallelism_opt());
.xla_gpu_experimental_enable_pipeline_parallelism_opt());
}

// Run algebraic simplifier to reshape(broadcast) into a broadcast when
Expand Down Expand Up @@ -2669,7 +2669,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(

if (!module->config()
.debug_options()
.xla_gpu_enable_experimental_pipeline_parallelism_opt() &&
.xla_gpu_experimental_enable_pipeline_parallelism_opt() &&
(module->config()
.debug_options()
.xla_gpu_enable_pipelined_collectives() ||
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/gpu_latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,13 @@ void GpuAsyncTrackerBase::PostProcessScheduleGraph(
// Schedule partially pipelined send/recv instructions late so that they can
// overlap with compute. Schedule send/recv late and, when unblocked,
// schedule send-done/recv-done early.
if (debug_options.xla_gpu_enable_experimental_pipeline_parallelism_opt() &&
if (debug_options.xla_gpu_experimental_enable_pipeline_parallelism_opt() &&
IsPartiallyPipelinedSendRecv(inst)) {
HloGraphNode& node = schedule_graph->GetNode(inst);
node.SetForceDelay(true);
VLOG(5) << "Setting force delay for instruction: " << inst->ToString();
}
if (debug_options.xla_gpu_enable_experimental_pipeline_parallelism_opt() &&
if (debug_options.xla_gpu_experimental_enable_pipeline_parallelism_opt() &&
IsPartiallyPipelinedSendRecvDone(inst)) {
HloGraphNode& node = schedule_graph->GetNode(inst);
node.SetForceEarly(true);
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/gpu_latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class GpuLatencyHidingSchedulerBaseTest : public HloTestBase {
HloModuleConfig config;
DebugOptions debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_enable_latency_hiding_scheduler(true);
debug_options.set_xla_gpu_enable_experimental_pipeline_parallelism_opt(
debug_options.set_xla_gpu_experimental_enable_pipeline_parallelism_opt(
enable_experimental_pipeline_parallelism_opt);
config.set_debug_options(debug_options);
config.set_fdo_profile(fdo_profile);
Expand Down
2 changes: 1 addition & 1 deletion xla/tests/collective_pipeline_parallelism_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CollectivePipelineParallelismTest

// Set debug options.
DebugOptions debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_enable_experimental_pipeline_parallelism_opt(
debug_options.set_xla_gpu_experimental_enable_pipeline_parallelism_opt(
GetParam());
config.set_debug_options(debug_options);

Expand Down
17 changes: 6 additions & 11 deletions xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,6 @@ message DebugOptions {

bool xla_gpu_enable_approx_costly_collectives = 305;

// If enabled, uses bf16_3way gemm to compute F32 gemm.
bool xla_gpu_enable_bf16_3way_gemm = 279;

// If enabled, uses bf16_6way gemm to compute F32 gemm.
bool xla_gpu_enable_bf16_6way_gemm = 271;

// Determine the types of commands that are recorded into command buffers.
repeated CommandBufferCmdType xla_gpu_enable_command_buffer = 258;

Expand Down Expand Up @@ -386,10 +380,6 @@ message DebugOptions {
// dynamic-update-slice operations around library calls.
bool xla_gpu_enable_dynamic_slice_fusion = 105;

// Experimental optimizations for SPMD-based pipeline parallelism on GPU.
// TODO(bchetioui): adjust this name to follow the naming convention.
bool xla_gpu_enable_experimental_pipeline_parallelism_opt = 351;

// When true we lower the Minimum and Maximum hlos in the GPU backend such
// that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag
// this is true we don't propagate NaNs through Min and Max.
Expand Down Expand Up @@ -513,6 +503,9 @@ message DebugOptions {
// Pre-existing block-level fusions are left unmodified.
bool xla_gpu_experimental_enable_fusion_block_level_rewriter = 334;

// Experimental optimizations for SPMD-based pipeline parallelism on GPU.
bool xla_gpu_experimental_enable_pipeline_parallelism_opt = 351;

// When enabled, the PriorityFusion pass will try to make Triton fusions first
// and foremost where it is possible.
//
Expand Down Expand Up @@ -1152,8 +1145,10 @@ message DebugOptions {
// xla_gpu_enable_heuristic_pass_configuration
// xla_gpu_enable_dot_strength_reduction
// xla_gpu_triton_fusion_level
// xla_gpu_enable_bf16_3way_gemm
// xla_gpu_enable_bf16_6way_gemm
reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 221, 242, 206, 320,
325, 326, 332, 361, 270, 229;
325, 326, 332, 361, 270, 229, 271, 279;
}

// Contains flags which affects the GPU compilation result.
Expand Down

0 comments on commit 15af0b3

Please sign in to comment.