diff --git a/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc b/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc index a64e1153ccf80..4a9d95c15fb31 100644 --- a/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc +++ b/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc @@ -137,25 +137,6 @@ class Triton6xBF16GemmTest : public AlgorithmTest { } }; -// In these tests, we depend on debug option flags for selecting the 6XBF16 -// algorithm. -// TODO(b/379905071): Remove this class and the --xla_gpu_enable_bf16_6way_gemm -// flag after we will support the algorithm values through the entire stack. -class Triton6xBF16GemmTestWithFlag : public AlgorithmTest { - public: - DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest(); - // Do not fall back to cuBLAS, we are testing Triton. - debug_options.set_xla_gpu_cublas_fallback(false); - // Do not autotune split-k by default, since this prevents deterministically - // matching the optimized HLO. - debug_options.set_xla_gpu_enable_split_k_autotuning(false); - // Enable bf16_6way gemm to compute F32 matmul. - debug_options.set_xla_gpu_enable_bf16_6way_gemm(true); - return debug_options; - } -}; - class BlasAlgorithmTest : public AlgorithmTest { public: DebugOptions GetDebugOptionsForTest() const override { @@ -518,47 +499,6 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf /*arel=*/1e-6})); } -TEST_F(Triton6xBF16GemmTestWithFlag, Emit6xBF16GemmWhenBothInputsAreF32) { - constexpr absl::string_view kHloText = R"( - HloModule Emit6xBF16GemmWhenBothInputsAreF32 - - triton_dot { - p0 = f32[5,7] parameter(0) - p1 = f32[7,33] parameter(1) - ROOT dot = f32[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - - ENTRY e { - p0 = f32[5,7]{1,0} parameter(0) - p1 = f32[7,33]{1,0} parameter(1) - ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} - } - )"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> -CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> -CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> -CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> -CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> -CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> -CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> -CHECK-COUNT-5: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ABS:.*]] = math.absf -CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> -CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> -CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> - )")); - - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-6, - /*arel=*/1e-6})); -} - TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForLongContractingDimension) { constexpr absl::string_view kHloText = R"( HloModule Triton6xBF16GemmWorksForLongContractingDimension @@ -635,34 +575,6 @@ class Triton3xBF16GemmTest : public AlgorithmTest { } }; -// In these tests, we depend on debug option flags for selecting the 3XBF16 -// algorithm. -// TODO(b/379905071): Remove this class and the --xla_gpu_enable_bf16_3way_gemm -// flag after we will support the algorithm values through the entire stack. -class Triton3xBF16GemmTestWithFlag : public AlgorithmTest { - public: - DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest(); - // Enable triton fusion for all supported GEMMs. - debug_options.set_xla_gpu_triton_gemm_any(true); - // Do not fall back to cuBLAS, we are testing Triton. - debug_options.set_xla_gpu_cublas_fallback(false); - // Do not autotune split-k by default, since this prevents deterministically - // matching the optimized HLO. - debug_options.set_xla_gpu_enable_split_k_autotuning(false); - // Enable bf16_3way gemm to compute F32 matmul. - debug_options.set_xla_gpu_enable_bf16_3way_gemm(true); - return debug_options; - } - - protected: - void SetUp() override { - if (!SupportsBF16(GpuComputeComp())) { - GTEST_SKIP() << "BF16 not supported."; - } - } -}; - TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmWhenBothInputsAreF32) { constexpr absl::string_view kHloText = R"( HloModule Emit3xBF16GemmWhenBothInputsAreF32 @@ -705,75 +617,6 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf /*arel=*/1e-5})); } -TEST_F(Triton3xBF16GemmTestWithFlag, Emit3xBF16GemmWhenBothInputsAreF32) { - constexpr absl::string_view kHloText = R"( - HloModule Emit3xBF16GemmWhenBothInputsAreF32 - - triton_dot { - p0 = f32[5,7] parameter(0) - p1 = f32[7,33] parameter(1) - ROOT dot = f32[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - - ENTRY e { - p0 = f32[5,7]{1,0} parameter(0) - p1 = f32[7,33]{1,0} parameter(1) - ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} - } - )"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> -CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> -CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> -CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> -CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> -CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> -CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> -CHECK-COUNT-2: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ABS:.*]] = math.absf -CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> -CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> -CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> - )")); - - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, - /*arel=*/1e-5})); -} - -TEST_F(Triton3xBF16GemmTestWithFlag, NoEmit3xBF16GemmWhenBothInputsAreNotF32) { - constexpr absl::string_view kHloText = R"( - HloModule NoEmit3xBF16GemmWhenBothInputsAreNotF32 - - triton_dot { - p0 = f16[5,7] parameter(0) - p1 = f16[7,33] parameter(1) - ROOT dot = f16[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - - ENTRY e { - p0 = f16[5,7]{1,0} parameter(0) - p1 = f16[7,33]{1,0} parameter(1) - ROOT _ = f16[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} - } - )"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK: tt.dot -CHECK-SAME: tensor<32x32xf16> * tensor<32x32xf16> -> tensor<32x32xf32> -CHECK-NOT: tt.dot - )")); -} - TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForLongContractingDimension) { constexpr absl::string_view kHloText = R"( HloModule Triton3xBF16GemmWorksForLongContractingDimension diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc b/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc index bbfcf71f52e9a..5995000fc3df0 100644 --- a/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc +++ b/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc @@ -1664,18 +1664,6 @@ bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr, const PrecisionConfig::Algorithm algorithm = dot_instr->precision_config().algorithm(); - if (algorithm == PrecisionConfig::ALG_UNSET) { - const HloModule* hlo_module = dot_instr->GetModule(); - Type f32 = b.getF32Type(); - return hlo_module->config() - .debug_options() - .xla_gpu_enable_bf16_6way_gemm() && - mlir::cast(dot_input_lhs.getType()).getElementType() == - f32 && - mlir::cast(dot_input_rhs.getType()).getElementType() == - f32; - } - return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6; } @@ -1686,18 +1674,6 @@ bool Is3xBfloat16MatMul(const HloDotInstruction* dot_instr, const PrecisionConfig::Algorithm algorithm = dot_instr->precision_config().algorithm(); - if (algorithm == PrecisionConfig::ALG_UNSET) { - const HloModule* hlo_module = dot_instr->GetModule(); - Type f32 = b.getF32Type(); - return hlo_module->config() - .debug_options() - .xla_gpu_enable_bf16_3way_gemm() && - mlir::cast(dot_input_lhs.getType()).getElementType() == - f32 && - mlir::cast(dot_input_rhs.getType()).getElementType() == - f32; - } - return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3; } @@ -2105,13 +2081,6 @@ absl::Status EmitMatMul(EmitterLocOpBuilder& b, return; } - const HloModule* hlo_module = dot_instr->GetModule(); - if (hlo_module->config().debug_options().xla_gpu_enable_bf16_3way_gemm() && - hlo_module->config().debug_options().xla_gpu_enable_bf16_6way_gemm()) { - LOG(WARNING) << "Both BF16 6way gemm and 3way gemm are enabled." - << " Fallback to BF16 6way gemm."; - } - Value accumulator_next; if (Is6xBfloat16MatMul(dot_instr, b, dot_input_lhs, dot_input_rhs, device_info)) { diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 0b803fc49f589..aed4c49d5e530 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -205,7 +205,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_collective_permute_decomposer_threshold( std::numeric_limits::max()); - opts.set_xla_gpu_enable_experimental_pipeline_parallelism_opt(false); + opts.set_xla_gpu_experimental_enable_pipeline_parallelism_opt(false); opts.set_xla_cpu_enable_mlir_tiling_and_fusion(true); opts.set_xla_cpu_enable_custom_matmul_tiling(false); @@ -268,8 +268,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { stream_executor::IsLibNvPtxCompilerSupported()); opts.set_xla_gpu_libnvjitlink_mode(DebugOptions::LIB_NV_JIT_LINK_MODE_AUTO); - opts.set_xla_gpu_enable_bf16_6way_gemm(false); - opts.set_xla_gpu_enable_bf16_3way_gemm(false); opts.set_xla_gpu_nccl_collective_max_nchannels(0); opts.set_xla_gpu_nccl_p2p_max_nchannels(0); opts.set_xla_gpu_multi_streamed_windowed_einsum(true); @@ -1713,11 +1711,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_collective_permute_decomposer_threshold(), "Collective permute decomposer threshold.")); flag_list->push_back(tsl::Flag( - "xla_gpu_enable_experimental_pipeline_parallelism_opt", + "xla_gpu_experimental_enable_pipeline_parallelism_opt", bool_setter_for( &DebugOptions:: - set_xla_gpu_enable_experimental_pipeline_parallelism_opt), - debug_options->xla_gpu_enable_experimental_pipeline_parallelism_opt(), + set_xla_gpu_experimental_enable_pipeline_parallelism_opt), + debug_options->xla_gpu_experimental_enable_pipeline_parallelism_opt(), "Experimental optimizations for SPMD-based pipeline parallelism on " "GPU.")); flag_list->push_back(tsl::Flag( @@ -1993,16 +1991,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, flag_list->push_back(tsl::Flag("xla_gpu_enable_dot_strength_reduction", noop_flag_setter, true, "[Deprecated, do not use]")); - flag_list->push_back(tsl::Flag( - "xla_gpu_enable_bf16_6way_gemm", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_bf16_6way_gemm), - debug_options->xla_gpu_enable_bf16_6way_gemm(), - "Use BF16 6way gemm to compute F32 gemm.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_enable_bf16_3way_gemm", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_bf16_3way_gemm), - debug_options->xla_gpu_enable_bf16_3way_gemm(), - "Use BF16 3way gemm to compute F32 gemm.")); flag_list->push_back( tsl::Flag("xla_gpu_nccl_collective_max_nchannels", int64_setter_for( diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 48dce2d678371..30f25de027a97 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -954,7 +954,7 @@ absl::Status RunCollectiveOptimizationPasses( if (hlo_module->config() .debug_options() - .xla_gpu_enable_experimental_pipeline_parallelism_opt()) { + .xla_gpu_experimental_enable_pipeline_parallelism_opt()) { collectives_pipeline.AddPass(); } @@ -971,7 +971,7 @@ absl::Status RunCollectiveOptimizationPasses( collectives_pipeline, hlo_module->config() .debug_options() - .xla_gpu_enable_experimental_pipeline_parallelism_opt()); + .xla_gpu_experimental_enable_pipeline_parallelism_opt()); } // Run algebraic simplifier to reshape(broadcast) into a broadcast when @@ -2669,7 +2669,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( if (!module->config() .debug_options() - .xla_gpu_enable_experimental_pipeline_parallelism_opt() && + .xla_gpu_experimental_enable_pipeline_parallelism_opt() && (module->config() .debug_options() .xla_gpu_enable_pipelined_collectives() || diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler.cc b/xla/service/gpu/gpu_latency_hiding_scheduler.cc index 2c50af565ef2b..353c70e3d6f31 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler.cc +++ b/xla/service/gpu/gpu_latency_hiding_scheduler.cc @@ -272,13 +272,13 @@ void GpuAsyncTrackerBase::PostProcessScheduleGraph( // Schedule partially pipelined send/recv instructions late so that they can // overlap with compute. Schedule send/recv late and, when unblocked, // schedule send-done/recv-done early. - if (debug_options.xla_gpu_enable_experimental_pipeline_parallelism_opt() && + if (debug_options.xla_gpu_experimental_enable_pipeline_parallelism_opt() && IsPartiallyPipelinedSendRecv(inst)) { HloGraphNode& node = schedule_graph->GetNode(inst); node.SetForceDelay(true); VLOG(5) << "Setting force delay for instruction: " << inst->ToString(); } - if (debug_options.xla_gpu_enable_experimental_pipeline_parallelism_opt() && + if (debug_options.xla_gpu_experimental_enable_pipeline_parallelism_opt() && IsPartiallyPipelinedSendRecvDone(inst)) { HloGraphNode& node = schedule_graph->GetNode(inst); node.SetForceEarly(true); diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc index 382e6e148e50e..038a6b9ec8dac 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc +++ b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc @@ -84,7 +84,7 @@ class GpuLatencyHidingSchedulerBaseTest : public HloTestBase { HloModuleConfig config; DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_latency_hiding_scheduler(true); - debug_options.set_xla_gpu_enable_experimental_pipeline_parallelism_opt( + debug_options.set_xla_gpu_experimental_enable_pipeline_parallelism_opt( enable_experimental_pipeline_parallelism_opt); config.set_debug_options(debug_options); config.set_fdo_profile(fdo_profile); diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index f10ab3d181da3..a28f749b653a2 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -56,7 +56,7 @@ class CollectivePipelineParallelismTest // Set debug options. DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_experimental_pipeline_parallelism_opt( + debug_options.set_xla_gpu_experimental_enable_pipeline_parallelism_opt( GetParam()); config.set_debug_options(debug_options); diff --git a/xla/xla.proto b/xla/xla.proto index 146bbf56d989a..6f71adef68446 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -348,12 +348,6 @@ message DebugOptions { bool xla_gpu_enable_approx_costly_collectives = 305; - // If enabled, uses bf16_3way gemm to compute F32 gemm. - bool xla_gpu_enable_bf16_3way_gemm = 279; - - // If enabled, uses bf16_6way gemm to compute F32 gemm. - bool xla_gpu_enable_bf16_6way_gemm = 271; - // Determine the types of commands that are recorded into command buffers. repeated CommandBufferCmdType xla_gpu_enable_command_buffer = 258; @@ -386,10 +380,6 @@ message DebugOptions { // dynamic-update-slice operations around library calls. bool xla_gpu_enable_dynamic_slice_fusion = 105; - // Experimental optimizations for SPMD-based pipeline parallelism on GPU. - // TODO(bchetioui): adjust this name to follow the naming convention. - bool xla_gpu_enable_experimental_pipeline_parallelism_opt = 351; - // When true we lower the Minimum and Maximum hlos in the GPU backend such // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag // this is true we don't propagate NaNs through Min and Max. @@ -513,6 +503,9 @@ message DebugOptions { // Pre-existing block-level fusions are left unmodified. bool xla_gpu_experimental_enable_fusion_block_level_rewriter = 334; + // Experimental optimizations for SPMD-based pipeline parallelism on GPU. + bool xla_gpu_experimental_enable_pipeline_parallelism_opt = 351; + // When enabled, the PriorityFusion pass will try to make Triton fusions first // and foremost where it is possible. // @@ -1152,8 +1145,10 @@ message DebugOptions { // xla_gpu_enable_heuristic_pass_configuration // xla_gpu_enable_dot_strength_reduction // xla_gpu_triton_fusion_level + // xla_gpu_enable_bf16_3way_gemm + // xla_gpu_enable_bf16_6way_gemm reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 221, 242, 206, 320, - 325, 326, 332, 361, 270, 229; + 325, 326, 332, 361, 270, 229, 271, 279; } // Contains flags which affects the GPU compilation result.