Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

triton softmax support multi-batch #152

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ at::Tensor fused_mask_softmax_forward(at::Tensor input, at::Tensor mask, long lo
CHECK_INPUT(input);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
int head = input.sizes()[2];
int head = input.sizes().at(input.sizes().size() - 3);
// at::Tensor output = at::empty_like(input);

int grid = (rows + 3) / 4;
Expand Down Expand Up @@ -589,7 +589,7 @@ at::Tensor fused_mask_softmax_backward(at::Tensor d_output, at::Tensor output, a
CHECK_INPUT(output);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
int head = output.sizes()[2];
int head = output.sizes().at(output.sizes().size() - 3);
at::Tensor grad_input = at::empty_like(output);

int grid = (rows + 3) / 4;
Expand Down Expand Up @@ -711,7 +711,7 @@ at::Tensor fused_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at
CHECK_INPUT(mask);
CHECK_INPUT(bias);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
int head = input.sizes()[2];
int head = input.sizes().at(input.sizes().size() - 3);
// at::Tensor output = at::empty_like(input);

int grid = (rows + 3) / 4;
Expand Down Expand Up @@ -814,7 +814,7 @@ at::Tensor fused_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor outp
CHECK_INPUT(output);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
int head = output.sizes()[2];
int head = output.sizes().at(output.sizes().size() - 3);
at::Tensor grad_input = at::empty_like(output);

int grid = (rows + 3) / 4;
Expand Down
2 changes: 1 addition & 1 deletion fastfold/model/fastnn/kernel/triton/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def softmax_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, d_output

def softmax_triton_kernel_wrapper(x, mask, bias, n_rows, n_cols):
y = torch.empty_like(x)
n_heads = x.shape[2]
n_heads = x.shape[-3]

num_warps = 1
BLOCK_SIZE = triton.next_power_of_2(n_cols)
Expand Down
4 changes: 2 additions & 2 deletions fastfold/model/fastnn/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def forward(self, in_data, mask, nonbatched_bias=None):
logits = torch.matmul(q, k.transpose(-1, -2))

if nonbatched_bias is not None:
weights = fused_softmax(logits, mask, bias.unsqueeze(1))
weights = fused_softmax(logits, mask, bias)
else:
weights = fused_softmax(logits, mask)

Expand Down Expand Up @@ -343,7 +343,7 @@ def forward(self, in_data, mask, nonbatched_bias=None):
# logits += bias.unsqueeze(1)
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights = fused_softmax(logits, mask_part, bias.unsqueeze(1))
weights = fused_softmax(logits, mask_part, bias)
else:
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fastnn/test_attention_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from einops import rearrange

TEST_TRITON = False
TEST_TRITON = True
try:
from fastfold.model.fastnn.kernel import fused_attention_core
except:
Expand Down
10 changes: 5 additions & 5 deletions tests/test_fastnn/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def _test_softmax_core():

batch_, chunk_, head_ = 1, 8, 4
batch, batch_, chunk_, head_ = 1, 1, 8, 4
test_seq_ = [31, 32, 128, 129, 256, 259, 512, 700, 1024]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
Expand All @@ -15,19 +15,19 @@ def _test_softmax_core():

for seq_ in test_seq_:
for dtype in test_dtype:
sample_input = torch.rand(batch_, chunk_, head_, seq_,
sample_input = torch.rand(batch, batch_, chunk_, head_, seq_,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, without the batch dimension it would work correctly as well

seq_).to(device=test_device, dtype=dtype).requires_grad_(True)
sample_mask = torch.cuda.FloatTensor(batch_, chunk_, seq_).uniform_() > 0
sample_mask = torch.cuda.FloatTensor(batch, batch_, chunk_, seq_).uniform_() > 0
sample_mask = sample_mask.to(device=test_device, dtype=dtype).requires_grad_(False)
sample_bias = torch.rand(batch_, 1, head_, seq_,
sample_bias = torch.rand(batch, batch_, 1, head_, seq_,
seq_).to(device=test_device, dtype=dtype).requires_grad_(True)

sample_input_fastnn = torch.clone(sample_input.detach()).requires_grad_(True)
sample_mask_fastnn = torch.clone(sample_mask.detach()).requires_grad_(False)
sample_bias_fastnn = torch.clone(sample_bias.detach()).requires_grad_(True)

# Forward
sample_mask_torch = 1e9 * (sample_mask - 1)[:, :, None, None, :]
sample_mask_torch = 1e9 * (sample_mask - 1)[:, :, :, None, None, :]
torch_out = torch.nn.functional.softmax(sample_input + sample_mask_torch + sample_bias,
dim=-1)

Expand Down