You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
hope you're fine and I'm really glad that the zennit community grows, congratulation!
With a growing community, more nn.Modules desire to be explained and that's why I'm writing this issue.
A student in our department tries to explain a LinearAttention module. (The implementation is below for reference).
It contains a series of torch.einsum
and torch.transpose
operations.
It uses the rearrange function of the einops library, a new syntax to write basic torch code like transpose, reshape etc.
I think, zennit should be able to analyse a series of reshaping and transposing operations. However, I am not completely sure.
I'd be glad, if you could give your opinion on analyzing such a linear attention module. If you don't know, that's also no problem (: Then, it's the beginning of a new research topic.
(And the softmax function is also a problem, but maybe Arras et. al has a solution to this which the student could implement... )
Best,
Reduan
classLinearAttention(nn.Module):
def__init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale=dim_head**-0.5self.heads=headshidden_dim=dim_head*headsself.to_qkv=nn.Conv2d(dim, hidden_dim*3, 1, bias=False)
self.to_out=nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
nn.GroupNorm(1, dim))
defforward(self, x):
b, c, h, w=x.shapeqkv=self.to_qkv(x).chunk(3, dim=1)
q, k, v=map(
lambdat: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q=q.softmax(dim=-2)
k=k.softmax(dim=-1)
q=q*self.scalecontext=torch.einsum("b h d n, b h e n -> b h d e", k, v)
out=torch.einsum("b h d e, b h d n -> b h e n", context, q)
out=rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
returnself.to_out(out)
The text was updated successfully, but these errors were encountered:
thank you for the issue!
You can have a look at this work, where they introduce LRP for Transformers (i.e. also attention heads).
I have talked to @tschnake before about bringing transformers to Zennit, which is still as WIP as it gets.
About the implementation details:
The rearrange operation is just a re-indexing, so the correct approach for it is already simply the gradient, so it is supported by Zennit.
The einsum is a linear operation, so it can be handled like a linear layer in LRP.
The softmax is a little tricky. In the work above they handle this by viewing the gating terms as constants.
In code, we may get away by requiring to use torch.nn.Softmax and implementing a Constant rule, which will have the gradient be set to zero, although I need to think a little more if this would work as intended.
Otherwise, we could also implement a canonizer (or a meta-rule) for the most popular library implementing attention layers.
Hi Christopher,
hope you're fine and I'm really glad that the zennit community grows, congratulation!
With a growing community, more nn.Modules desire to be explained and that's why I'm writing this issue.
A student in our department tries to explain a LinearAttention module. (The implementation is below for reference).
It contains a series of
torch.einsum
and
torch.transpose
operations.
It uses the
rearrange
function of the einops library, a new syntax to write basic torch code like transpose, reshape etc.I think, zennit should be able to analyse a series of reshaping and transposing operations. However, I am not completely sure.
I'd be glad, if you could give your opinion on analyzing such a linear attention module. If you don't know, that's also no problem (: Then, it's the beginning of a new research topic.
(And the softmax function is also a problem, but maybe Arras et. al has a solution to this which the student could implement... )
Best,
Reduan
The text was updated successfully, but these errors were encountered: