Skip to content

Commit

Permalink
fix for routed cross attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 28, 2023
1 parent adecc5a commit 3a30279
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions colt5_attention/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,11 +898,12 @@ def forward(
if should_route_kv:
normalized_scores_kv, indices_kv = self.kv_router(context, num_tokens = num_tokens_kv, mask = context_mask)

routed_tokens_kv = x[batch_range, indices_kv]
kv_batch_range = create_batch_range(x, right_pad_dims = indices_kv.ndim - 1)
routed_tokens_kv = x[kv_batch_range, indices_kv]

routed_tokens_kv_mask = None
if exists(context_mask):
routed_tokens_kv_mask = context_mask[batch_range, indices_kv]
routed_tokens_kv_mask = context_mask[kv_batch_range, indices_kv]

# do the heavier branch with only routed tokens

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CoLT5-attention',
packages = find_packages(),
version = '0.3.2',
version = '0.3.3',
license='MIT',
description = 'Conditionally Routed Attention',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 3a30279

Please sign in to comment.