From 15af0b3e4c18bbb702f02ada5103fc8ccfe71cc4 Mon Sep 17 00:00:00 2001
From: Benjamin Chetioui <bchetioui@google.com>
Date: Tue, 21 Jan 2025 06:11:55 -0800
Subject: [PATCH] [XLA:GPU] Clean up flags (and uses of)
 `--xla_gpu_enable_bf16_{3,6}way_gemm`.

Those are no longer necessary now that algorithms can be requested explicitly.

PiperOrigin-RevId: 717885033
---
 .../gpu/codegen/triton/dot_algorithms_test.cc | 157 ------------------
 .../triton/fusion_emitter_legacy_matmul.cc    |  31 ----
 xla/debug_options_flags.cc                    |  20 +--
 xla/service/gpu/gpu_compiler.cc               |   6 +-
 .../gpu/gpu_latency_hiding_scheduler.cc       |   4 +-
 .../gpu/gpu_latency_hiding_scheduler_test.cc  |   2 +-
 .../collective_pipeline_parallelism_test.cc   |   2 +-
 xla/xla.proto                                 |  17 +-
 8 files changed, 17 insertions(+), 222 deletions(-)

diff --git a/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc b/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc
index a64e1153ccf80..4a9d95c15fb31 100644
--- a/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc
+++ b/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc
@@ -137,25 +137,6 @@ class Triton6xBF16GemmTest : public AlgorithmTest {
   }
 };
 
-// In these tests, we depend on debug option flags for selecting the 6XBF16
-// algorithm.
-// TODO(b/379905071): Remove this class and the --xla_gpu_enable_bf16_6way_gemm
-// flag after we will support the algorithm values through the entire stack.
-class Triton6xBF16GemmTestWithFlag : public AlgorithmTest {
- public:
-  DebugOptions GetDebugOptionsForTest() const override {
-    DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest();
-    // Do not fall back to cuBLAS, we are testing Triton.
-    debug_options.set_xla_gpu_cublas_fallback(false);
-    // Do not autotune split-k by default, since this prevents deterministically
-    // matching the optimized HLO.
-    debug_options.set_xla_gpu_enable_split_k_autotuning(false);
-    // Enable bf16_6way gemm to compute F32 matmul.
-    debug_options.set_xla_gpu_enable_bf16_6way_gemm(true);
-    return debug_options;
-  }
-};
-
 class BlasAlgorithmTest : public AlgorithmTest {
  public:
   DebugOptions GetDebugOptionsForTest() const override {
@@ -518,47 +499,6 @@ CHECK:          %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf
                                                            /*arel=*/1e-6}));
 }
 
-TEST_F(Triton6xBF16GemmTestWithFlag, Emit6xBF16GemmWhenBothInputsAreF32) {
-  constexpr absl::string_view kHloText = R"(
-    HloModule Emit6xBF16GemmWhenBothInputsAreF32
-
-    triton_dot {
-      p0 = f32[5,7] parameter(0)
-      p1 = f32[7,33] parameter(1)
-      ROOT dot = f32[5,33] dot(p0, p1),
-        lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    }
-
-    ENTRY e {
-      p0 = f32[5,7]{1,0} parameter(0)
-      p1 = f32[7,33]{1,0} parameter(1)
-      ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot,
-        backend_config={"fusion_backend_config": {kind: "__triton_gemm",
-        triton_gemm_config:
-        {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}}
-    }
-  )";
-  TF_ASSERT_OK(
-      CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"(
-CHECK:          %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32>
-CHECK:          %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32>
-CHECK:          %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
-CHECK:          %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32>
-CHECK:          %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32>
-CHECK:          %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32>
-CHECK:          %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16>
-CHECK-COUNT-5:  %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32>
-CHECK:          %[[ABS:.*]] = math.absf
-CHECK:          %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32>
-CHECK:          %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32>
-CHECK:          %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32>
-CHECK:          %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32>
-    )"));
-
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-6,
-                                                           /*arel=*/1e-6}));
-}
-
 TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForLongContractingDimension) {
   constexpr absl::string_view kHloText = R"(
     HloModule Triton6xBF16GemmWorksForLongContractingDimension
@@ -635,34 +575,6 @@ class Triton3xBF16GemmTest : public AlgorithmTest {
   }
 };
 
-// In these tests, we depend on debug option flags for selecting the 3XBF16
-// algorithm.
-// TODO(b/379905071): Remove this class and the --xla_gpu_enable_bf16_3way_gemm
-// flag after we will support the algorithm values through the entire stack.
-class Triton3xBF16GemmTestWithFlag : public AlgorithmTest {
- public:
-  DebugOptions GetDebugOptionsForTest() const override {
-    DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest();
-    // Enable triton fusion for all supported GEMMs.
-    debug_options.set_xla_gpu_triton_gemm_any(true);
-    // Do not fall back to cuBLAS, we are testing Triton.
-    debug_options.set_xla_gpu_cublas_fallback(false);
-    // Do not autotune split-k by default, since this prevents deterministically
-    // matching the optimized HLO.
-    debug_options.set_xla_gpu_enable_split_k_autotuning(false);
-    // Enable bf16_3way gemm to compute F32 matmul.
-    debug_options.set_xla_gpu_enable_bf16_3way_gemm(true);
-    return debug_options;
-  }
-
- protected:
-  void SetUp() override {
-    if (!SupportsBF16(GpuComputeComp())) {
-      GTEST_SKIP() << "BF16 not supported.";
-    }
-  }
-};
-
 TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmWhenBothInputsAreF32) {
   constexpr absl::string_view kHloText = R"(
     HloModule Emit3xBF16GemmWhenBothInputsAreF32
@@ -705,75 +617,6 @@ CHECK:          %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf
                                                            /*arel=*/1e-5}));
 }
 
-TEST_F(Triton3xBF16GemmTestWithFlag, Emit3xBF16GemmWhenBothInputsAreF32) {
-  constexpr absl::string_view kHloText = R"(
-    HloModule Emit3xBF16GemmWhenBothInputsAreF32
-
-    triton_dot {
-      p0 = f32[5,7] parameter(0)
-      p1 = f32[7,33] parameter(1)
-      ROOT dot = f32[5,33] dot(p0, p1),
-        lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    }
-
-    ENTRY e {
-      p0 = f32[5,7]{1,0} parameter(0)
-      p1 = f32[7,33]{1,0} parameter(1)
-      ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot,
-        backend_config={"fusion_backend_config": {kind: "__triton_gemm",
-        triton_gemm_config:
-        {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}}
-    }
-  )";
-  TF_ASSERT_OK(
-      CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"(
-CHECK:          %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32>
-CHECK:          %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32>
-CHECK:          %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
-CHECK:          %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32>
-CHECK:          %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32>
-CHECK:          %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32>
-CHECK:          %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16>
-CHECK-COUNT-2:  %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32>
-CHECK:          %[[ABS:.*]] = math.absf
-CHECK:          %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32>
-CHECK:          %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32>
-CHECK:          %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32>
-CHECK:          %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32>
-    )"));
-
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5,
-                                                           /*arel=*/1e-5}));
-}
-
-TEST_F(Triton3xBF16GemmTestWithFlag, NoEmit3xBF16GemmWhenBothInputsAreNotF32) {
-  constexpr absl::string_view kHloText = R"(
-    HloModule NoEmit3xBF16GemmWhenBothInputsAreNotF32
-
-    triton_dot {
-      p0 = f16[5,7] parameter(0)
-      p1 = f16[7,33] parameter(1)
-      ROOT dot = f16[5,33] dot(p0, p1),
-        lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    }
-
-    ENTRY e {
-      p0 = f16[5,7]{1,0} parameter(0)
-      p1 = f16[7,33]{1,0} parameter(1)
-      ROOT _ = f16[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot,
-        backend_config={"fusion_backend_config": {kind: "__triton_gemm",
-        triton_gemm_config:
-        {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}}
-    }
-  )";
-  TF_ASSERT_OK(
-      CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"(
-CHECK:      tt.dot
-CHECK-SAME: tensor<32x32xf16> * tensor<32x32xf16> -> tensor<32x32xf32>
-CHECK-NOT:  tt.dot
-    )"));
-}
-
 TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForLongContractingDimension) {
   constexpr absl::string_view kHloText = R"(
     HloModule Triton3xBF16GemmWorksForLongContractingDimension
diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc b/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc
index bbfcf71f52e9a..5995000fc3df0 100644
--- a/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc
+++ b/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc
@@ -1664,18 +1664,6 @@ bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr,
   const PrecisionConfig::Algorithm algorithm =
       dot_instr->precision_config().algorithm();
 
-  if (algorithm == PrecisionConfig::ALG_UNSET) {
-    const HloModule* hlo_module = dot_instr->GetModule();
-    Type f32 = b.getF32Type();
-    return hlo_module->config()
-               .debug_options()
-               .xla_gpu_enable_bf16_6way_gemm() &&
-           mlir::cast<ShapedType>(dot_input_lhs.getType()).getElementType() ==
-               f32 &&
-           mlir::cast<ShapedType>(dot_input_rhs.getType()).getElementType() ==
-               f32;
-  }
-
   return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6;
 }
 
@@ -1686,18 +1674,6 @@ bool Is3xBfloat16MatMul(const HloDotInstruction* dot_instr,
   const PrecisionConfig::Algorithm algorithm =
       dot_instr->precision_config().algorithm();
 
-  if (algorithm == PrecisionConfig::ALG_UNSET) {
-    const HloModule* hlo_module = dot_instr->GetModule();
-    Type f32 = b.getF32Type();
-    return hlo_module->config()
-               .debug_options()
-               .xla_gpu_enable_bf16_3way_gemm() &&
-           mlir::cast<ShapedType>(dot_input_lhs.getType()).getElementType() ==
-               f32 &&
-           mlir::cast<ShapedType>(dot_input_rhs.getType()).getElementType() ==
-               f32;
-  }
-
   return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3;
 }
 
@@ -2105,13 +2081,6 @@ absl::Status EmitMatMul(EmitterLocOpBuilder& b,
       return;
     }
 
-    const HloModule* hlo_module = dot_instr->GetModule();
-    if (hlo_module->config().debug_options().xla_gpu_enable_bf16_3way_gemm() &&
-        hlo_module->config().debug_options().xla_gpu_enable_bf16_6way_gemm()) {
-      LOG(WARNING) << "Both BF16 6way gemm and 3way gemm are enabled."
-                   << " Fallback to BF16 6way gemm.";
-    }
-
     Value accumulator_next;
     if (Is6xBfloat16MatMul(dot_instr, b, dot_input_lhs, dot_input_rhs,
                            device_info)) {
diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc
index 0b803fc49f589..aed4c49d5e530 100644
--- a/xla/debug_options_flags.cc
+++ b/xla/debug_options_flags.cc
@@ -205,7 +205,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
 
   opts.set_xla_gpu_collective_permute_decomposer_threshold(
       std::numeric_limits<int64_t>::max());
-  opts.set_xla_gpu_enable_experimental_pipeline_parallelism_opt(false);
+  opts.set_xla_gpu_experimental_enable_pipeline_parallelism_opt(false);
 
   opts.set_xla_cpu_enable_mlir_tiling_and_fusion(true);
   opts.set_xla_cpu_enable_custom_matmul_tiling(false);
@@ -268,8 +268,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
       stream_executor::IsLibNvPtxCompilerSupported());
   opts.set_xla_gpu_libnvjitlink_mode(DebugOptions::LIB_NV_JIT_LINK_MODE_AUTO);
 
-  opts.set_xla_gpu_enable_bf16_6way_gemm(false);
-  opts.set_xla_gpu_enable_bf16_3way_gemm(false);
   opts.set_xla_gpu_nccl_collective_max_nchannels(0);
   opts.set_xla_gpu_nccl_p2p_max_nchannels(0);
   opts.set_xla_gpu_multi_streamed_windowed_einsum(true);
@@ -1713,11 +1711,11 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
       debug_options->xla_gpu_collective_permute_decomposer_threshold(),
       "Collective permute decomposer threshold."));
   flag_list->push_back(tsl::Flag(
-      "xla_gpu_enable_experimental_pipeline_parallelism_opt",
+      "xla_gpu_experimental_enable_pipeline_parallelism_opt",
       bool_setter_for(
           &DebugOptions::
-              set_xla_gpu_enable_experimental_pipeline_parallelism_opt),
-      debug_options->xla_gpu_enable_experimental_pipeline_parallelism_opt(),
+              set_xla_gpu_experimental_enable_pipeline_parallelism_opt),
+      debug_options->xla_gpu_experimental_enable_pipeline_parallelism_opt(),
       "Experimental optimizations for SPMD-based pipeline parallelism on "
       "GPU."));
   flag_list->push_back(tsl::Flag(
@@ -1993,16 +1991,6 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
   flag_list->push_back(tsl::Flag("xla_gpu_enable_dot_strength_reduction",
                                  noop_flag_setter<bool>, true,
                                  "[Deprecated, do not use]"));
-  flag_list->push_back(tsl::Flag(
-      "xla_gpu_enable_bf16_6way_gemm",
-      bool_setter_for(&DebugOptions::set_xla_gpu_enable_bf16_6way_gemm),
-      debug_options->xla_gpu_enable_bf16_6way_gemm(),
-      "Use BF16 6way gemm to compute F32 gemm."));
-  flag_list->push_back(tsl::Flag(
-      "xla_gpu_enable_bf16_3way_gemm",
-      bool_setter_for(&DebugOptions::set_xla_gpu_enable_bf16_3way_gemm),
-      debug_options->xla_gpu_enable_bf16_3way_gemm(),
-      "Use BF16 3way gemm to compute F32 gemm."));
   flag_list->push_back(
       tsl::Flag("xla_gpu_nccl_collective_max_nchannels",
                 int64_setter_for(
diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc
index 48dce2d678371..30f25de027a97 100755
--- a/xla/service/gpu/gpu_compiler.cc
+++ b/xla/service/gpu/gpu_compiler.cc
@@ -954,7 +954,7 @@ absl::Status RunCollectiveOptimizationPasses(
 
   if (hlo_module->config()
           .debug_options()
-          .xla_gpu_enable_experimental_pipeline_parallelism_opt()) {
+          .xla_gpu_experimental_enable_pipeline_parallelism_opt()) {
     collectives_pipeline.AddPass<CollectiveSelectFolder>();
   }
 
@@ -971,7 +971,7 @@ absl::Status RunCollectiveOptimizationPasses(
         collectives_pipeline,
         hlo_module->config()
             .debug_options()
-            .xla_gpu_enable_experimental_pipeline_parallelism_opt());
+            .xla_gpu_experimental_enable_pipeline_parallelism_opt());
   }
 
   // Run algebraic simplifier to reshape(broadcast) into a broadcast when
@@ -2669,7 +2669,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
 
     if (!module->config()
              .debug_options()
-             .xla_gpu_enable_experimental_pipeline_parallelism_opt() &&
+             .xla_gpu_experimental_enable_pipeline_parallelism_opt() &&
         (module->config()
              .debug_options()
              .xla_gpu_enable_pipelined_collectives() ||
diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler.cc b/xla/service/gpu/gpu_latency_hiding_scheduler.cc
index 2c50af565ef2b..353c70e3d6f31 100644
--- a/xla/service/gpu/gpu_latency_hiding_scheduler.cc
+++ b/xla/service/gpu/gpu_latency_hiding_scheduler.cc
@@ -272,13 +272,13 @@ void GpuAsyncTrackerBase::PostProcessScheduleGraph(
     // Schedule partially pipelined send/recv instructions late so that they can
     // overlap with compute. Schedule send/recv late and, when unblocked,
     // schedule send-done/recv-done early.
-    if (debug_options.xla_gpu_enable_experimental_pipeline_parallelism_opt() &&
+    if (debug_options.xla_gpu_experimental_enable_pipeline_parallelism_opt() &&
         IsPartiallyPipelinedSendRecv(inst)) {
       HloGraphNode& node = schedule_graph->GetNode(inst);
       node.SetForceDelay(true);
       VLOG(5) << "Setting force delay for instruction: " << inst->ToString();
     }
-    if (debug_options.xla_gpu_enable_experimental_pipeline_parallelism_opt() &&
+    if (debug_options.xla_gpu_experimental_enable_pipeline_parallelism_opt() &&
         IsPartiallyPipelinedSendRecvDone(inst)) {
       HloGraphNode& node = schedule_graph->GetNode(inst);
       node.SetForceEarly(true);
diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc
index 382e6e148e50e..038a6b9ec8dac 100644
--- a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc
+++ b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc
@@ -84,7 +84,7 @@ class GpuLatencyHidingSchedulerBaseTest : public HloTestBase {
     HloModuleConfig config;
     DebugOptions debug_options = GetDebugOptionsForTest();
     debug_options.set_xla_gpu_enable_latency_hiding_scheduler(true);
-    debug_options.set_xla_gpu_enable_experimental_pipeline_parallelism_opt(
+    debug_options.set_xla_gpu_experimental_enable_pipeline_parallelism_opt(
         enable_experimental_pipeline_parallelism_opt);
     config.set_debug_options(debug_options);
     config.set_fdo_profile(fdo_profile);
diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc
index f10ab3d181da3..a28f749b653a2 100644
--- a/xla/tests/collective_pipeline_parallelism_test.cc
+++ b/xla/tests/collective_pipeline_parallelism_test.cc
@@ -56,7 +56,7 @@ class CollectivePipelineParallelismTest
 
     // Set debug options.
     DebugOptions debug_options = GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_enable_experimental_pipeline_parallelism_opt(
+    debug_options.set_xla_gpu_experimental_enable_pipeline_parallelism_opt(
         GetParam());
     config.set_debug_options(debug_options);
 
diff --git a/xla/xla.proto b/xla/xla.proto
index 146bbf56d989a..6f71adef68446 100644
--- a/xla/xla.proto
+++ b/xla/xla.proto
@@ -348,12 +348,6 @@ message DebugOptions {
 
   bool xla_gpu_enable_approx_costly_collectives = 305;
 
-  // If enabled, uses bf16_3way gemm to compute F32 gemm.
-  bool xla_gpu_enable_bf16_3way_gemm = 279;
-
-  // If enabled, uses bf16_6way gemm to compute F32 gemm.
-  bool xla_gpu_enable_bf16_6way_gemm = 271;
-
   // Determine the types of commands that are recorded into command buffers.
   repeated CommandBufferCmdType xla_gpu_enable_command_buffer = 258;
 
@@ -386,10 +380,6 @@ message DebugOptions {
   // dynamic-update-slice operations around library calls.
   bool xla_gpu_enable_dynamic_slice_fusion = 105;
 
-  // Experimental optimizations for SPMD-based pipeline parallelism on GPU.
-  // TODO(bchetioui): adjust this name to follow the naming convention.
-  bool xla_gpu_enable_experimental_pipeline_parallelism_opt = 351;
-
   // When true we lower the Minimum and Maximum hlos in the GPU backend such
   // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN.  In other words, if flag
   // this is true we don't propagate NaNs through Min and Max.
@@ -513,6 +503,9 @@ message DebugOptions {
   // Pre-existing block-level fusions are left unmodified.
   bool xla_gpu_experimental_enable_fusion_block_level_rewriter = 334;
 
+  // Experimental optimizations for SPMD-based pipeline parallelism on GPU.
+  bool xla_gpu_experimental_enable_pipeline_parallelism_opt = 351;
+
   // When enabled, the PriorityFusion pass will try to make Triton fusions first
   // and foremost where it is possible.
   //
@@ -1152,8 +1145,10 @@ message DebugOptions {
   // xla_gpu_enable_heuristic_pass_configuration
   // xla_gpu_enable_dot_strength_reduction
   // xla_gpu_triton_fusion_level
+  // xla_gpu_enable_bf16_3way_gemm
+  // xla_gpu_enable_bf16_6way_gemm
   reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 221, 242, 206, 320,
-      325, 326, 332, 361, 270, 229;
+      325, 326, 332, 361, 270, 229, 271, 279;
 }
 
 // Contains flags which affects the GPU compilation result.