diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f7b725269f726..5899ec0e8d101 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1321,7 +1321,10 @@ def zero_grad(self) -> None: for p in self.parameters(): if p.grad is not None: - p.grad.detach_() + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) p.grad.zero_() def share_memory(self: T) -> T: diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 3fba268d99a30..40ada58b2641d 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -686,7 +686,10 @@ def _sync_params(self): # to zero the grads on all model replicas as well. # This snippet is copied from torch.optim.Optimizer. if param.grad is not None: - param.grad.detach_() + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) param.grad.zero_() # module buffer sync diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 5c222bd38b038..2370730f85fbd 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -169,7 +169,10 @@ def zero_grad(self): for group in self.param_groups: for p in group['params']: if p.grad is not None: - p.grad.detach_() + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) p.grad.zero_() def step(self, closure):