Skip to content

Commit

Permalink
grad detach_ only when it has grad_fn in zero_grad call (pytorch#41283)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#41283

in optimizer.zero_grad(), detach_ is useful to avoid memory leak only when grad has grad_fn, so add check to call grad.detach_ only when the grad has grad_fn in zero_grad() function
ghstack-source-id: 108702289

Test Plan: unit test

Reviewed By: mrshenli

Differential Revision: D22487315

fbshipit-source-id: 861909b15c8497f1da57f092d8963d4920c85e38
  • Loading branch information
zhaojuanmao authored and facebook-github-bot committed Jul 29, 2020
1 parent 4b6e5f4 commit 79cfd85
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
5 changes: 4 additions & 1 deletion torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,10 @@ def zero_grad(self) -> None:

for p in self.parameters():
if p.grad is not None:
p.grad.detach_()
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()

def share_memory(self: T) -> T:
Expand Down
5 changes: 4 additions & 1 deletion torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,10 @@ def _sync_params(self):
# to zero the grads on all model replicas as well.
# This snippet is copied from torch.optim.Optimizer.
if param.grad is not None:
param.grad.detach_()
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()

# module buffer sync
Expand Down
5 changes: 4 additions & 1 deletion torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ def zero_grad(self):
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.detach_()
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()

def step(self, closure):
Expand Down

0 comments on commit 79cfd85

Please sign in to comment.