diff --git a/aten/src/ATen/cuda/detail/UnpackRaw.cuh b/aten/src/ATen/cuda/detail/UnpackRaw.cuh index f8fa4ebbf160ae..96b258fcb6916b 100644 --- a/aten/src/ATen/cuda/detail/UnpackRaw.cuh +++ b/aten/src/ATen/cuda/detail/UnpackRaw.cuh @@ -15,7 +15,7 @@ namespace philox { // Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda. // // The raw definition lives in its own file so jit codegen can easily copy it. -__device__ __forceinline__ std::tuple +__host__ __device__ __forceinline__ std::tuple unpack(at::PhiloxCudaState arg) { if (arg.captured_) { // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1f07b6474bc0cb..a63582f9b4e189 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14176,14 +14176,17 @@ dispatch: CUDA: _scaled_dot_product_flash_attention_backward_cuda -- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp) +- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) dispatch: CUDA: _scaled_dot_product_efficient_attention_cuda NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda + tags: nondeterministic_seeded -- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False, bool chunk_grad_outputs=False, *, float? scale=None) -> (Tensor, Tensor, Tensor) +- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor) + device_check: NoCheck dispatch: CUDA: _scaled_dot_product_efficient_attention_backward_cuda + tags: nondeterministic_seeded # THIS FUNCTION iS DEPRECATED AND SHOULD BE REMOVED - func: _chunk_grad_outputs_efficient_attention(Tensor query, Tensor key, Tensor value, bool is_causal=False) -> bool @@ -14203,13 +14206,13 @@ CUDA: _flash_attention_backward # Returns ouput, logsumexp if compute_logsumexp -- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp) +- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset) variants: function dispatch: CUDA: _efficient_attention_forward tags: nondeterministic_seeded -- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor rng_seed, Tensor rng_offset, int custom_mask_type, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor) +- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor) device_check: NoCheck variants: function dispatch: @@ -14219,7 +14222,13 @@ variants: function dispatch: CUDA: triton_scaled_dot_attention + tags: nondeterministic_seeded autogen: _triton_scaled_dot_attention.out + +- func: _fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!) + variants: function + dispatch: + CUDA: _fill_mem_eff_dropout_mask_ tags: nondeterministic_seeded - func: _triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 79f3bbd6f39d97..f30be772f909c6 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -738,12 +738,13 @@ _scaled_dot_product_flash_attention_nestedtensor_cuda( debug_attn_mask); } -std::tuple +std::tuple _scaled_dot_product_efficient_attention_nestedtensor_cuda( const Tensor& query, const Tensor& key, const Tensor& value, bool compute_log_sumexp, + double dropout_p, bool is_causal, c10::optional scale) { Tensor query_buffer_reshaped, key_buffer_reshaped, value_buffer_reshaped, @@ -763,8 +764,8 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda( ? sdp::CustomMaskType::CausalFromTopLeft : sdp::CustomMaskType::NoCustomMask; - Tensor attention, log_sumexp; - std::tie(attention, log_sumexp) = at::_efficient_attention_forward( + // See Note [Seed and Offset] for description of seed and offset + auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward( query_buffer_reshaped.unsqueeze(0), key_buffer_reshaped.unsqueeze(0), value_buffer_reshaped.unsqueeze(0), @@ -772,14 +773,14 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda( cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, - 0.0 /*dropout_p*/, + dropout_p, static_cast(custom_mask_type), compute_log_sumexp, scale); // Reshape output to convert nnz to batch_size and seq_len attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2); - return std::make_tuple(std::move(attention), std::move(log_sumexp)); + return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset)); } } // namespace native diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index f713c9e6c6054e..6324415f44b852 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -615,7 +615,7 @@ Tensor scaled_dot_product_attention( (query_.requires_grad() || key.requires_grad() || value.requires_grad()); auto out_and_lse = at::_scaled_dot_product_efficient_attention( - query_, key, value, compute_logsumexp, is_causal, scale); + query_, key, value, compute_logsumexp, dropout_p, is_causal, scale); return std::get<0>(out_and_lse); } case sdp::SDPBackend::math: @@ -682,9 +682,12 @@ std::tuple _scaled_dot_product_attention_math( attn = at::softmax(attn, -1); if (dropout_p > 0.0) { if (dropout_mask.has_value()) { - auto attn_dropout_masked = attn.masked_fill(dropout_mask->logical_not(), 0.0); + // In order to validate the correctness of the fused kernels, we need to + // use the same dropout mask in order to compare the results. + TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes."); + attn = attn.masked_fill(dropout_mask->logical_not(), 0.0); auto dropout_scaling = 1.0 / (1 - dropout_p); - return std::make_tuple(at::matmul(attn_dropout_masked, value * dropout_scaling), attn); + return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn); } else { attn = at::dropout(attn, dropout_p, true); } diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index c17547a825b1b7..075366c7b30eec 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -18,6 +19,13 @@ #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + #include #include @@ -700,11 +708,12 @@ std::tuple _scaled_dot_product_efficient_attention_cuda( +std::tuple _scaled_dot_product_efficient_attention_cuda( const Tensor& query, const Tensor& key, const Tensor& value, bool compute_log_sumexp, + double dropout_p, bool is_causal, c10::optional scale) { // Used for tracking usage statistics @@ -720,8 +729,7 @@ std::tuple _scaled_dot_product_efficient_attention_cuda( ? sdp::CustomMaskType::CausalFromTopLeft : sdp::CustomMaskType::NoCustomMask; - Tensor attention, log_sumexp; - std::tie(attention, log_sumexp) = at::_efficient_attention_forward( + auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward( q_t, k_t, v_t, @@ -729,13 +737,13 @@ std::tuple _scaled_dot_product_efficient_attention_cuda( c10::nullopt, c10::nullopt, c10::nullopt, - 0.0 /*dropout_p*/, + dropout_p /*dropout_p*/, static_cast(custom_mask_type), compute_log_sumexp, scale); attention = attention.transpose(1, 2); - return std::make_tuple(std::move(attention), std::move(log_sumexp)); + return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset)); } int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value, @@ -774,8 +782,7 @@ std::tuple _flash_attention_forward( const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); at::Tensor output = at::empty_like(query); - Tensor logsumexp, debug_attn_mask, philox_seed, philox_offset; - std::tie(logsumexp, philox_seed, philox_offset, debug_attn_mask) = pytorch_fmha::mha_fwd( + auto [logsumexp, philox_seed, philox_offset, debug_attn_mask] = pytorch_fmha::mha_fwd( query, key, value, @@ -791,7 +798,8 @@ std::tuple _flash_attention_forward( return_debug_mask, /*return_softmax (this is used for testing)*/ num_splits); - debug_attn_mask = return_debug_mask ? debug_attn_mask : at::empty({0}, query.options()); + debug_attn_mask = + return_debug_mask ? debug_attn_mask : at::empty({0}, query.options()); return std::make_tuple(output, logsumexp, philox_seed, philox_offset, debug_attn_mask); #endif @@ -799,7 +807,7 @@ std::tuple _flash_attention_forward( return std::make_tuple(Tensor(), Tensor(), Tensor(), Tensor(), Tensor()); } -std::tuple _efficient_attention_forward( +std::tuple _efficient_attention_forward( const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K] const at::Tensor& value, // [b, seqlen, num_heads, Kv] @@ -875,18 +883,47 @@ std::tuple _efficient_attention_forward( at::Tensor res; at::Tensor logsumexp; + at::Tensor seed_t, offset_t; const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; - at::PhiloxCudaState rng_engine_inputs; + + // Note [Seed and Offset Device] + // If we are currently in graph capture mode, we need to create the seed and offset tensors on the device. + // This is necessary for CUDA graph-safe random number generation, which requires the seed and offset tensors + // to be single element tensors on device. During graph capture, when the seed and offset tensors are passed + // the pointers act as scratch space for storing the RNG state for the backwards pass. + // When calling backwards, we either construct a PhiloxState with the pointers or the actual values. + // For more information on CUDA graph-safe RNG states, see Note [CUDA Graph-safe RNG states]. + + at::PhiloxCudaState philox_state; + const bool in_capture_stream = + at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None; + auto device = in_capture_stream ? at::kCUDA : at::kCPU; if (use_dropout) { - at::CUDAGeneratorImpl* gen = - at::get_generator_or_default( - c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + auto gen = at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); // if using dropout, we produce 1 random number for each element of the // attention tensor - rng_engine_inputs = gen->philox_cuda_state(B * num_heads * M * N); + philox_state = gen->philox_cuda_state(B * num_heads * M * N); + + if (in_capture_stream) { + // The seed and offset will be populated by the kernel + seed_t = at::empty({}, at::dtype(at::kLong).device(device)); + offset_t = at::empty({}, at::dtype(at::kLong).device(device)); + } else { + auto [seed, offset] = at::cuda::philox::unpack(philox_state); + seed_t = at::scalar_tensor( + at::Scalar(static_cast(seed)), at::dtype(at::kLong)); + offset_t = at::scalar_tensor( + at::Scalar(static_cast(offset)), at::dtype(at::kLong)); + } + } else { + // Not using dropout + seed_t = at::empty({}, at::dtype(at::kLong).device(device)); + offset_t = at::empty({}, at::dtype(at::kLong).device(device)); } cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index()); @@ -940,7 +977,6 @@ std::tuple _efficient_attention_forward( num_heads, compute_logsumexp ? ceil_div(max_seqlen_q, kAlignLSE) * kAlignLSE : 0}, query.options().dtype(at::ScalarType::Float)); - typename Kernel::Params p; p.query_ptr = (scalar_t*)query.data_ptr(); p.key_ptr = (scalar_t*)key.data_ptr(); @@ -1018,8 +1054,10 @@ std::tuple _efficient_attention_forward( p.use_dropout = use_dropout; if (p.use_dropout) { - p.rng_engine_inputs = rng_engine_inputs; + p.rng_engine_inputs = philox_state; p.dropout_prob = dropout_p; + p.seed = seed_t.data_ptr(); + p.extragraph_offset = offset_t.data_ptr(); } if (smem_bytes > 0xc000) { @@ -1043,19 +1081,14 @@ std::tuple _efficient_attention_forward( TORCH_CHECK(kernel_launched, "cutlassF: no kernel found to launch!"); AT_CUDA_CHECK(cudaGetLastError()); - // !!TODO_DRISS: We are throwing this away for now and need to change how its done - // uint64_t -> int64_t bitwise casting as PyTorch don't support uint64_t - // so just fake it as a int64_t - int64_t seed, offset; - if (use_dropout) { - std::memcpy(&seed, &rng_engine_inputs.seed_, sizeof(seed)); - std::memcpy(&offset, &rng_engine_inputs.offset_.val, sizeof(offset)); - } - - return std::make_tuple(res, logsumexp); + return std::make_tuple( + std::move(res), + std::move(logsumexp), + std::move(seed_t), + std::move(offset_t)); #endif TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") - return std::make_tuple(Tensor{}, Tensor{}); + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); } Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tensor& v, double dropout_p){ @@ -1083,6 +1116,93 @@ bool _chunk_grad_outputs_efficient_attention( return chunk_grad_outputs; } +#ifdef USE_FLASH_ATTENTION +namespace { +/** + * simple kernel that populates a tensor with rand uniform values. + * currently only used for testing purposes, not much attention + * is paid to performance. + * + * problem is partitioned as follows: + * - (batch, head) is given by block coordinates + * - each thread handles a row for a given (batch, head) + */ +template +__global__ void rand_uniform_kernel( + int64_t n_heads, + int64_t n_queries, + int64_t n_keys, + float dropout_prob, + at::PhiloxCudaState rng_engine_inputs, + mask_t* mask_out, + int64_t mask_numel) { + const int64_t batch_id = blockIdx.x; + const int64_t head_id = blockIdx.y; + const int64_t query_idx = threadIdx.x; + + const auto seeds = at::cuda::philox::unpack(rng_engine_inputs); + + const int dropout_seq_start = batch_id * (n_heads * n_queries * n_keys) + + head_id * (n_queries * n_keys); + const int64_t query_start_idx = query_idx * n_keys; + + curandStatePhilox4_32_10_t curand_state; + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + dropout_seq_start + query_start_idx, + &curand_state); + + for (int key_start_idx = 0; key_start_idx < n_keys; key_start_idx += 4) { + float4 rand_quad = curand_uniform4(&curand_state); + +#pragma unroll + for (int i = 0; i < 4; ++i) { + const int64_t linear_idx = dropout_seq_start + query_start_idx + key_start_idx + i; + if (linear_idx < mask_numel) { + mask_out[linear_idx] = (&rand_quad.x)[i]; + } + } + } +} +} // namespace +#endif +/** + * fill tensor with random uniform values. only used for testing, not much + * attention is paid to performance + */ +at::Tensor& _fill_mem_eff_dropout_mask_( + Tensor& self, + double dropout_p, + const int64_t seed, + const int64_t offset) { + TORCH_CHECK(self.is_contiguous()); + TORCH_CHECK(self.dtype() == at::ScalarType::Float); + const int64_t batch_sz = self.size(0); + const int64_t n_heads = self.size(1); + const int64_t n_queries = self.size(2); + const int64_t n_keys = self.size(3); +#if defined(USE_FLASH_ATTENTION) + + at::PhiloxCudaState rng_engine_inputs; + rng_engine_inputs = at::PhiloxCudaState(seed, offset); + at::cuda::CUDAGuard device_guard(self.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + rand_uniform_kernel<<>>( + n_heads, + n_queries, + n_keys, + dropout_p, + rng_engine_inputs, + reinterpret_cast(self.data_ptr()), + self.numel()); + + return self; +#endif + TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") + return self; +} } // namespace native } // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 7ff5dec367f6f4..5f05b8c64bd890 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -112,8 +112,8 @@ _efficient_attention_backward( int64_t max_seqlen_k, const at::Tensor& logsumexp, double dropout_p, // dropout probability - const at::Tensor& rng_seed_tensor, // seed using for generating random numbers for dropout - const at::Tensor& rng_offset_tensor, // offset into random number sequence + const at::Tensor& philox_seed, // seed using for generating random numbers for dropout + const at::Tensor& philox_offset, // offset into random number sequence int64_t custom_mask_type, const c10::optional scale, c10::optional num_splits_key) { @@ -121,11 +121,6 @@ _efficient_attention_backward( if (!grad_out_.defined()) { return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); } - // TODO_DRISS utilize these tensor correctly - // Appease the compilier for now.e These values - // will never be used Until we wire up dropout - int64_t rng_seed = *rng_seed_tensor.data_ptr(); - int64_t rng_offset = *rng_offset_tensor.data_ptr(); // ndim TORCH_CHECK(query.dim() == grad_out_.dim()); @@ -211,7 +206,22 @@ _efficient_attention_backward( at::Tensor workspace; const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; - at::PhiloxCudaState rng_engine_inputs(rng_seed, rng_offset); + + // See Note [Seed and Offset Device] + at::PhiloxCudaState rng_engine_inputs; + if (use_dropout) { + if (at::cuda::currentStreamCaptureStatus() == + at::cuda::CaptureStatus::None) { + rng_engine_inputs = at::PhiloxCudaState( + *philox_seed.data_ptr(), + *philox_offset.data_ptr()); + } else { // dropout + capture + rng_engine_inputs = at::PhiloxCudaState( + philox_seed.data_ptr(), + philox_offset.data_ptr(), + 0); + } + } cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index()); const int computeCapability = p->major * 10 + p->minor; @@ -484,7 +494,6 @@ std::tuple _scaled_dot_product_flash_attenti Tensor k_t = key.transpose(1, 2); Tensor v_t = value.transpose(1, 2); - int64_t Nnz_q{batch_size * max_seqlen_batch_q}; int64_t Nnz_kv{batch_size * max_seqlen_batch_k}; @@ -521,6 +530,7 @@ std::tuple _scaled_dot_product_flash_attenti return std::make_tuple(grad_q, grad_k, grad_v); } + std::tuple _scaled_dot_product_efficient_attention_backward_cuda( const at::Tensor& grad_out_, const at::Tensor& query, @@ -528,8 +538,10 @@ std::tuple _scaled_dot_product_efficient_att const at::Tensor& value, const at::Tensor& out, const at::Tensor& logsumexp, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + double dropout_p, bool causal, - bool chunk_grad_outputs, c10::optional scale){ if (!grad_out_.defined()) { return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); @@ -543,17 +555,14 @@ std::tuple _scaled_dot_product_efficient_att Tensor grad_q, grad_k, grad_v, grad_bias; // TODO_DRISS - // These are place holders unitl we add support for dropout and bias + // These are place holders unitl we add support for bias auto bias = c10::nullopt; - Tensor seed_t = at::empty({}, at::dtype(at::kLong)); - Tensor offset_t = at::empty({}, at::dtype(at::kLong)); // Will add with signauter changes for dropout and bias // We are only handiling Dense inputs, but this should be passed // from forward to backward int64_t max_seqlen_q = q_t.size(1); int64_t max_seqlen_k = k_t.size(1); - double dropout_p = 0.0; sdp::CustomMaskType custom_mask_type = causal ? sdp::CustomMaskType::CausalFromTopLeft @@ -573,8 +582,8 @@ std::tuple _scaled_dot_product_efficient_att max_seqlen_k, logsumexp, dropout_p, - seed_t, - offset_t, + philox_seed, + philox_offset, static_cast(custom_mask_type), scale, c10::nullopt); // num_split_keys diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index 50e166d2b03629..2b416dd5a64429 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -65,6 +65,7 @@ using namespace gemm_kernel_utils; +namespace PyTorchMemEffAttention { namespace { template @@ -1271,8 +1272,7 @@ struct AttentionBackwardKernel { } TORCH_CHECK( kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled"); - TORCH_CHECK( - p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); + TORCH_CHECK(p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); TORCH_CHECK( p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ), "Invalid `num_splits_key` (", @@ -1323,6 +1323,7 @@ struct AttentionBackwardKernel { curandStatePhilox4_32_10_t rng_state_init; if (kApplyDropout) { + // See Note [Seed and Offset Device] auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); // each element of the attention matrix P with shape // (batch_sz, n_heads, n_queries, n_keys) is associated with a single @@ -1338,7 +1339,6 @@ struct AttentionBackwardKernel { std::get<1>(seeds) + p.dropout_batch_head_rng_offset, &rng_state_init); } - CUTLASS_PRAGMA_UNROLL for (; key_start < p.num_keys; key_start += p.num_splits_key_device() * kBlockSizeJ) { @@ -2534,3 +2534,5 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) template __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) attention_kernel_backward_batched(typename AK::Params params); + +} // namespace PyTorchMemEffAttention \ No newline at end of file diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h index 4eab19aba773cb..06ab1c79c64ad0 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h @@ -54,6 +54,7 @@ using namespace gemm_kernel_utils; +namespace PyTorchMemEffAttention { namespace { template constexpr int getWarpsPerSmFw() { @@ -156,6 +157,7 @@ struct AttentionKernel { int32_t head_dim_value; int32_t num_queries; int32_t num_keys; + int32_t num_keys_absolute; uint8_t custom_mask_type = NoCustomMask; @@ -186,7 +188,8 @@ struct AttentionKernel { unsigned long long dropout_batch_head_rng_offset; float dropout_prob; at::PhiloxCudaState rng_engine_inputs; - + int64_t* extragraph_offset; + int64_t* seed; // Moves pointers to what we should process // Returns "false" if there is no work to do @@ -274,6 +277,9 @@ struct AttentionKernel { if (custom_mask_type == CausalFromBottomRight) { causal_diagonal_offset += num_keys - num_queries; } + // We use num_keys_absolute to index into the rng_state + // We need this index to match between forward and backwards + num_keys_absolute = num_keys; if (custom_mask_type == CausalFromTopLeft || custom_mask_type == CausalFromBottomRight) { // the bottom row of the current block is query_start + kQueriesPerBlock @@ -656,7 +662,16 @@ struct AttentionKernel { curandStatePhilox4_32_10_t curand_state_init; if (kSupportsDropout && p.use_dropout) { const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); - + if (p.rng_engine_inputs.captured_) { + // See Note [Seed and Offset Device] + // When we are in cuda graph capture mode the seed and offset are stored + // on device We pass in int64_t* seed, and int64_t* offset to act as + // scratch space for storing the rng state during the forward pass and + // saving for backwards. + auto [seed, offset] = seeds; + *p.seed = seed; + *p.extragraph_offset = offset; + } // each element of the attention matrix P with shape // (batch_sz, n_heads, n_queries, n_keys) is associated with a single // offset in RNG sequence. we initialize the RNG state with offset that @@ -862,7 +877,6 @@ struct AttentionKernel { __syncthreads(); - // apply dropout (if applicable) after we've written Pij to smem. // dropout is applied by multiplying each element of Pij by: // - 0 with probability dropout_p @@ -899,7 +913,7 @@ struct AttentionKernel { curandStatePhilox4_32_10_t curand_state = curand_state_init; skipahead( static_cast( - (query_start + thread_i) * p.num_keys + + (query_start + thread_i) * p.num_keys_absolute + (iter_key_start + thread_start_j)), &curand_state); const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); @@ -1269,3 +1283,5 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) template __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) attention_kernel_batched(typename AK::Params params); + +} // namespace PyTorchMemEffAttention \ No newline at end of file diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h index ffffc210bdbcb6..157b04b8a5fff9 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h @@ -8,6 +8,7 @@ // This file is auto-generated. See "generate_kernels.py" #pragma once #include +using namespace PyTorchMemEffAttention; // ======== f16 / sm70 ======== __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k128.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k128.cu index ce70d1dfd39c09..1062d04e0ba6be 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k128.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k128.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k128_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k128_dropout.cu index b1bda738ebb834..0d28c11c045699 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k128_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k128_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k32.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k32.cu index e20c42f3e11b54..37a5e11f94f25b 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k32.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k32.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k32_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k32_dropout.cu index 1a3012bc5f8659..bd467e659b064c 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k32_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k32_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k64.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k64.cu index eea8ac68042743..6f0d89e1f0c51e 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k64.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k64.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k64_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k64_dropout.cu index 52be7c916924d3..c0fce90e5c52eb 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k64_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k64_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k65536.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k65536.cu index 43084398c7573b..bddb719a35aad3 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k65536.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k65536.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k65536_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k65536_dropout.cu index 3ad57fc2cf2adc..197a833db06534 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k65536_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k65536_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k96.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k96.cu index e01b7448f56aef..8b333cded97f17 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k96.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_bf16_aligned_k96.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k128.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k128.cu index dc6991be5ceeb1..f7466d0107b876 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k128.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k128.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k128_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k128_dropout.cu index c20aacac65fbe2..32c8a832ac824e 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k128_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k128_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k32.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k32.cu index f44c7d530527ee..a3b0aa028ddd76 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k32.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k32.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k32_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k32_dropout.cu index d1d22daf2c51e1..7b5033978a166e 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k32_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k32_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k64.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k64.cu index 9f7f7df6bffae6..5948e418e3fc65 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k64.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k64.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k64_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k64_dropout.cu index 12ae95684e99c8..1656037c19da6a 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k64_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k64_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k65536.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k65536.cu index 5a082aa45e4a21..7229985438dfbe 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k65536.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k65536.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k65536_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k65536_dropout.cu index b7d62c877e00a3..2944befa6f6e04 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k65536_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k65536_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k96.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k96.cu index 46718930a50b31..5bab50aa101937 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k96.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_aligned_k96.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k128.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k128.cu index d33740fa931a9b..4e513594d87e7a 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k128.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k128.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k128_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k128_dropout.cu index b9104f203f7084..f82db06a789a64 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k128_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k128_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k32.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k32.cu index 258f308484ea2e..4b51b045fa1618 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k32.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k32.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k32_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k32_dropout.cu index ac21091c71d3fb..1cff4d6ee10232 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k32_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k32_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k64.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k64.cu index 3677e05878728f..45dd860b6d7475 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k64.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k64.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k64_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k64_dropout.cu index 3c5df60038bf89..40a10272e1e659 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k64_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k64_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k65536.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k65536.cu index 51e7cf08d5ce75..87a5e9e59e42d7 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k65536.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k65536.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k65536_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k65536_dropout.cu index 11a3e06a5b5f4e..5aafa412aa2664 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k65536_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f16_notaligned_k65536_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k128.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k128.cu index 1558aa09cc350a..3d399d4167a5cf 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k128.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k128.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k128_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k128_dropout.cu index ab7b672d665409..8d38226fa3a5b3 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k128_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k128_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k32.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k32.cu index a9d9afc610303d..1bac00c0926130 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k32.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k32.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k32_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k32_dropout.cu index 44b15e1f5f7d8d..7250f11f962056 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k32_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k32_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k64.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k64.cu index bb9e6060048636..d308ae59339ff5 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k64.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k64.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k64_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k64_dropout.cu index 8a56a1ce6cc528..4db2f760302ff6 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k64_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k64_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k65536.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k65536.cu index ce997186e4daf7..facf60a6534fc6 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k65536.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k65536.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k65536_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k65536_dropout.cu index 99b6b5ec57b1e5..b38b417364613a 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k65536_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_aligned_k65536_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k128.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k128.cu index 99ffbbc4a3c3eb..90d3c1831e0051 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k128.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k128.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k128_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k128_dropout.cu index e43fad123a3192..9bb8a6a6c7deb5 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k128_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k128_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k32.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k32.cu index 83a0914ff485d6..db1dc62a95f94a 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k32.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k32.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k32_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k32_dropout.cu index eb16f22c9e7341..d50ba8c5e81d0c 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k32_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k32_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k64.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k64.cu index e230f91f128fda..c2b996e2b97413 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k64.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k64.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k64_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k64_dropout.cu index f0519f1e8dd983..890cca641e5f27 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k64_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k64_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k65536.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k65536.cu index 24fb9af075ea25..eb6979dc7e9330 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k65536.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k65536.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k65536_dropout.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k65536_dropout.cu index a1d54e26f7400a..eddb78c1b17380 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k65536_dropout.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB_f32_notaligned_k65536_dropout.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h index c4222a11ef7998..c8e38916501eaf 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h @@ -8,6 +8,7 @@ // This file is auto-generated. See "generate_kernels.py" #pragma once #include +using namespace PyTorchMemEffAttention; // ======== bf16 / sm80 ======== __global__ void __launch_bounds__( AttentionKernel::kNumThreads, diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_bf16_aligned.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_bf16_aligned.cu index e4415a24379983..0d4a1bd6a35624 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_bf16_aligned.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_bf16_aligned.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionKernel::kNumThreads, AttentionKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f16_aligned.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f16_aligned.cu index a1d9086dd33073..2adb226181fe71 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f16_aligned.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f16_aligned.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionKernel::kNumThreads, AttentionKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f16_notaligned.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f16_notaligned.cu index 2dcefe61a9d3be..a93fda7d76b2d2 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f16_notaligned.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f16_notaligned.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionKernel::kNumThreads, AttentionKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f32_aligned.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f32_aligned.cu index fdaa936f6c15c2..f1a2c1772369db 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f32_aligned.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f32_aligned.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionKernel::kNumThreads, AttentionKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f32_notaligned.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f32_notaligned.cu index ca2301db816584..8c13e7560247df 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f32_notaligned.cu +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF_f32_notaligned.cu @@ -7,6 +7,7 @@ */ // This file is auto-generated. See "generate_kernels.py" #include +using namespace PyTorchMemEffAttention; __global__ void __launch_bounds__( AttentionKernel::kNumThreads, AttentionKernel::kMinBlocksPerSm) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py index bafca47d00ec4e..3dcbdc35b5511d 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py @@ -324,6 +324,7 @@ def write_decl_impl( declarations = cpp_file_header + "#pragma once\n" # declarations += f"#ifndef {disable_def}\n" declarations += f"""#include {impl_file}\n""" + declarations += """using namespace PyTorchMemEffAttention;\n""" # Declaration of kernel functions for k in kernels: @@ -365,6 +366,7 @@ def write_decl_impl( impl_cu = cpp_file_header # impl_cu += f"#ifndef {disable_def}\n" impl_cu += f"""#include {impl_file}\n""" + impl_cu += """using namespace PyTorchMemEffAttention;\n""" for k in f_kernels: impl_cu += k.cpp_impl # impl_cu += f"#endif // {disable_def}\n" diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 9fb7208b15cef4..203553fd44327c 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -114,17 +114,6 @@ bool check_tensor_dtype( return true; } -bool check_for_non_zero_dropout(sdp_params params, bool debug) { - if (params.dropout != 0.0) { - if (debug) { - TORCH_WARN( - "Mem_efficient does not support non_zero dropout. Dropout_p: ", - params.dropout); - } - return false; - } - return true; -} bool try_broadcast_param_size( const c10::SymInt q_size, @@ -595,8 +584,7 @@ bool use_mem_efficient_attention(sdp_params params, bool debug) { check_batch_size_and_num_heads, check_for_attn_mask, check_head_dim_size_mem_efficient, - check_for_seq_len_0_nested_tensor, - check_for_non_zero_dropout); + check_for_seq_len_0_nested_tensor); for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index d0a0aabce56634..ec1dd999f916c7 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -6,7 +6,7 @@ import torch._dynamo.test_case import torch._dynamo.testing import torch.onnx.operators -from torch._dynamo.testing import expectedFailureDynamic, same +from torch._dynamo.testing import same from torch.nn import functional as F from torch.testing._internal.common_cuda import ( @@ -291,14 +291,13 @@ def fn(a_float32, b_float32): not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Can't run fused SDPA on this platform", ) - @expectedFailureDynamic def test_autocast_sdpa(self): class MyModule(torch.nn.Module): def forward(self, query, key, value): with torch.autocast("cpu"): with torch.autocast("cuda", dtype=torch.float32): out = F.scaled_dot_product_attention( - query, key, value, None, 0.5, True + query, key, value, None, 0.0, True ) return out diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 20a189791ab87c..dd32e1a9ef20e6 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -117,6 +117,7 @@ aten::_fft_c2r aten::_fft_c2r.out aten::_fft_r2c aten::_fft_r2c.out +aten::_fill_mem_eff_dropout_mask_ aten::_flash_attention_backward aten::_flash_attention_forward aten::_foobar diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 1773ea06298ed7..533483475a2e3c 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -266,13 +266,14 @@ ("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)), ("aten::_scaled_dot_product_attention", datetime.date(2023, 3, 15)), ("aten::_scaled_dot_product_flash_attention", datetime.date(2023, 5, 15)), - ("aten::_scaled_dot_product_efficient_attention", datetime.date(2023, 6, 1)), + ("aten::_scaled_dot_product_efficient_attention", datetime.date(2023, 7, 1)), + ("aten::_scaled_dot_product_efficient_attention_backward", datetime.date(2023, 7, 1)), ("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)), ("aten::_fused_sdp_choice", datetime.date(2023, 3, 15)), ("aten::_flash_attention_forward", datetime.date(2023, 5, 15)), ("aten::_flash_attention_backward", datetime.date(2023, 5, 15)), - ("aten::_efficient_attention_forward", datetime.date(2023, 6, 1)), - ("aten::_efficient_attention_backward", datetime.date(2023, 6, 1)), + ("aten::_efficient_attention_forward", datetime.date(2023, 7, 1)), + ("aten::_efficient_attention_backward", datetime.date(2023, 7, 1)), ("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)), ("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)), ("prim::CudaFusionGuard", datetime.date(2023, 2, 1)), diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 184649586d7d28..1d9b997248c071 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6482,10 +6482,12 @@ def test_scaled_dot_product_efficient_attention(self): if self.device == "cpu": raise unittest.SkipTest("requires CUDA") + # The first two values should be the same, attention output + # and logsumexp since dropout is not being set def fn(q, k, v, compute_log_sumexp): return aten._scaled_dot_product_efficient_attention( q, k, v, compute_log_sumexp - ) + )[:2] self.common( fn, diff --git a/test/test_transformers.py b/test/test_transformers.py index c4092c59df6fba..fe800533d0607f 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -14,7 +14,7 @@ import torch.optim as optim from torch.testing._internal.common_dtype import floating_types_and_half from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Optional from torch.testing._internal.common_nn import NNTestCase from torch.testing._internal.common_utils import ( TEST_FAIRSEQ, @@ -66,6 +66,30 @@ def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: return deviation.max().item() +def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: + deviation = true_value - computed_value + atol = torch.abs(deviation).max().item() + return atol + + +def get_tolerances( + true_value: torch.Tensor, + computed_value: torch.Tensor, + fudge_factor: Optional[float] = None, +) -> Tuple[float, float]: + """Returns the absolute and relative tolerances for comparing two tensors.""" + fudge_factor = fudge_factor if fudge_factor is not None else 1.0 + atol = get_atol(true_value, computed_value) + rtol = get_rtol(true_value, computed_value) + + atol = fudge_factor * max(atol, default_atol[computed_value.dtype]) + rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype]) + # torch.isclose() has weird behavior around see: + # https://github.com/pytorch/pytorch/issues/102400 + if rtol > 1e30: + rtol = default_rtol[computed_value.dtype] + return atol, rtol + backend_map = { SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False}, SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False}, @@ -815,7 +839,7 @@ def forward( f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask" if attn_dim is not None else "no_attn_mask"))) @parametrize("dropout_p", [0.0, 0.2, 0.5]) - @sdp_kernel(enable_flash=False) + @sdp_kernel(enable_flash=False, enable_mem_efficient=False) def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p): def sdp_ref( q, @@ -1424,6 +1448,14 @@ def _get_block_size(head_dim): S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k)) return S_converted + def query_key_value_clones(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype): + """ Clones the query, key, and value tensors and moves them to the specified dtype. """ + query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad) + key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad) + value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad) + return query_ref, key_ref, value_ref + + @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system") @parametrize("type", ["dense", "nested"]) @@ -1828,15 +1860,21 @@ def test_mem_eff_backwards_determinism(self, device): @parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512]) @parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512]) @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if SM80OrLater else [8, 16, 32, 64]) - @parametrize("is_causal", [True, False]) - @parametrize("dropout_p", [0.0]) # mem_efficient_attention does not support dropout + @parametrize("is_causal", [False, True]) + @parametrize("dropout_p", [0.0, 0.22]) @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if SM80OrLater else [torch.float16, torch.float32]) @parametrize("scale", [None, "l1"]) def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, scale: str): + def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device): + mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32) + rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset) + mask = (rand_uniform > p).to(torch.float32) + return mask + seed = 42 scale = scale if scale is None else (1 / head_dim) n_heads = 4 query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, @@ -1847,27 +1885,38 @@ def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int device=device, dtype=dtype, requires_grad=True) # Run the math kernel on low precision references - query_ref_lp = query.clone().detach().requires_grad_(True) - key_ref_lp = key.clone().detach().requires_grad_(True) - value_ref_lp = value.clone().detach().requires_grad_(True) + query_ref_lp, key_ref_lp, value_ref_lp = self.query_key_value_clones(query, key, value, dtype=dtype) higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 - - query_ref = query.clone().detach().to(higher_precision_dtype).requires_grad_(True) - key_ref = key.clone().detach().to(higher_precision_dtype).requires_grad_(True) - value_ref = value.clone().detach().to(higher_precision_dtype).requires_grad_(True) + query_ref, key_ref, value_ref = self.query_key_value_clones(query, key, value, dtype=higher_precision_dtype) # Create real output with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False): + # Set the seed and run the kernel + torch.manual_seed(seed) out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False): + if dropout_p == 0.0: + with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False): + # High Precision Math Reference + out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, + dropout_p=dropout_p, is_causal=is_causal, scale=scale) + # Low Precision Math Reference + out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, + dropout_p=dropout_p, is_causal=is_causal, scale=scale) + else: + if seq_len_q > 1024: + self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!") + # Create the dropout_mask + torch.manual_seed(seed) + dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q, seq_len_k, dropout_p, seed, 0, device=device) # High Precision Math Reference - out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, - dropout_p=dropout_p, is_causal=is_causal, scale=scale) + out_ref = torch.ops.aten._scaled_dot_product_attention_math( + query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0] # Low Precision Math Reference - out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, - dropout_p=dropout_p, is_causal=is_causal, scale=scale) + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( + query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, + dropout_mask=dropout_mask)[0] upstream_grad = torch.rand_like(out, requires_grad=False) @@ -1881,25 +1930,20 @@ def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int # And we use the default rtol for the low precision type. # We then provide a fudge factor for gradients respectively to account # for the use of the fused kernel rather than the eager implemntation. - out_deviation = out_ref - out_lp_ref - output_ref_atol = max(torch.abs(out_deviation).max().item(), default_atol[out.dtype]) - output_ref_rtol = max(get_rtol(out_ref, out_lp_ref), default_rtol[out.dtype]) + output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) - grad_q_deviation = query_ref.grad - query_ref_lp.grad - grad_q_ref_atol = max(torch.abs(grad_q_deviation).max().item(), default_atol[out.dtype]) - grad_q_ref_rtol = max(get_rtol(query_ref.grad, query_ref_lp.grad), default_rtol[out.dtype]) + # Fudge Factor when dropout is enabled + dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5 + + query_fudge_factor = dropout_fudge_factor + grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) # TODO: Investigate why grad_k needs larger tolerances - grad_k_deviation = key_ref.grad - key_ref_lp.grad - fudge_factor = 7 if not isSM86or89Device else 8 - grad_k_ref_atol = max(fudge_factor * torch.abs(grad_k_deviation).max().item(), fudge_factor * default_atol[out.dtype]) - grad_k_ref_rtol = max(fudge_factor * get_rtol(key_ref.grad, key_ref_lp.grad), fudge_factor * default_rtol[out.dtype]) + key_fudge_factor = 8 * dropout_fudge_factor + grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) - grad_v_deviation = value_ref.grad - value_ref_lp.grad - grad_v_ref_atol = max(torch.abs(grad_v_deviation).max().item(), default_atol[out.dtype]) - grad_v_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 - grad_v_ref_rtol = max(grad_v_fudge_factor * get_rtol(value_ref.grad, - value_ref_lp.grad), default_rtol[out.dtype]) + value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 + grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), @@ -1934,13 +1978,10 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le device=device, dtype=dtype, requires_grad=True) # Run the math kernel on low precision references - query_ref_lp = query.clone().detach().requires_grad_(True) - key_ref_lp = key.clone().detach().requires_grad_(True) - value_ref_lp = value.clone().detach().requires_grad_(True) + query_ref_lp, key_ref_lp, value_ref_lp = self.query_key_value_clones(query, key, value, dtype=dtype) - query_ref = query.clone().detach().to(torch.float32).requires_grad_(True) - key_ref = key.clone().detach().to(torch.float32).requires_grad_(True) - value_ref = value.clone().detach().to(torch.float32).requires_grad_(True) + higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 + query_ref, key_ref, value_ref = self.query_key_value_clones(query, key, value, dtype=higher_precision_dtype) is_dropout = dropout_p > 0.0 @@ -1987,22 +2028,15 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) # See [Note] Fused Tolerances above - out_deviation = out_ref - out_lp_ref - output_ref_atol = max(torch.abs(out_deviation).max().item(), default_atol[out.dtype]) - output_ref_rtol = max(get_rtol(out_ref, out_lp_ref), default_rtol[out.dtype]) + output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) # TODO: Investigate why grad_q needs larger tolerances - grad_q_deviation = query_ref.grad - query_ref_lp.grad - grad_q_ref_atol = max(4 * torch.abs(grad_q_deviation).max().item(), default_atol[out.dtype]) - grad_q_ref_rtol = max(get_rtol(query_ref.grad, query_ref_lp.grad), default_rtol[out.dtype]) + query_fudge_factor = 4 + grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) - grad_k_deviation = key_ref.grad - key_ref_lp.grad - grad_k_ref_atol = max(torch.abs(grad_k_deviation).max().item(), default_atol[out.dtype]) - grad_k_ref_rtol = max(get_rtol(key_ref.grad, key_ref_lp.grad), default_rtol[out.dtype]) + grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad) - grad_v_deviation = value_ref.grad - value_ref_lp.grad - grad_v_ref_atol = max(torch.abs(grad_v_deviation).max().item(), default_atol[out.dtype]) - grad_v_ref_rtol = max(get_rtol(value_ref.grad, value_ref_lp.grad), default_rtol[out.dtype]) + grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad) self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), @@ -2012,103 +2046,127 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) + @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support SDPA or pre-SM80 hardware") @parametrize("batch_size", [1, 8]) - @parametrize("seq_len_q", [512, 1024, 2048]) - @parametrize("seq_len_k", [512, 1024, 2048]) + @parametrize("seq_len_q", [256, 512, 1024]) + @parametrize("seq_len_k", [256, 512, 1024]) @parametrize("head_dim", [32, 64]) @parametrize("is_causal", [True, False]) @parametrize("dropout_p", [0.0, 0.22]) @parametrize("dtype", [torch.float16,]) @parametrize("scale", [None, "l1"]) - def test_flash_attention_graph_vs_math_ref_grads(self, batch_size: int, seq_len_q: int, seq_len_k: int, - head_dim: int, - is_causal: bool, - dropout_p: float, - dtype: torch.dtype, - scale: str): - + @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) + def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, + head_dim: int, + is_causal: bool, + dropout_p: float, + dtype: torch.dtype, + scale: str, + fused_kernel: SDPBackend): + def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, dropout_p, seed, offset, device=device): + mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32) + rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, dropout_p, seed, offset) + mask = (rand_uniform > dropout_p).to(torch.float32) + return mask + + def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, dropout_p, device=device): + if fused_kernel == SDPBackend.EFFICIENT_ATTENTION: + output_seed, output_offset = output_tuple[2], output_tuple[3] + output_seed = output_seed.item() + output_offset = output_offset.item() + return _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, + dropout_p, output_seed, output_offset, device=device) + else: + dbug_mask = output[-1] + query_padding_mask = torch.ones( + 1, seq_len_q, device="cuda", dtype=torch.bool) + key_padding_mask = torch.ones( + 1, seq_len_k, device="cuda", dtype=torch.bool) + + softmax_mask = self.convert_flash_attn_S_to_softmax( + dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal) + dropout_mask = softmax_mask >= 0 + return dropout_mask + + seed = 42 scale = scale if scale is None else (1 / head_dim) n_heads = 4 query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, - device="cuda", dtype=dtype, requires_grad=True) - key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device="cuda", + device=device, dtype=dtype, requires_grad=True) + key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True) value = torch.rand(batch_size, n_heads, seq_len_k, head_dim, - device="cuda", dtype=dtype, requires_grad=True) + device=device, dtype=dtype, requires_grad=True) + fused_op = (torch.ops.aten._scaled_dot_product_efficient_attention + if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else torch.ops.aten._scaled_dot_product_flash_attention) # Run the math kernel on low precision references - query_ref_lp = query.clone().detach().requires_grad_(True) - key_ref_lp = key.clone().detach().requires_grad_(True) - value_ref_lp = value.clone().detach().requires_grad_(True) + query_ref_lp, key_ref_lp, value_ref_lp = self.query_key_value_clones(query, key, value, dtype=dtype) - query_ref = query.clone().detach().to(torch.float32).requires_grad_(True) - key_ref = key.clone().detach().to(torch.float32).requires_grad_(True) - value_ref = value.clone().detach().to(torch.float32).requires_grad_(True) + higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 + query_ref, key_ref, value_ref = self.query_key_value_clones(query, key, value, dtype=higher_precision_dtype) - is_dropout = dropout_p > 0.0 # warmup s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) + # Set the global seed before capture + torch.manual_seed(seed) + kwargs = {"dropout_p": dropout_p, "is_causal": is_causal, "scale": scale} + if fused_kernel == SDPBackend.EFFICIENT_ATTENTION: + kwargs["compute_log_sumexp"] = True + if fused_kernel == SDPBackend.FLASH_ATTENTION: + kwargs['return_debug_mask'] = True with torch.cuda.stream(s): - output_tuple = torch.ops.aten._scaled_dot_product_flash_attention( - query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=True) + # Create real output + output_tuple = fused_op(query, key, value, **kwargs) + torch.cuda.current_stream().wait_stream(s) out = output_tuple[0] - dbug_mask = output_tuple[-1] upstream_grad = torch.rand_like(out, requires_grad=False) s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): out.backward(upstream_grad) for x in (query, key, value): x.grad = None - g = torch.cuda.CUDAGraph() # Create real output with torch.cuda.graph(g): tmp = torch.rand_like(query, device=query.device) # test non-zero intragraph offset - output_tuple = torch.ops.aten._scaled_dot_product_flash_attention( - query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=True) + # Create real output + output_tuple = fused_op(query, key, value, **kwargs) assert all(not isinstance(o, torch.Tensor) or o.is_cuda for o in output_tuple) g.replay() out_first = output_tuple[0].clone() - dbug_mask_first = output_tuple[-1].clone() g.replay() out = output_tuple[0] - dbug_mask = output_tuple[-1] - if not is_dropout: + if dropout_p == 0.0: self.assertEqual(out_first, out, atol=0, rtol=0) else: # replays produce different results self.assertNotEqual(out_first, out) - query_padding_mask = torch.ones( - 1, seq_len_q, device="cuda", dtype=torch.bool) - key_padding_mask = torch.ones( - 1, seq_len_k, device="cuda", dtype=torch.bool) - - softmax_mask = self.convert_flash_attn_S_to_softmax( - dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal) - dropout_mask = softmax_mask >= 0 - - if not is_dropout: - with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False): + with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False): + if dropout_p == 0.0: # High Precision Math Reference - out_ref = F.scaled_dot_product_attention( - query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale) + out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, + dropout_p=dropout_p, is_causal=is_causal, scale=scale) # Low Precision Math Reference - out_lp_ref = F.scaled_dot_product_attention( - query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale) - else: - # High Precision Math Reference - out_ref = torch.ops.aten._scaled_dot_product_attention_math( - query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0] - # Low Precision Math Reference - out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( - query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, - dropout_mask=dropout_mask)[0] + out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, + dropout_p=dropout_p, is_causal=is_causal, scale=scale) + else: + # Create the dropout_mask + dropout_mask = get_dropout_mask(output_tuple, fused_kernel, batch_size, + n_heads, seq_len_q, seq_len_k, dropout_p, device) + # High Precision Math Reference + out_ref = torch.ops.aten._scaled_dot_product_attention_math( + query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, + scale=scale, dropout_mask=dropout_mask)[0] + # Low Precision Math Reference + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( + query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, + dropout_mask=dropout_mask)[0] - upstream_grad = torch.rand_like(out, requires_grad=False) g1 = torch.cuda.CUDAGraph() with torch.cuda.graph(g1): @@ -2117,23 +2175,26 @@ def test_flash_attention_graph_vs_math_ref_grads(self, batch_size: int, seq_len_ out_ref.backward(upstream_grad.to(out_ref.dtype)) out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) - # # See [Note] Fused Tolerances above - out_deviation = out_ref - out_lp_ref - output_ref_atol = max(torch.abs(out_deviation).max().item(), default_atol[out.dtype]) - output_ref_rtol = max(get_rtol(out_ref, out_lp_ref), default_rtol[out.dtype]) + # [Note] Fused Tolerances + # Establish the numerical error between the "true" high precision math output + # and the low precision math reference. We use this reference for the atol + # And we use the default rtol for the low precision type. + # We then provide a fudge factor for gradients respectively to account + # for the use of the fused kernel rather than the eager implemntation. + output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) - # # TODO: Investigate why grad_q needs larger tolerances - grad_q_deviation = query_ref.grad - query_ref_lp.grad - grad_q_ref_atol = max(4 * torch.abs(grad_q_deviation).max().item(), default_atol[out.dtype]) - grad_q_ref_rtol = max(get_rtol(query_ref.grad, query_ref_lp.grad), default_rtol[out.dtype]) + # Fudge Factor when dropout is enabled + dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5 - grad_k_deviation = key_ref.grad - key_ref_lp.grad - grad_k_ref_atol = max(torch.abs(grad_k_deviation).max().item(), default_atol[out.dtype]) - grad_k_ref_rtol = max(get_rtol(key_ref.grad, key_ref_lp.grad), default_rtol[out.dtype]) + query_fudge_factor = dropout_fudge_factor + grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) + + # TODO: Investigate why grad_k needs larger tolerances + key_fudge_factor = 8 * dropout_fudge_factor + grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) - grad_v_deviation = value_ref.grad - value_ref_lp.grad - grad_v_ref_atol = max(torch.abs(grad_v_deviation).max().item(), default_atol[out.dtype]) - grad_v_ref_rtol = max(get_rtol(value_ref.grad, value_ref_lp.grad), default_rtol[out.dtype]) + value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 + grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index f3d4b9be543c06..1549a1b817a457 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2718,14 +2718,9 @@ nested_strides: non_differentiable # Transformers -- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp) - output_differentiability: [True, False] - query, key, value: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, output, log_sumexp, is_causal, false, scale) - -# - name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor, Tensor) -# output_differentiability: [True, False] -# query, key, value, bias: _efficient_attention_backward(grad, query, key, value, bias, result0, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, key.size(0), result1, custom_mask_type, false, scale) -# # Returns ouput, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, rng_state +- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) + output_differentiability: [True, False, False, False] + query, key, value: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, output, log_sumexp, philox_seed, philox_offset, dropout_p, is_causal, scale) - name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 290f47afc976f4..e4295b7811ade5 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3817,9 +3817,9 @@ def meta__scaled_dot_product_flash( else: debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) - # note: device for seed and offset below depends on whether we are + # Note [Seed and Offset]: device for seed and offset below depends on whether we are # capturing or not, but at the time of tracing we don't know if we - # are going to use cudagraphs or not, so we return cpu tensors here + # are going to use cudagraphs or not, so we return meta tensors here # it's possible we'll need to have some special handling in inductor for sdpa return ( @@ -3853,8 +3853,8 @@ def meta__scaled_dot_product_flash_backward( max_k: int, dropout_p: float, is_causal: bool, - philox_seed: int, - philox_offset: int, + philox_seed: Tensor, + philox_offset: Tensor, scale: Optional[float] = None, ): batch_size = query.size(0) @@ -3893,6 +3893,7 @@ def meta__scaled_dot_product_efficient( key: Tensor, value: Tensor, compute_log_sumexp: bool, + dropout_p=0.0, is_causal: bool = False, scale: Optional[float] = None, ): @@ -3918,7 +3919,11 @@ def meta__scaled_dot_product_efficient( res = res.transpose(1, 2) - return res, logsum_exp + # See Note [Seed and Offset]: + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return res, logsum_exp, seed, offset @register_meta( @@ -3933,40 +3938,40 @@ def meta__scaled_dot_product_efficient_backward( value: Tensor, out: Tensor, logsumexp: Tensor, + philox_seed: Tensor, + philox_offset: Tensor, + dropout_p: float, is_causal: bool = False, - chunk_grad_outputs=False, scale: Optional[float] = None, ): - grad_out = grad_out.transpose(1, 2) - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - B = query.size(0) - M = query.size(1) - N = key.size(1) - nH = query.size(2) - K = query.size(3) + batch_size = query.size(0) + num_heads = query.size(1) + max_q = query.size(2) + head_dim = query.size(3) - grad_kv_needs_init = is_causal and N > M + max_k = key.size(2) - grad_q = torch.empty(query.shape, dtype=query.dtype, device=query.device) - grad_k = ( - torch.zeros(key.shape, dtype=key.dtype, device=key.device) - if grad_kv_needs_init - else torch.empty(key.shape, dtype=key.dtype, device=key.device) + grad_q = torch.empty_permuted( + (batch_size, num_heads, max_q, head_dim), + (0, 2, 1, 3), + dtype=query.dtype, + device=query.device, ) - grad_v = ( - torch.zeros(value.shape, dtype=value.dtype, device=value.device) - if grad_kv_needs_init - else torch.empty(value.shape, dtype=value.dtype, device=value.device) + grad_k = torch.empty_permuted( + (batch_size, num_heads, max_k, head_dim), + (0, 2, 1, 3), + dtype=key.dtype, + device=key.device, ) - return ( - grad_q.transpose(1, 2), - grad_k.transpose(1, 2), - grad_v.transpose(1, 2), + grad_v = torch.empty_permuted( + (batch_size, num_heads, max_k, head_dim), + (0, 2, 1, 3), + dtype=value.dtype, + device=value.device, ) + return grad_q, grad_k, grad_v + @register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out]) @out_wrapper() diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index d0d93d2dbe7323..8c13eac50c7df2 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -27,8 +27,7 @@ toleranceOverride, tol) from torch.testing._internal.common_cuda import ( SM53OrLater, SM60OrLater, with_tf32_off, TEST_CUDNN, - _get_torch_cuda_version, _get_torch_rocm_version, PLATFORM_SUPPORTS_FUSED_SDPA, - SM80OrLater + _get_torch_cuda_version, _get_torch_rocm_version, PLATFORM_SUPPORTS_FUSED_SDPA ) from torch.testing._internal.common_utils import ( make_fullrank_matrices_with_distinct_singular_values, @@ -13261,9 +13260,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), # OpInfo was implemented with a lambda DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), - # See [Note] SDPA_flash's meta function returns incorrect Philox seed and offset + # See [Note] SDPA returns Philox Offset and Seed as tensors that will live on CPU when not in cuda graph capture DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_amp', - device_type='cuda', dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_FUSED_SDPA and SM80OrLater), + device_type='cuda', dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_FUSED_SDPA), + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_no_amp', + device_type='cuda', dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_FUSED_SDPA), # TODO Need to understand what this is testing and why it doesn't work DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'), DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'), diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 8f44b9caf57008..1d0805697703e4 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -59,6 +59,8 @@ # See Note [resize_ in Functionalization] "resize_", "resize_as_", + # This function is used as for testing purposes only. + "_fill_mem_eff_dropout_mask_", ] )