Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize attention kernel #1228

Merged
merged 46 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
f4ca263
change layout
xinhaoc Oct 11, 2023
c868e8f
main change
xinhaoc Oct 12, 2023
38be2c0
fix
xinhaoc Oct 12, 2023
c076719
change spec&tree kernel
xinhaoc Oct 12, 2023
d71cf29
fix tp
xinhaoc Oct 12, 2023
5ee8587
fix
xinhaoc Oct 12, 2023
ffa5168
fix multi requests
xinhaoc Oct 12, 2023
552d49f
replicate key&value
xinhaoc Oct 13, 2023
23f5891
ci
xinhaoc Oct 14, 2023
e31249a
cleanup&hip
xinhaoc Oct 14, 2023
c9b4ed3
more fix.
xinhaoc Oct 14, 2023
11edf85
ci
xinhaoc Oct 14, 2023
6e70280
Merge branch 'flexflow:inference' into optimize_attn_v2
xinhaoc Oct 15, 2023
8ceaf41
new kernel
xinhaoc Oct 19, 2023
90ebe10
draft
xinhaoc Oct 19, 2023
64710ed
fix
xinhaoc Oct 20, 2023
2613ffe
align inc
xinhaoc Oct 22, 2023
509a86e
Merge branch 'inference' into optimize_attn_v2
xinhaoc Oct 23, 2023
804c580
fix
xinhaoc Oct 23, 2023
c6011f9
.
xinhaoc Oct 23, 2023
add285b
multi batch
xinhaoc Oct 23, 2023
f572737
fix
xinhaoc Oct 24, 2023
0aec1b6
fix
xinhaoc Oct 24, 2023
5d2dbbd
fix different thread per key case
xinhaoc Oct 24, 2023
20b2b2b
Merge branch 'inference' into optimize_attn_v2
xinhaoc Oct 24, 2023
4dbd31b
fix
xinhaoc Oct 24, 2023
a53ff87
.
xinhaoc Oct 24, 2023
02e4fad
.
xinhaoc Oct 29, 2023
90caa1a
.
xinhaoc Oct 29, 2023
23bf953
fix.
xinhaoc Oct 30, 2023
09a4fb2
fix.
xinhaoc Oct 30, 2023
abbc8eb
.
xinhaoc Oct 31, 2023
7b05643
.
xinhaoc Oct 31, 2023
305b681
..
xinhaoc Nov 1, 2023
871df6e
opt
xinhaoc Nov 2, 2023
87a294e
fix half
xinhaoc Nov 5, 2023
a9e75e5
fix.
xinhaoc Nov 5, 2023
dd66b85
Merge branch 'inference' into optimize_attn_v2
xinhaoc Nov 5, 2023
59d3dba
Merge branch 'inference' into optimize_attn_v2
xinhaoc Nov 6, 2023
f8ec408
.
xinhaoc Nov 7, 2023
251a99c
Merge branch 'optimize_attn_v2' of https://github.com/xinhaoc/FlexFlo…
xinhaoc Nov 7, 2023
e112b25
hip
xinhaoc Nov 8, 2023
0aab1e3
clean
xinhaoc Nov 8, 2023
e6a6b0e
Merge branch 'inference' into optimize_attn_v2
jiazhihao Nov 10, 2023
99bc9b5
Merge branch 'inference' into optimize_attn_v2
xinhaoc Nov 10, 2023
1780965
Merge pull request #1227 from xinhaoc/optimize_attn_v2
xinhaoc Nov 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class BatchConfig {

// Set by update
int num_tokens;
// number of tokens in prompt phase, start offset of tokens in inc_decoding
// phase. num_tokens - num_prompt_tokens = num_generation_tokens;
int num_generation_tokens;

struct PerRequestInfo {
int first_token_depth_in_request;
Expand Down
10 changes: 6 additions & 4 deletions include/flexflow/ops/inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class IncMultiHeadSelfAttention : public Op {

IncMultiHeadSelfAttention(FFModel &model,
LayerID const &layer_guid,
const ParallelTensor _input,
ParallelTensor const _input,
int _embed_dim,
int _num_q_heads,
int _num_kv_heads,
Expand All @@ -50,8 +50,8 @@ class IncMultiHeadSelfAttention : public Op {
int _tensor_parallelism_degree,
char const *name);
IncMultiHeadSelfAttention(FFModel &model,
const ParallelTensor _input,
const ParallelTensor _weight,
ParallelTensor const _input,
ParallelTensor const _weight,
int _embed_dim,
int _num_q_heads,
int _num_kv_heads,
Expand All @@ -73,7 +73,7 @@ class IncMultiHeadSelfAttention : public Op {
char const *name);
IncMultiHeadSelfAttention(FFModel &model,
IncMultiHeadSelfAttention const &other,
const ParallelTensor input,
ParallelTensor const input,
bool allocate_weights);
IncMultiHeadSelfAttention(FFModel &model,
Params const &params,
Expand Down Expand Up @@ -192,9 +192,11 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta {
void *attn_heads;
char *quantized_weight_ptr;
BatchConfig::PerTokenInfo *token_infos;
BatchConfig::PerRequestInfo *request_infos;
DataType quantization_type;
bool offload;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
// cudaStream_t task_local_stream;
cudnnTensorDescriptor_t qk_tensor;
cuFloatComplex *complex_input;
#elif defined(FF_USE_HIP_ROCM)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ namespace FlexFlow {
namespace Kernels {
namespace IncMultiHeadAttention {

template <typename DT>
void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
DT *output_ptr,
ffStream_t stream);

template <typename DT>
void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
int shard_id,
DT *output_ptr,
DT const *weight_ptr,
DT const *bias_ptr,
int num_tokens,
ffStream_t stream);

template <typename DT>
__global__ void apply_position_bias_qkprd(DT *input_ptr,
int num_tokens,
Expand Down
Loading