Skip to content

Commit

Permalink
fix fp8 checkpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Jan 13, 2025
1 parent 01f6d06 commit 9500523
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions dolomite_engine/distributed/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
precompute_float8_dynamic_scale_for_fsdp,
sync_float8_amax_and_scale_history,
)
from torchao.float8.fsdp_utils import (
WeightWithDelayedFloat8CastTensor,
WeightWithDynamicFloat8CastTensor,
WeightWithStaticFloat8CastTensor,
)

torch.serialization.add_safe_globals(
[WeightWithDynamicFloat8CastTensor, WeightWithDelayedFloat8CastTensor, WeightWithStaticFloat8CastTensor]
)

_PRECOMPUTE_SCALE: bool = False
_DELAYED_SCALING: bool = False
Expand Down

0 comments on commit 9500523

Please sign in to comment.