Skip to content

Commit

Permalink
allow bf16 transfer dtype in bench script
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Oct 4, 2024
1 parent 5d0e704 commit b290a76
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions scripts/all_reduce_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
class TorchDtype(str, Enum):
FLOAT32 = "float32"
FLOAT16 = "float16"
BFLOAT16 = "bfloat16"
UINT8 = "uint8"


TORCH_DTYPE_MAP = {
None: None,
TorchDtype.FLOAT32: torch.float32,
TorchDtype.FLOAT16: torch.float16,
TorchDtype.BFLOAT16: torch.bfloat16,
TorchDtype.UINT8: torch.uint8,
}

Expand Down

0 comments on commit b290a76

Please sign in to comment.