From 860081fcaabfc0125e8530507f4bf8d802596281 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 18 Aug 2023 11:47:38 -0700 Subject: [PATCH] prepare for differentiable topk use for another lib --- colt5_attention/topk.py | 8 +++++++- setup.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/colt5_attention/topk.py b/colt5_attention/topk.py index 03f8286..87df528 100644 --- a/colt5_attention/topk.py +++ b/colt5_attention/topk.py @@ -13,11 +13,17 @@ def topk( eps_init = None, eps_decay = 1., mask = None, - fused = False + fused = False, + non_differentiable = False ): """ differentiable top-k on last dimension """ + + if non_differentiable: + values, indices = torch.topk(x, k = k, dim = -1) + return TopkReturn(values, indices, None, None) + assert coor_descent_k_ratio >= 1. assert k > 0 diff --git a/setup.py b/setup.py index 11a4b0b..3d6efc9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'CoLT5-attention', packages = find_packages(), - version = '0.10.14', + version = '0.10.15', license='MIT', description = 'Conditionally Routed Attention', long_description_content_type = 'text/markdown',