diff --git a/colt5_attention/transformer_block.py b/colt5_attention/transformer_block.py index be601f9..ae078e3 100644 --- a/colt5_attention/transformer_block.py +++ b/colt5_attention/transformer_block.py @@ -1,3 +1,4 @@ +import math from functools import partial import torch @@ -24,6 +25,17 @@ def pack_one(t, pattern): def unpack_one(t, ps, pattern): return unpack(t, ps, pattern)[0] +def pad_to_multiple(tensor, multiple, dim=-1, value=0): + seq_len = tensor.shape[dim] + m = seq_len / multiple + if m.is_integer(): + return tensor, seq_len + + remainder = math.ceil(m) * multiple - seq_len + pad_offset = (0,) * (-1 - dim) * 2 + padded_tensor = F.pad(tensor, (*pad_offset, 0, remainder), value = value) + return padded_tensor, seq_len + # tensor helpers def create_batch_range(t): @@ -108,7 +120,7 @@ def forward( if context.ndim == 3: context = rearrange(context, 'b n d -> b 1 n d') - if exists(normalized_scores_kv): + if exists(normalized_scores_kv) and isinstance(normalized_scores_kv, torch.Tensor): if normalized_scores_kv.ndim == 2: normalized_scores_kv = rearrange(normalized_scores_kv, 'b n -> b 1 n') @@ -121,7 +133,7 @@ def forward( q = self.to_q(x) q = rearrange(q, 'b n (h d) -> b h n d', h = h) - if exists(normalized_scores_q): + if exists(normalized_scores_q) and isinstance(normalized_scores_q, torch.Tensor): q = q * rearrange(normalized_scores_q, 'b n -> b 1 n 1') # handle key / values, with the routing dimension, dividing the number of heads in between the routes @@ -642,6 +654,7 @@ def __init__( light_dim_head = 64, light_heads = 8, light_window_size = 128, # each token would see ~ 64 tokens either way to left or right + heavy_window_size = None, heavy_dim_head = 64, heavy_heads = 8, router_straight_through = True, # would make sure all normalized scores are 1., still differentiable @@ -662,7 +675,7 @@ def __init__( self.multiply_queries_by_score = multiply_queries_by_score - self.light_window_size = light_window_size + self.heavy_window_size = default(heavy_window_size, light_window_size) self.light_attn = LocalMHA( dim = dim, @@ -706,27 +719,63 @@ def forward( num_heavy_tokens_q = default(num_heavy_tokens_q, self.num_heavy_tokens_q) num_heavy_tokens_kv = default(num_heavy_tokens_kv, self.num_heavy_tokens_kv) - batch_range = create_batch_range(x) - # light local attention sees all tokens in a limited context light_out = self.light_attn(x) - # route tokens appropriately for heavy branch + # pad sequence to multiple of the heavy window size + # routing will take place within each heavy window block size - normalized_scores_q, indices_q = self.q_router(x, num_tokens = num_heavy_tokens_q, mask = mask) - normalized_scores_kv, indices_kv = self.kv_router(x, num_tokens = num_heavy_tokens_kv, mask = mask) + window_size = self.heavy_window_size - # select the tokens to be routed to full attention + x, seq_len = pad_to_multiple(x, window_size, dim = -2) - routed_tokens_q = x[batch_range, indices_q] - routed_tokens_kv = x[batch_range, indices_kv] + padded_seq_len = x.shape[-2] - # calculate key padding mask + # construct mask, and make sure not to attend to padding - routed_tokens_kv_mask = None - if exists(mask): - routed_tokens_kv_mask = mask[batch_range, indices_kv] + q_mask = torch.ones((batch, seq_len), dtype = torch.bool, device = device) + q_mask = F.pad(q_mask, (0, padded_seq_len - seq_len), value = False) + + # block the sequence and mask into windows for the queries + + q = rearrange(x, 'b (n w) d -> (b n) w d', w = window_size) + q_mask = rearrange(q_mask, 'b (n w) -> (b n) w', w = window_size) + + # each block of queries attend to sequences that are causally masked out appropriately + + windows = padded_seq_len // window_size + + kv = repeat(x, 'b n d -> (b m) n d', m = windows) + + kv_mask = torch.ones((windows, windows), dtype = torch.bool, device = device).tril(-1) + kv_mask = repeat(kv_mask, 'm n -> (b m) (n w)', b = batch, w = window_size) + + # route tokens appropriately for heavy branch, if need be + + should_route_q = q.shape[-2] > num_heavy_tokens_q + should_route_kv = kv.shape[-2] > num_heavy_tokens_kv + + if should_route_q: + normalized_scores_q, indices_q = self.q_router(q, num_tokens = num_heavy_tokens_q, mask = q_mask) + + q_batch_range = create_batch_range(q) + routed_tokens_q = q[q_batch_range, indices_q] + else: + normalized_scores_q = 1. + routed_tokens_q = q + + if should_route_kv: + normalized_scores_kv, indices_kv = self.kv_router(kv, num_tokens = num_heavy_tokens_kv, mask = kv_mask) + + kv_batch_range = create_batch_range(kv) + + routed_tokens_kv = kv[kv_batch_range, indices_kv] + routed_tokens_kv_mask = kv_mask[kv_batch_range, indices_kv] + else: + normalized_scores_kv = 1. + routed_tokens_kv = kv + routed_tokens_kv_mask = kv_mask # do the heavier branch with only routed tokens @@ -738,12 +787,20 @@ def forward( normalized_scores_q = normalized_scores_q if self.multiply_queries_by_score else None ) - routed_tokens_out = routed_tokens_out * rearrange(normalized_scores_q, '... -> ... 1') + if should_route_q: + routed_tokens_out = routed_tokens_out * rearrange(normalized_scores_q, '... -> ... 1') - # scatter back the output of the heavy branch + # scatter back the output of the heavy branch - heavy_out = torch.zeros_like(x) - heavy_out = self.q_router.route_back(heavy_out, routed_tokens_out, indices_q) + heavy_out = torch.zeros_like(q) + heavy_out = self.q_router.route_back(heavy_out, routed_tokens_out, indices_q) + else: + heavy_out = routed_tokens_out + + # un-window and slice out original sequence + + heavy_out = rearrange(heavy_out, '(b n) w d -> b (n w) d', b = batch) + heavy_out = heavy_out[:, :seq_len] # sum light and heavy branches diff --git a/setup.py b/setup.py index 903f08d..3bed132 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'CoLT5-attention', packages = find_packages(), - version = '0.2.5', + version = '0.3.1', license='MIT', description = 'Conditionally Routed Attention', long_description_content_type = 'text/markdown', @@ -16,7 +16,7 @@ 'dynamic routing' ], install_requires=[ - 'einops>=0.6.0', + 'einops>=0.6.1', 'local-attention>=1.8.5', 'torch>=1.10' ],