Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] ISS-60: Implement Self Extend #431

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
4 changes: 4 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ struct LayerConfig {
size_t conv1d_width = 0;
bool ff_biases = false;
bool softmax_attn_output_biases = false;
bool self_extend = false;
size_t ngb_size = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this n-gram block? Maybe expand to block_size for more clarity? We can also move these three new fields into a section (just newline before them) with a // Self-extension comment.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jan-wassenberg Sorry, didn't understood it. I did it here because LayerConfig gets accessed during the Attention mechanism.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to be unclear, I was suggesting considering renaming this to ngram_block_size.
And it would be good to add a newline plus "// Self-extension" comment for visual separation from the other fields in this struct.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, ngb is short for neighbour, i see the point of confusion now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah :) Generally it's good to write out words.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understood, i'll make the change.

size_t grp_size = 1;

PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma;
ActivationType activation = ActivationType::Gelu;
Expand Down
29 changes: 25 additions & 4 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,28 +305,39 @@ class GemmaAttention {
}
}

// Self-extension
const hwy::Divisor div_grp_size(
static_cast<uint32_t>(layer_config_.grp_size));
// Apply positional encodings for K (and copy KV to cache if MHA).
pool_.Run(0, layer_config_.kv_heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % layer_config_.kv_heads;
const size_t interleaved_idx = task / layer_config_.kv_heads;
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t pos = queries_pos_[query_idx] + batch_idx;
size_t pos = queries_pos_[query_idx] + batch_idx;
const size_t cache_pos = div_seq_len_.Remainder(pos);
const size_t kv_offset = cache_pos * cache_pos_size_ +
layer_ * cache_layer_size_ +
head * layer_config_.qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx];

const size_t ngb_size = layer_config_.ngb_size;
const bool self_extend = layer_config_.self_extend;

float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
const float* HWY_RESTRICT mha_kv =
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
layer_config_.qkv_dim;

// In self-extend, when embedding position,
// we will use grouped key position
if (self_extend && pos > ngb_size) {
pos = div_grp_size.Divide(pos);
}
// Copy from `q` if MHA, or apply in-place.
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
kv);

// If MHA, also copy V into KVCache.
if (is_mha_) {
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
Expand Down Expand Up @@ -411,12 +422,22 @@ class GemmaAttention {
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t head_offset =
(head / kHeadGroups) * layer_config_.qkv_dim * 2;

const size_t grp_size = layer_config_.grp_size;
const size_t ngb_size = layer_config_.ngb_size;
const bool self_extend = layer_config_.self_extend;
KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * q_stride_;

// Apply rope and scaling to Q.
const size_t pos = queries_pos_[query_idx] + batch_idx;
size_t pos = queries_pos_[query_idx] + batch_idx;
if (self_extend && pos > ngb_size) {
const size_t grp_pos = pos / grp_size;
const size_t shift = ngb_size - ngb_size / grp_size;
const size_t shifted_grouped_pos = grp_pos + shift;
pos = shifted_grouped_pos;
}
PositionalEncodingQK(q, pos, layer_, query_scale, q);

const size_t start_pos = StartPos(pos, layer_);
Expand Down Expand Up @@ -1466,7 +1487,7 @@ void GenerateBatchT(const ModelWeightsStorage& model,
qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
qbatch_size);
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT<T>(model, activations, runtime_config, qbatch_prompts, qbatch_pos,
qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info);
Expand Down