From b290a764fa88101c36198a6150f64e35c09e7d69 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 5 Oct 2024 05:09:25 +0800 Subject: [PATCH] allow bf16 transfer dtype in bench script --- scripts/all_reduce_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/all_reduce_test.py b/scripts/all_reduce_test.py index 3a0e1684..31201b45 100644 --- a/scripts/all_reduce_test.py +++ b/scripts/all_reduce_test.py @@ -13,6 +13,7 @@ class TorchDtype(str, Enum): FLOAT32 = "float32" FLOAT16 = "float16" + BFLOAT16 = "bfloat16" UINT8 = "uint8" @@ -20,6 +21,7 @@ class TorchDtype(str, Enum): None: None, TorchDtype.FLOAT32: torch.float32, TorchDtype.FLOAT16: torch.float16, + TorchDtype.BFLOAT16: torch.bfloat16, TorchDtype.UINT8: torch.uint8, }