From 3a3027961ee381c1b4c876a40151ad5b57e1a1d8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 28 Apr 2023 16:29:17 -0700 Subject: [PATCH] fix for routed cross attention --- colt5_attention/transformer_block.py | 5 +++-- setup.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/colt5_attention/transformer_block.py b/colt5_attention/transformer_block.py index 54a43f3..fdcd727 100644 --- a/colt5_attention/transformer_block.py +++ b/colt5_attention/transformer_block.py @@ -898,11 +898,12 @@ def forward( if should_route_kv: normalized_scores_kv, indices_kv = self.kv_router(context, num_tokens = num_tokens_kv, mask = context_mask) - routed_tokens_kv = x[batch_range, indices_kv] + kv_batch_range = create_batch_range(x, right_pad_dims = indices_kv.ndim - 1) + routed_tokens_kv = x[kv_batch_range, indices_kv] routed_tokens_kv_mask = None if exists(context_mask): - routed_tokens_kv_mask = context_mask[batch_range, indices_kv] + routed_tokens_kv_mask = context_mask[kv_batch_range, indices_kv] # do the heavier branch with only routed tokens diff --git a/setup.py b/setup.py index 71a711b..070d77f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'CoLT5-attention', packages = find_packages(), - version = '0.3.2', + version = '0.3.3', license='MIT', description = 'Conditionally Routed Attention', long_description_content_type = 'text/markdown',