diff --git a/fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu b/fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu index f5e87c8f..abfa6e68 100644 --- a/fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu +++ b/fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu @@ -427,7 +427,7 @@ at::Tensor fused_mask_softmax_forward(at::Tensor input, at::Tensor mask, long lo CHECK_INPUT(input); CHECK_INPUT(mask); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - int head = input.sizes()[2]; + int head = input.sizes().at(input.sizes().size() - 3); // at::Tensor output = at::empty_like(input); int grid = (rows + 3) / 4; @@ -589,7 +589,7 @@ at::Tensor fused_mask_softmax_backward(at::Tensor d_output, at::Tensor output, a CHECK_INPUT(output); CHECK_INPUT(mask); const at::cuda::OptionalCUDAGuard device_guard(device_of(mask)); - int head = output.sizes()[2]; + int head = output.sizes().at(output.sizes().size() - 3); at::Tensor grad_input = at::empty_like(output); int grid = (rows + 3) / 4; @@ -711,7 +711,7 @@ at::Tensor fused_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at CHECK_INPUT(mask); CHECK_INPUT(bias); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - int head = input.sizes()[2]; + int head = input.sizes().at(input.sizes().size() - 3); // at::Tensor output = at::empty_like(input); int grid = (rows + 3) / 4; @@ -814,7 +814,7 @@ at::Tensor fused_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor outp CHECK_INPUT(output); CHECK_INPUT(mask); const at::cuda::OptionalCUDAGuard device_guard(device_of(mask)); - int head = output.sizes()[2]; + int head = output.sizes().at(output.sizes().size() - 3); at::Tensor grad_input = at::empty_like(output); int grid = (rows + 3) / 4; diff --git a/fastfold/model/fastnn/kernel/triton/softmax.py b/fastfold/model/fastnn/kernel/triton/softmax.py index 61cf4223..14503987 100644 --- a/fastfold/model/fastnn/kernel/triton/softmax.py +++ b/fastfold/model/fastnn/kernel/triton/softmax.py @@ -8,7 +8,7 @@ def _softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols, use_mask: tl.constexpr, use_bias: tl.constexpr): - row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=float("-inf")).to(tl.float32) if use_bias: bias = tl.load(bias_ptrs, mask=col_offsets < n_cols, other=float("-inf")).to(tl.float32) @@ -44,7 +44,7 @@ def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, col_offsets, n_ @triton.jit def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride, - output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr, + output_row_stride, n_cols, n_heads, n_chunks, BLOCK_SIZE: tl.constexpr, use_mask: tl.constexpr, use_bias: tl.constexpr): row_idx = tl.program_id(0).to(tl.int64) col_offsets = tl.arange(0, BLOCK_SIZE) @@ -62,7 +62,7 @@ def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_ro bias_ptrs = input_ptrs # place holder, not use if use_bias == False if use_bias: - bias_row_ptr = bias_ptr + (row_idx % (n_heads * n_cols)) * n_cols + bias_row_ptr = bias_ptr + (((row_idx // (n_chunks * n_heads * n_cols)) * n_heads * n_cols) + (row_idx % (n_heads * n_cols))) * n_cols bias_ptrs = bias_row_ptr + col_offsets _softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols, use_mask, @@ -71,7 +71,7 @@ def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_ro @triton.jit def softmax_mask_bias_kernel_two_rows(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride, - output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr, + output_row_stride, n_cols, n_heads, n_chunks, BLOCK_SIZE: tl.constexpr, use_mask: tl.constexpr, use_bias: tl.constexpr): row_idx = tl.program_id(0).to(tl.int64) col_offsets = tl.arange(0, BLOCK_SIZE) @@ -89,7 +89,7 @@ def softmax_mask_bias_kernel_two_rows(output_ptr, input_ptr, mask_ptr, bias_ptr, bias_ptrs = input_ptrs # place holder, not use if use_bias == False if use_bias: - bias_row_ptr = bias_ptr + ((2 * row_idx) % (n_heads * n_cols)) * n_cols + bias_row_ptr = bias_ptr + ((((2 * row_idx) // (n_chunks * n_heads * n_cols)) * n_heads * n_cols) + ((2 * row_idx) % (n_heads * n_cols))) * n_cols bias_ptrs = bias_row_ptr + col_offsets _softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols, use_mask, @@ -102,7 +102,7 @@ def softmax_mask_bias_kernel_two_rows(output_ptr, input_ptr, mask_ptr, bias_ptr, bias_ptrs = input_ptrs # place holder, not use if use_bias == False if use_bias: - bias_row_ptr = bias_ptr + ((2 * row_idx + 1) % (n_heads * n_cols)) * n_cols + bias_row_ptr = bias_ptr + ((((2 * row_idx + 1) // (n_chunks * n_heads * n_cols)) * n_heads * n_cols) + ((2 * row_idx + 1) % (n_heads * n_cols))) * n_cols bias_ptrs = bias_row_ptr + col_offsets _softmax_core(input_ptrs + n_cols, output_ptrs + n_cols, mask_ptrs, bias_ptrs, col_offsets, @@ -152,8 +152,8 @@ def softmax_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, d_output def softmax_triton_kernel_wrapper(x, mask, bias, n_rows, n_cols): y = torch.empty_like(x) - n_heads = x.shape[2] - + n_heads = x.shape[-3] + n_chunks = x.shape[-4] num_warps = 1 BLOCK_SIZE = triton.next_power_of_2(n_cols) if BLOCK_SIZE >= 1024: @@ -178,6 +178,7 @@ def softmax_triton_kernel_wrapper(x, mask, bias, n_rows, n_cols): y.stride(-2), n_cols, n_heads, + n_chunks, num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, use_mask=(mask != None), diff --git a/fastfold/model/fastnn/ops.py b/fastfold/model/fastnn/ops.py index 08d464d6..d0e8e552 100644 --- a/fastfold/model/fastnn/ops.py +++ b/fastfold/model/fastnn/ops.py @@ -310,7 +310,7 @@ def forward(self, in_data, mask, nonbatched_bias=None): logits = torch.matmul(q, k.transpose(-1, -2)) if nonbatched_bias is not None: - weights = fused_softmax(logits, mask, bias.unsqueeze(1)) + weights = fused_softmax(logits, mask, bias) else: weights = fused_softmax(logits, mask) @@ -343,7 +343,7 @@ def forward(self, in_data, mask, nonbatched_bias=None): # logits += bias.unsqueeze(1) # logits += (1e9 * (mask_part - 1))[..., :, None, None, :] # weights = torch.nn.functional.softmax(logits, -1) - weights = fused_softmax(logits, mask_part, bias.unsqueeze(1)) + weights = fused_softmax(logits, mask_part, bias) else: # logits += (1e9 * (mask_part - 1))[..., :, None, None, :] # weights = torch.nn.functional.softmax(logits, -1) diff --git a/tests/test_fastnn/test_attention_core.py b/tests/test_fastnn/test_attention_core.py index cb1e7cdd..cbe32de2 100644 --- a/tests/test_fastnn/test_attention_core.py +++ b/tests/test_fastnn/test_attention_core.py @@ -5,7 +5,7 @@ import torch from einops import rearrange -TEST_TRITON = False +TEST_TRITON = True try: from fastfold.model.fastnn.kernel import fused_attention_core except: diff --git a/tests/test_fastnn/test_batch_softmax.py b/tests/test_fastnn/test_batch_softmax.py new file mode 100644 index 00000000..1f87ca24 --- /dev/null +++ b/tests/test_fastnn/test_batch_softmax.py @@ -0,0 +1,71 @@ +# Testing the triton edition softmax only +# For triton softmax support an additional batch + +import torch +import pytest +from fastfold.model.fastnn.kernel import fused_softmax + + +TEST_TRITON = True +try: + import triton +except: + print("Skip triton attention test!") + TEST_TRITON = False + +def _test_softmax_core(): + + batch, batch_, chunk_, head_ = 3, 1, 8, 4 + test_seq_ = [31, 32, 128, 129, 256, 259, 512, 700, 1024] + test_dtype = [torch.float32, torch.float16, torch.bfloat16] + test_device = torch.device("cuda") + + tolerance_eps = {torch.float32: 1e-6, torch.float16: 2e-4, torch.bfloat16: 1e-3} + + for seq_ in test_seq_: + for dtype in test_dtype: + sample_input = torch.rand(batch, batch_, chunk_, head_, seq_, + seq_).to(device=test_device, dtype=dtype).requires_grad_(True) + sample_mask = torch.cuda.FloatTensor(batch, batch_, chunk_, seq_).uniform_() > 0 + sample_mask = sample_mask.to(device=test_device, dtype=dtype).requires_grad_(False) + sample_bias = torch.rand(batch, batch_, 1, head_, seq_, + seq_).to(device=test_device, dtype=dtype).requires_grad_(True) + + sample_input_fastnn = torch.clone(sample_input.detach()).requires_grad_(True) + sample_mask_fastnn = torch.clone(sample_mask.detach()).requires_grad_(False) + sample_bias_fastnn = torch.clone(sample_bias.detach()).requires_grad_(True) + + # Forward + sample_mask_torch = 1e9 * (sample_mask - 1)[:, :, :, None, None, :] + torch_out = torch.nn.functional.softmax(sample_input + sample_mask_torch + sample_bias, + dim=-1) + + fastnn_out = fused_softmax(sample_input_fastnn, sample_mask_fastnn, sample_bias_fastnn) + # print(sample_bias_fastnn) + # print(fastnn_out) + fwd_fastnn_error = torch.max(torch.abs(torch_out - fastnn_out)).cpu().item() + assert fwd_fastnn_error < tolerance_eps[ + dtype], f"fastnn fwd kernel error when {seq_} {dtype}" + + # Backward + out_grad = torch.rand_like(torch_out).requires_grad_(False) + torch_out.backward(out_grad) + fastnn_out.backward(out_grad) + + grad_input_error = torch.max(torch.abs(sample_input.grad - + sample_input_fastnn.grad)).cpu().item() + assert grad_input_error < tolerance_eps[ + dtype], f"fastnn bwd kernel error when {seq_} {dtype}" + + grad_bias_error = torch.max(torch.abs(sample_bias.grad - + sample_bias_fastnn.grad)).cpu().item() + assert grad_bias_error < tolerance_eps[ + dtype], f"fastnn bwd kernel error when {seq_} {dtype}" + + +@pytest.mark.skipif(TEST_TRITON == False, reason="triton is not available") +def test_softmax(): + _test_softmax_core() + +if __name__ == "__main__": + test_softmax() diff --git a/tests/test_fastnn/test_softmax.py b/tests/test_fastnn/test_softmax.py index 62b7b0ba..0f6e5547 100644 --- a/tests/test_fastnn/test_softmax.py +++ b/tests/test_fastnn/test_softmax.py @@ -20,7 +20,7 @@ def _test_softmax_core(): sample_mask = torch.cuda.FloatTensor(batch_, chunk_, seq_).uniform_() > 0 sample_mask = sample_mask.to(device=test_device, dtype=dtype).requires_grad_(False) sample_bias = torch.rand(batch_, 1, head_, seq_, - seq_).to(device=test_device, dtype=dtype).requires_grad_(True) + seq_).to(device=test_device, dtype=dtype).requires_grad_(True) sample_input_fastnn = torch.clone(sample_input.detach()).requires_grad_(True) sample_mask_fastnn = torch.clone(sample_mask.detach()).requires_grad_(False)