Skip to content

Commit

Permalink
add a variant of conditionally routed attention for cross attention, …
Browse files Browse the repository at this point in the history
…without a light branch
  • Loading branch information
lucidrains committed Apr 15, 2023
1 parent 5cd470f commit 19317f6
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 6 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,40 @@ block = ConditionalRoutedTransformerBlock(
block_out = block(tokens, mask = mask) # (2, 32768, 512)
```

Also included a variation of the conditionally routed attention for cross attention, to be tried with long context memories in a transformer-xl

```python
import torch
from colt5_attention import ConditionalRoutedCrossAttention

# mock input, let us say it is a transformer of 1024 length attending to 1 million context past memories

tokens = torch.randn(2, 1024, 512).cuda()
tokens_mask = torch.ones(2, 1024).bool().cuda()

memories = torch.randn(2, int(1e6), 512).cuda()
memories_mask = torch.ones(2, int(1e6)).bool().cuda()

# conditionally routed cross attention

cross_attn = ConditionalRoutedCrossAttention(
dim = 512,
dim_head = 64,
heads = 8,
num_tokens_q = 512, # only 512 routed from 1024
num_tokens_kv = 1024, # only 1024 routed from 1 million
).cuda()

cross_attn_out = cross_attn(
tokens,
context = memories,
mask = tokens_mask,
context_mask = memories_mask
)

cross_attn_out.shape # (2, 1024, 512) - same as tokens
```

## Todo

- [x] add the coordinate descent method as another router
Expand Down
1 change: 1 addition & 0 deletions colt5_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from colt5_attention.transformer_block import (
ConditionalRoutedFeedForward,
ConditionalRoutedAttention,
ConditionalRoutedCrossAttention,
ConditionalRoutedTransformerBlock,
DifferentiableTopKRouter,
SinkhornRouter,
Expand Down
107 changes: 102 additions & 5 deletions colt5_attention/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,6 @@ def __init__(
look_forward = 1
)

# for now, just do qkv for starters, need to separate to q and kv

self.q_router = router_klass(
dim = dim,
straight_through = router_straight_through,
Expand Down Expand Up @@ -431,7 +429,7 @@ def forward(
batch_range = torch.arange(batch, device = device)
batch_range = rearrange(batch_range, 'b -> b 1')

# light local attention sees all tokens
# light local attention sees all tokens in a limited context

light_out = self.light_attn(x, mask = mask)

Expand All @@ -440,7 +438,7 @@ def forward(
normalized_scores_q, indices_q = self.q_router(x, num_tokens = num_heavy_tokens_q, mask = mask)
normalized_scores_kv, indices_kv = self.kv_router(x, num_tokens = num_heavy_tokens_kv, mask = mask)

# select the tokens to be routed to heavier feedforward (hidden dimension is 4 times model dimensions)
# select the tokens to be routed to full attention

routed_tokens_q = x[batch_range, indices_q]
routed_tokens_kv = x[batch_range, indices_kv]
Expand All @@ -462,7 +460,7 @@ def forward(

routed_tokens_out = routed_tokens_out * rearrange(normalized_scores_q, '... -> ... 1')

# scatter back the output of the heavy feedforward branch
# scatter back the output of the heavy branch

heavy_out = torch.zeros_like(x)

Expand All @@ -475,6 +473,105 @@ def forward(

return light_out + heavy_out

# adapting the conditional routed self attention to cross attention

class ConditionalRoutedCrossAttention(nn.Module):
def __init__(
self,
dim,
*,
num_tokens_q,
num_tokens_kv,
dim_head = 64,
heads = 8,
router_straight_through = True, # would make sure all normalized scores are 1., still differentiable
router_type = 'coor_descent',
router_kwargs: dict = {}
):
super().__init__()
assert router_type in ROUTERS.keys()

self.router_type = router_type

router_klass = ROUTERS.get(router_type)

self.num_tokens_q = num_tokens_q
self.num_tokens_kv = num_tokens_kv

self.q_router = router_klass(
dim = dim,
straight_through = router_straight_through,
**router_kwargs
)

self.kv_router = router_klass(
dim = dim,
straight_through = router_straight_through,
**router_kwargs
)

self.heavy_attn = Attention(
dim = dim,
dim_head = dim_head,
heads = heads
)

def forward(
self,
x,
context,
*,
num_tokens_q = None,
num_tokens_kv = None,
mask = None,
context_mask = None
):
batch, device = x.shape[0], x.device

num_tokens_q = default(num_tokens_q, self.num_tokens_q)
num_tokens_kv = default(num_tokens_kv, self.num_tokens_kv)

batch_range = torch.arange(batch, device = device)
batch_range = rearrange(batch_range, 'b -> b 1')

# route tokens appropriately

normalized_scores_q, indices_q = self.q_router(x, num_tokens = num_tokens_q, mask = mask)
normalized_scores_kv, indices_kv = self.kv_router(context, num_tokens = num_tokens_kv, mask = context_mask)

# select the tokens to be routed

routed_tokens_q = x[batch_range, indices_q]
routed_tokens_kv = x[batch_range, indices_kv]

# calculate key padding mask

routed_tokens_kv_mask = None
if exists(mask):
routed_tokens_kv_mask = context_mask[batch_range, indices_kv]

# do the heavier branch with only routed tokens

routed_tokens_out = self.heavy_attn(
routed_tokens_q,
mask = routed_tokens_kv_mask,
context = routed_tokens_kv,
normalized_scores_kv = normalized_scores_kv
)

routed_tokens_out = routed_tokens_out * rearrange(normalized_scores_q, '... -> ... 1')

# scatter back the output

out = torch.zeros_like(x)

if self.router_type == 'sinkhorn':
out = scatter_mean(heavy_out, routed_tokens_out, indices_q, dim = 1)
else:
out[batch_range, indices_q] = routed_tokens_out

return out

# block

class ConditionalRoutedTransformerBlock(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CoLT5-attention',
packages = find_packages(),
version = '0.1.6',
version = '0.2.0',
license='MIT',
description = 'Conditionally Routed Attention',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 19317f6

Please sign in to comment.