From 852ebc44a0da8f97c33a4c9ce5eefd925ec0c4e4 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 9 Jun 2024 17:56:15 +0100 Subject: [PATCH] Update optim.grad_scaler to use torch.amp Co-authored-by: Luciferian Ink --- hivemind/optim/grad_scaler.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/hivemind/optim/grad_scaler.py b/hivemind/optim/grad_scaler.py index 704f859c9..49562158b 100644 --- a/hivemind/optim/grad_scaler.py +++ b/hivemind/optim/grad_scaler.py @@ -4,8 +4,17 @@ from typing import Dict, Optional import torch -from torch.cuda.amp import GradScaler as TorchGradScaler -from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state +from packaging import version + +torch_version = torch.__version__.split("+")[0] + +if version.parse(torch_version) >= version.parse("1.12.0"): + from torch.amp import GradScaler as TorchGradScaler + from torch.amp.grad_scaler import OptState, _refresh_per_optimizer_state +else: + from torch.cuda.amp import GradScaler as TorchGradScaler + from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state + from torch.optim import Optimizer as TorchOptimizer import hivemind