Skip to content

Commit

Permalink
make sure it works with masking
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 2, 2023
1 parent 6a4106d commit 220f104
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
19 changes: 16 additions & 3 deletions colt5_attention/triton_coor_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,17 @@ def forward(
x,
n_iters,
k,
eps
eps,
mask
):
assert x.is_cuda

batch, requires_grad = x.shape[0], x.requires_grad

if exists(mask):
mask_value = -torch.finfo(x.dtype).max
x = x.masked_fill(~mask, mask_value)

x, shape = pack_one(x, '* n')

n_rows, n_cols = x.shape
Expand Down Expand Up @@ -309,6 +314,14 @@ def backward(
BLOCK_SIZE = BLOCK_SIZE
)

return unpack_one(dx, shape, '* n'), None, None, None
return unpack_one(dx, shape, '* n'), None, None, None, None

triton_coor_descent = _coor_descent.apply
def triton_coor_descent(
s,
*,
n_iters,
k,
eps = 1e-1,
mask = None
):
return _coor_descent.apply(s, n_iters, k, eps, mask)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CoLT5-attention',
packages = find_packages(),
version = '0.3.7',
version = '0.3.8',
license='MIT',
description = 'Conditionally Routed Attention',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 220f104

Please sign in to comment.