Skip to content

Commit

Permalink
add a learned null query output token, for those query tokens not sel…
Browse files Browse the repository at this point in the history
…ected by the router. only for autoregressive flavor for now
  • Loading branch information
lucidrains committed May 3, 2023
1 parent 7ecbf8c commit d3612ca
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
14 changes: 12 additions & 2 deletions colt5_attention/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,8 @@ def __init__(
router_kwargs: dict = {},
multiply_keys_by_score = False,
multiply_queries_by_score = False,
use_triton = False
use_triton = False,
use_null_q_tokens = False
):
super().__init__()
assert router_type in ROUTERS.keys()
Expand Down Expand Up @@ -713,6 +714,10 @@ def __init__(
use_rotary_pos_emb = False
)

self.null_q_token = None
if use_null_q_tokens:
self.null_q_token = nn.Parameter(torch.randn(dim)) # for the query tokens not selected by the router, give it a learned output embed

self.q_router = router_klass(
dim = dim,
straight_through = router_straight_through,
Expand Down Expand Up @@ -823,7 +828,12 @@ def forward(

# scatter back the output of the heavy branch

heavy_out = torch.zeros_like(q)
if exists(self.null_q_token):
heavy_out = rearrange(self.null_q_token, 'd -> 1 1 d')
heavy_out = heavy_out.expand_as(q).clone()
else:
heavy_out = torch.zeros_like(q)

heavy_out = self.q_router.route_back(heavy_out, routed_tokens_out, indices_q)
else:
heavy_out = routed_tokens_out
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CoLT5-attention',
packages = find_packages(),
version = '0.4.4',
version = '0.4.5',
license='MIT',
description = 'Conditionally Routed Attention',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit d3612ca

Please sign in to comment.