Skip to content

Commit

Permalink
REDO of dropout support for mem eff pytorch#102038 (pytorch#103704)
Browse files Browse the repository at this point in the history
THIS IS A new PR with the changes from pytorch#102038 + pytorch#103201 +  plus namespacing changes to fix bug.

# Summary
This PR builds off of:
- pytorch#101847
- pytorch#100583

It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made:
- Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention
- Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support
- Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. facebookresearch/xformers#755

Pull Request resolved: pytorch#103704
Approved by: https://github.com/cpuhrsch
  • Loading branch information
drisspg authored and pytorchmergebot committed Jun 26, 2023
1 parent bfa08a1 commit 4a008d2
Show file tree
Hide file tree
Showing 68 changed files with 506 additions and 240 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/detail/UnpackRaw.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace philox {
// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
//
// The raw definition lives in its own file so jit codegen can easily copy it.
__device__ __forceinline__ std::tuple<uint64_t, uint64_t>
__host__ __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
unpack(at::PhiloxCudaState arg) {
if (arg.captured_) {
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
Expand Down
17 changes: 13 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14176,14 +14176,17 @@
dispatch:
CUDA: _scaled_dot_product_flash_attention_backward_cuda

- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp)
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
dispatch:
CUDA: _scaled_dot_product_efficient_attention_cuda
NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda
tags: nondeterministic_seeded

- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False, bool chunk_grad_outputs=False, *, float? scale=None) -> (Tensor, Tensor, Tensor)
- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor)
device_check: NoCheck
dispatch:
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
tags: nondeterministic_seeded

# THIS FUNCTION iS DEPRECATED AND SHOULD BE REMOVED
- func: _chunk_grad_outputs_efficient_attention(Tensor query, Tensor key, Tensor value, bool is_causal=False) -> bool
Expand All @@ -14203,13 +14206,13 @@
CUDA: _flash_attention_backward

# Returns ouput, logsumexp if compute_logsumexp
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp)
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
variants: function
dispatch:
CUDA: _efficient_attention_forward
tags: nondeterministic_seeded

- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor rng_seed, Tensor rng_offset, int custom_mask_type, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
device_check: NoCheck
variants: function
dispatch:
Expand All @@ -14219,7 +14222,13 @@
variants: function
dispatch:
CUDA: triton_scaled_dot_attention
tags: nondeterministic_seeded
autogen: _triton_scaled_dot_attention.out

- func: _fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!)
variants: function
dispatch:
CUDA: _fill_mem_eff_dropout_mask_
tags: nondeterministic_seeded

- func: _triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,12 +738,13 @@ _scaled_dot_product_flash_attention_nestedtensor_cuda(
debug_attn_mask);
}

std::tuple<Tensor, Tensor>
std::tuple<Tensor, Tensor, Tensor, Tensor>
_scaled_dot_product_efficient_attention_nestedtensor_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
bool compute_log_sumexp,
double dropout_p,
bool is_causal,
c10::optional<double> scale) {
Tensor query_buffer_reshaped, key_buffer_reshaped, value_buffer_reshaped,
Expand All @@ -763,23 +764,23 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda(
? sdp::CustomMaskType::CausalFromTopLeft
: sdp::CustomMaskType::NoCustomMask;

Tensor attention, log_sumexp;
std::tie(attention, log_sumexp) = at::_efficient_attention_forward(
// See Note [Seed and Offset] for description of seed and offset
auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward(
query_buffer_reshaped.unsqueeze(0),
key_buffer_reshaped.unsqueeze(0),
value_buffer_reshaped.unsqueeze(0),
c10::nullopt,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
0.0 /*dropout_p*/,
dropout_p,
static_cast<int64_t>(custom_mask_type),
compute_log_sumexp,
scale);

// Reshape output to convert nnz to batch_size and seq_len
attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
return std::make_tuple(std::move(attention), std::move(log_sumexp));
return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
}

} // namespace native
Expand Down
9 changes: 6 additions & 3 deletions aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ Tensor scaled_dot_product_attention(
(query_.requires_grad() || key.requires_grad() ||
value.requires_grad());
auto out_and_lse = at::_scaled_dot_product_efficient_attention(
query_, key, value, compute_logsumexp, is_causal, scale);
query_, key, value, compute_logsumexp, dropout_p, is_causal, scale);
return std::get<0>(out_and_lse);
}
case sdp::SDPBackend::math:
Expand Down Expand Up @@ -682,9 +682,12 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
attn = at::softmax(attn, -1);
if (dropout_p > 0.0) {
if (dropout_mask.has_value()) {
auto attn_dropout_masked = attn.masked_fill(dropout_mask->logical_not(), 0.0);
// In order to validate the correctness of the fused kernels, we need to
// use the same dropout mask in order to compare the results.
TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
auto dropout_scaling = 1.0 / (1 - dropout_p);
return std::make_tuple(at::matmul(attn_dropout_masked, value * dropout_scaling), attn);
return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn);
} else {
attn = at::dropout(attn, dropout_p, true);
}
Expand Down
Loading

0 comments on commit 4a008d2

Please sign in to comment.