Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6057guided backprop silu #7070

Draft
wants to merge 9 commits into
base: dev
Choose a base branch
from
31 changes: 31 additions & 0 deletions monai/visualize/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,37 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return out


class _AutoGradSiLU(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
posmask = (x > 0).type_as(x)
mask = posmask + ((x <= 0).type_as(x) * torch.sigmoid((x <= 0).type_as(x)*x))
output = torch.mul(x, mask)
ctx.save_for_backward(x, output)
return output

@staticmethod
def backward(ctx, grad_output):
x, _ = ctx.saved_tensors
pos_mask_1 = (x > 0).type_as(grad_output)
mask_1 = pos_mask_1 + ((x <= 0).type_as(grad_output) * torch.sigmoid((x <= 0).type_as(grad_output)*x))
pos_mask_2 = (grad_output > 0).type_as(grad_output)
mask_2 = pos_mask_2 + ((grad_output <= 0).type_as(grad_output) * torch.sigmoid((grad_output <= 0).type_as(grad_output)*grad_output))
y = torch.mul(grad_output, mask_1)
grad_input = torch.mul(y, mask_2)
return grad_input


class _GradSiLU(torch.nn.Module):
"""
A customized SiLU with the backward pass imputed for guided backpropagation (https://arxiv.org/abs/1412.6806).
"""

def forward(self, x: torch.Tensor) -> torch.Tensor:
out: torch.Tensor = _AutoGradSiLU.apply(x)
return out


class VanillaGrad:
"""
Given an input image ``x``, calling this class will perform the forward pass, then set to zero
Expand Down