From 18d957bcd81f12d32894e84ad635e2278fd62323 Mon Sep 17 00:00:00 2001 From: Xuefei Jiang Date: Thu, 19 Sep 2024 16:54:23 -0700 Subject: [PATCH] PR #16938: Add NANOO FP8 support for collaborative communication unit tests Imported from GitHub PR https://github.com/openxla/xla/pull/16938 This PR adds support for NANOO FP8 data format in the collaborative communication unit tests. - For the context on OCP FP8 and NANOO FP8, please refer to this comment: https://github.com/google/flax/pull/3993#issue-2350000228 - The unit tests in this PR are similar to GEMM unit test introduced in the following PR to be able to deal with both OCP and NANOO fp8 formats: https://github.com/openxla/xla/pull/10488 Copybara import of the project: -- 0fc74ccae6cfcaf4e8627ea338ee03783af0626b by Wen Chen : [AMD] Added NCCL support for fp8e4m3fnuz and fp8e5m2fnuz. -- d247af5cd33fe42698bb55ef1c18f32df8a02a21 by scxfjiang : refactor tests for collective comm ops -- 6f8c418b3052f7c531896bd5f8cbbc7a766ef7fc by scxfjiang : rafactor collective comm e2e tests -- 8ecb6ecf08a1536c5b3f8ba87e0e9f8813b1b359 by scxfjiang : update: replace str -- 338d3af2ca1a32302fdfe9d7abee335d24539ee9 by scxfjiang : get rid of macros Merging this change closes #16938 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/16938 from ROCm:ci_dev_rccl_nanoo_fp8 338d3af2ca1a32302fdfe9d7abee335d24539ee9 PiperOrigin-RevId: 676615012 --- xla/service/gpu/runtime/nccl_api.cc | 2 + .../gpu/runtime/nccl_collective_thunk.cc | 2 + xla/tests/BUILD | 2 + xla/tests/collective_ops_e2e_test.cc | 34 +++- xla/tests/collective_ops_test.cc | 179 ++++++++++-------- 5 files changed, 137 insertions(+), 82 deletions(-) diff --git a/xla/service/gpu/runtime/nccl_api.cc b/xla/service/gpu/runtime/nccl_api.cc index 783bd2ddaddee..683195b2b4146 100644 --- a/xla/service/gpu/runtime/nccl_api.cc +++ b/xla/service/gpu/runtime/nccl_api.cc @@ -112,6 +112,8 @@ static absl::StatusOr ToNcclDataType(PrimitiveType dtype, case S8: case F8E5M2: case F8E4M3FN: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return ncclInt8; case PRED: case U8: diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.cc b/xla/service/gpu/runtime/nccl_collective_thunk.cc index 7582c18c292e7..cdb0a0c38b28e 100644 --- a/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -93,6 +93,8 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type, // they involve actual computation and not just data movement. case F8E5M2: case F8E4M3FN: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return !IsReductionCollective(reduction_op); default: return false; diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 2c225af4b8dff..66a7b3c18168d 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2388,6 +2388,8 @@ xla_test( "//xla/hlo/utils:hlo_matchers", "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", ], ) diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index f1d1c78d28bb6..df9f73c4e68ef 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -50,6 +52,13 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) { class CollectiveOpsTestE2E : public HloTestBase { public: + CollectiveOpsTestE2E() { + replacements_[kF8E4M3DatatypePlaceholder] = + IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = + IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + } + bool IsCuda() { return std::holds_alternative(Capability()); } @@ -79,6 +88,13 @@ class CollectiveOpsTestE2E : public HloTestBase { /*argument_provider*/ [](int64_t, int64_t) { return nullptr; }, num_replicas, /*run_hlo_passes=*/false, &device_assignment); } + + protected: + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; }; // E2E tests for collective ops. These will generally verify some HLO transform @@ -770,11 +786,11 @@ ENTRY main.12 { TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, WindowedEinsumE2EAllGatherAndReduceScatterF8) { absl::string_view kModuleReplicatedStr = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<>[2,16,48]{2,1,0}, <>[48,192]{1,0}, <>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 ENTRY main.12 { - Arg_0.1 = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - Arg_1.2 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + Arg_0.1 = <>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + Arg_1.2 = <>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} Arg_2.3 = bf16[] parameter(3) Arg_3.4 = bf16[] parameter(4) broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={} @@ -793,12 +809,12 @@ ENTRY main.12 { constant.1 = bf16[] constant(448.) broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={} clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4) - convert.2 = f8e4m3fn[2,16,192]{2,1,0} convert(clamp) + convert.2 = <>[2,16,192]{2,1,0} convert(clamp) Arg_5.6 = bf16[] parameter(6) broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={} convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2) multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5) - Arg_6.7 = f8e4m3fn[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} + Arg_6.7 = <>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} Arg_7.8 = bf16[] parameter(7) broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={} convert.4 = bf16[192,48]{1,0} convert(Arg_6.7) @@ -929,7 +945,7 @@ while_body { r = bf16[32,128] bitcast(dynamic-slice.k) a = bf16[32,128] add(r, r), control-predecessors={constant.2559} // A fp8 pattern of quant-dequant before the collective AG. - qa = f8e4m3fn[32,128] convert(a) + qa = <>[32,128] convert(a) dqa = bf16[32,128] convert(qa) a_scale = bf16[] get-tuple-element(param), index=3 a_scales = bf16[32,128] broadcast(a_scale), dimensions={} @@ -937,7 +953,7 @@ while_body { mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}} ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128} - qma = f8e4m3fn[128,128] convert(ma) + qma = <>[128,128] convert(ma) dqma = bf16[128,128] convert(qma) ma_scale = bf16[] get-tuple-element(param), index=4 ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={} @@ -970,7 +986,9 @@ ENTRY entry { config.set_debug_options(opts); config.set_num_partitions(kNumPartitions); TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); + auto module, + ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), config)); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(std::move(module), diff --git a/xla/tests/collective_ops_test.cc b/xla/tests/collective_ops_test.cc index 0d8c3062bf223..ae13d8c38f107 100644 --- a/xla/tests/collective_ops_test.cc +++ b/xla/tests/collective_ops_test.cc @@ -1762,80 +1762,6 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) { } } -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[1,2] constant({{1,2}}) - allgather = f8e4m3fn[2, 2] all-gather(a0), dimensions={0} - p = f8e4m3fn[4] reshape(allgather) - ROOT out = f32[4] convert(p) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - for (const Literal& result : results) { - LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); - } -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[2] constant({1,2}) - a2a = f8e4m3fn[2] all-to-all(a0), dimensions={0} - ROOT out = f32[2] convert(a2a) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); - LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e5m2[2] constant({1,2}) - a1 = f8e5m2[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} - ROOT out = f32[2] convert(a1) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); -} - XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) { const char* const kModuleStr = R"( HloModule test @@ -2282,5 +2208,110 @@ body { results[1])); } +class Fp8CollectiveOpsTest : public CollectiveOpsTest { + public: + Fp8CollectiveOpsTest() { + replacements_[kF8E4M3DatatypePlaceholder] = + IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = + IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + } + + protected: + bool IsCuda() { + return std::holds_alternative(Capability()); + } + + const se::GpuComputeCapability& Capability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; +}; + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[1,2] constant({{1,2}}) + allgather = <>[2, 2] all-gather(a0), dimensions={0} + p = <>[4] reshape(allgather) + ROOT out = f32[4] convert(p) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (const Literal& result : results) { + LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); + } +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a2a = <>[2] all-to-all(a0), dimensions={0} + ROOT out = f32[2] convert(a2a) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); + LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a1 = <>[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} + ROOT out = f32[2] convert(a1) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); +} + } // namespace } // namespace xla