Skip to content

Commit

Permalink
improvise an autoregressive version of the routing attention, by doin…
Browse files Browse the repository at this point in the history
…g local window causal attention on the light branch, and then routing tokens farther into the past than the local window for each block for the heavy branch
  • Loading branch information
lucidrains committed Apr 28, 2023
1 parent 654da7c commit e74a985
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 18 deletions.
86 changes: 70 additions & 16 deletions colt5_attention/transformer_block.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from functools import partial

import torch
Expand All @@ -24,6 +25,17 @@ def pack_one(t, pattern):
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]

def pad_to_multiple(tensor, multiple, dim=-1, value=0):
seq_len = tensor.shape[dim]
m = seq_len / multiple
if m.is_integer():
return tensor, seq_len

remainder = math.ceil(m) * multiple - seq_len
pad_offset = (0,) * (-1 - dim) * 2
padded_tensor = F.pad(tensor, (*pad_offset, 0, remainder), value = value)
return padded_tensor, seq_len

# tensor helpers

def create_batch_range(t):
Expand Down Expand Up @@ -108,7 +120,7 @@ def forward(
if context.ndim == 3:
context = rearrange(context, 'b n d -> b 1 n d')

if exists(normalized_scores_kv):
if exists(normalized_scores_kv) and isinstance(normalized_scores_kv, torch.Tensor):
if normalized_scores_kv.ndim == 2:
normalized_scores_kv = rearrange(normalized_scores_kv, 'b n -> b 1 n')

Expand All @@ -121,7 +133,7 @@ def forward(
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)

if exists(normalized_scores_q):
if exists(normalized_scores_q) and isinstance(normalized_scores_q, torch.Tensor):
q = q * rearrange(normalized_scores_q, 'b n -> b 1 n 1')

# handle key / values, with the routing dimension, dividing the number of heads in between the routes
Expand Down Expand Up @@ -642,6 +654,7 @@ def __init__(
light_dim_head = 64,
light_heads = 8,
light_window_size = 128, # each token would see ~ 64 tokens either way to left or right
heavy_window_size = None,
heavy_dim_head = 64,
heavy_heads = 8,
router_straight_through = True, # would make sure all normalized scores are 1., still differentiable
Expand All @@ -662,7 +675,7 @@ def __init__(

self.multiply_queries_by_score = multiply_queries_by_score

self.light_window_size = light_window_size
self.heavy_window_size = default(heavy_window_size, light_window_size)

self.light_attn = LocalMHA(
dim = dim,
Expand Down Expand Up @@ -706,27 +719,63 @@ def forward(
num_heavy_tokens_q = default(num_heavy_tokens_q, self.num_heavy_tokens_q)
num_heavy_tokens_kv = default(num_heavy_tokens_kv, self.num_heavy_tokens_kv)

batch_range = create_batch_range(x)

# light local attention sees all tokens in a limited context

light_out = self.light_attn(x)

# route tokens appropriately for heavy branch
# pad sequence to multiple of the heavy window size
# routing will take place within each heavy window block size

normalized_scores_q, indices_q = self.q_router(x, num_tokens = num_heavy_tokens_q, mask = mask)
normalized_scores_kv, indices_kv = self.kv_router(x, num_tokens = num_heavy_tokens_kv, mask = mask)
window_size = self.heavy_window_size

# select the tokens to be routed to full attention
x, seq_len = pad_to_multiple(x, window_size, dim = -2)

routed_tokens_q = x[batch_range, indices_q]
routed_tokens_kv = x[batch_range, indices_kv]
padded_seq_len = x.shape[-2]

# calculate key padding mask
# construct mask, and make sure not to attend to padding

routed_tokens_kv_mask = None
if exists(mask):
routed_tokens_kv_mask = mask[batch_range, indices_kv]
q_mask = torch.ones((batch, seq_len), dtype = torch.bool, device = device)
q_mask = F.pad(q_mask, (0, padded_seq_len - seq_len), value = False)

# block the sequence and mask into windows for the queries

q = rearrange(x, 'b (n w) d -> (b n) w d', w = window_size)
q_mask = rearrange(q_mask, 'b (n w) -> (b n) w', w = window_size)

# each block of queries attend to sequences that are causally masked out appropriately

windows = padded_seq_len // window_size

kv = repeat(x, 'b n d -> (b m) n d', m = windows)

kv_mask = torch.ones((windows, windows), dtype = torch.bool, device = device).tril(-1)
kv_mask = repeat(kv_mask, 'm n -> (b m) (n w)', b = batch, w = window_size)

# route tokens appropriately for heavy branch, if need be

should_route_q = q.shape[-2] > num_heavy_tokens_q
should_route_kv = kv.shape[-2] > num_heavy_tokens_kv

if should_route_q:
normalized_scores_q, indices_q = self.q_router(q, num_tokens = num_heavy_tokens_q, mask = q_mask)

q_batch_range = create_batch_range(q)
routed_tokens_q = q[q_batch_range, indices_q]
else:
normalized_scores_q = 1.
routed_tokens_q = q

if should_route_kv:
normalized_scores_kv, indices_kv = self.kv_router(kv, num_tokens = num_heavy_tokens_kv, mask = kv_mask)

kv_batch_range = create_batch_range(kv)

routed_tokens_kv = kv[kv_batch_range, indices_kv]
routed_tokens_kv_mask = kv_mask[kv_batch_range, indices_kv]
else:
normalized_scores_kv = 1.
routed_tokens_kv = kv
routed_tokens_kv_mask = kv_mask

# do the heavier branch with only routed tokens

Expand All @@ -742,9 +791,14 @@ def forward(

# scatter back the output of the heavy branch

heavy_out = torch.zeros_like(x)
heavy_out = torch.zeros_like(q)
heavy_out = self.q_router.route_back(heavy_out, routed_tokens_out, indices_q)

# un-window and slice out original sequence

heavy_out = rearrange(heavy_out, '(b n) w d -> b (n w) d', b = batch)
heavy_out = heavy_out[:, :seq_len]

# sum light and heavy branches

return light_out + heavy_out
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CoLT5-attention',
packages = find_packages(),
version = '0.2.5',
version = '0.3.0',
license='MIT',
description = 'Conditionally Routed Attention',
long_description_content_type = 'text/markdown',
Expand All @@ -16,7 +16,7 @@
'dynamic routing'
],
install_requires=[
'einops>=0.6.0',
'einops>=0.6.1',
'local-attention>=1.8.5',
'torch>=1.10'
],
Expand Down

0 comments on commit e74a985

Please sign in to comment.