Skip to content

Commit

Permalink
num routed key / value sets feature can also work for self attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 19, 2023
1 parent 83677b0 commit 47724ac
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 4 additions & 0 deletions colt5_attention/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def __init__(
*,
num_heavy_tokens_q,
num_heavy_tokens_kv,
num_routed_kv = 1,
light_dim_head = 64,
light_heads = 8,
light_window_size = 128, # each token would see ~ 64 tokens either way to left or right
Expand Down Expand Up @@ -525,6 +526,7 @@ def __init__(

self.kv_router = router_klass(
dim = dim,
num_routing_tokens = num_routed_kv,
straight_through = router_straight_through,
**router_kwargs
)
Expand Down Expand Up @@ -726,6 +728,7 @@ def __init__(
*,
num_heavy_attn_tokens_q,
num_heavy_attn_tokens_kv,
num_routed_kv = 1,
num_heavy_ff_tokens,
light_dim_head = 64,
light_heads = 8,
Expand Down Expand Up @@ -758,6 +761,7 @@ def __init__(
heavy_heads = heavy_heads,
num_heavy_tokens_q = num_heavy_attn_tokens_q,
num_heavy_tokens_kv = num_heavy_attn_tokens_kv,
num_routed_kv = num_routed_kv,
router_straight_through = router_straight_through,
router_type = router_type,
router_kwargs = router_kwargs
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CoLT5-attention',
packages = find_packages(),
version = '0.2.2',
version = '0.2.3',
license='MIT',
description = 'Conditionally Routed Attention',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 47724ac

Please sign in to comment.