From 8cf3966be451068c7038db0ee39ca6d2891ac95b Mon Sep 17 00:00:00 2001
From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com>
Date: Fri, 18 Oct 2024 14:25:28 +0530
Subject: [PATCH 1/7] compile success: set default self extend values in noSSM
 and griffin

---
 gemma/configs.h   | 260 ++++++++++++++++++++++++++++++++++++++++++++++
 gemma/gemma-inl.h |  15 +++
 2 files changed, 275 insertions(+)

diff --git a/gemma/configs.h b/gemma/configs.h
index f7c6ac2..30e9a6a 100644
--- a/gemma/configs.h
+++ b/gemma/configs.h
@@ -192,6 +192,266 @@ ModelConfig ConfigFromModel(Model model);
 
 // Returns the sub-config for the ViT model of the PaliGemma model.
 ModelConfig VitConfig(const ModelConfig& config);
+template <class TConfig, typename = void>
+struct CacheLayerSize {
+  constexpr size_t operator()() const {
+    return TConfig::kKVHeads * TConfig::kQKVDim * 2;
+  }
+};
+
+template <class TConfig, typename = void>
+struct CachePosSize {
+  constexpr size_t operator()() const {
+    return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
+  }
+};
+
+struct ConfigNoSSM {
+  static constexpr int kGriffinLayers = 0;
+
+  static constexpr int kConv1dWidth = 0;
+  static constexpr bool kFFBiases = false;
+  static constexpr bool kSoftmaxAttnOutputBiases = false;
+  static constexpr bool kUseHalfRope = false;
+  static constexpr bool kUseLocalAttention = false;
+  static constexpr bool kInterleaveQKV = true;
+  static constexpr int kNumTensorScales = 0;
+
+  static constexpr PostQKType kPostQK = PostQKType::Rope;
+  static constexpr ActivationType kActivation = ActivationType::Gelu;
+  static constexpr ResidualType kResidual = ResidualType::Add;
+
+  // Self-extend parameters with defaul values
+  static constexpr bool kSelfExtend = false;
+  static constexpr size_t kSelfExtendNgbSize = 0;
+  static constexpr size_t kSelfExtendGrpSize = 1;
+};
+
+struct ConfigBaseGemmaV1 : ConfigNoSSM {
+  static constexpr float kAttCap = 0.0f;
+  static constexpr float kFinalCap = 0.0f;
+  static constexpr PostNormType kPostNorm = PostNormType::None;
+  static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
+};
+
+struct ConfigBaseGemmaV2 : ConfigNoSSM {
+  static constexpr float kAttCap = 50.0f;
+  static constexpr float kFinalCap = 30.0f;
+  static constexpr PostNormType kPostNorm = PostNormType::Scale;
+};
+
+template <typename TWeight>
+struct ConfigGemma27B : public ConfigBaseGemmaV2 {
+  using Weight = TWeight;  // make accessible where we only have a TConfig
+
+  static constexpr int kSeqLen = 8192;
+  static constexpr int kVocabSize = 256000;
+  static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
+      FixedLayerConfig<46>(LayerAttentionType::kGemma);
+  static constexpr std::array<size_t, 46> kAttentionWindowSizes =
+      RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
+  static constexpr int kLayers = kLayerConfig.size();
+  static constexpr int kGemmaLayers = kLayers;
+  static constexpr int kModelDim = 4608;
+  static constexpr int kFFHiddenDim = 16 * 4608 / 2;  // = 36864
+  static constexpr int kHeads = 32;
+  static constexpr int kKVHeads = 16;
+  static constexpr int kQKVDim = 128;  // query size == key size == value size
+  static constexpr int kTopK = gcpp::kTopK;
+  static constexpr bool kAbsolutePE = false;
+  static constexpr QueryScaleType kQueryScale =
+      QueryScaleType::SqrtModelDimDivNumHeads;
+};
+
+template <typename TWeight>
+struct ConfigGemma9B : public ConfigBaseGemmaV2 {
+  using Weight = TWeight;  // make accessible where we only have a TConfig
+
+  static constexpr int kSeqLen = 8192;
+  static constexpr int kVocabSize = 256000;
+  static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
+      FixedLayerConfig<42>(LayerAttentionType::kGemma);
+  static constexpr std::array<size_t, 42> kAttentionWindowSizes =
+      RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
+  static constexpr int kLayers = kLayerConfig.size();
+  static constexpr int kGemmaLayers = kLayers;
+  static constexpr int kModelDim = 3584;
+  static constexpr int kFFHiddenDim = 8 * 3584 / 2;  // = 14336
+  static constexpr int kHeads = 16;
+  static constexpr int kKVHeads = 8;
+  static constexpr int kQKVDim = 256;  // query size == key size == value size
+  static constexpr int kTopK = gcpp::kTopK;
+  static constexpr bool kAbsolutePE = false;
+  static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
+};
+
+template <typename TWeight>
+struct ConfigGemma7B : public ConfigBaseGemmaV1 {
+  using Weight = TWeight;  // make accessible where we only have a TConfig
+
+  static constexpr int kSeqLen = gcpp::kSeqLen;
+  static constexpr int kVocabSize = 256000;
+  static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
+      FixedLayerConfig<28>(LayerAttentionType::kGemma);
+  static constexpr std::array<size_t, 28> kAttentionWindowSizes =
+      FixedAttentionWindowSizes<28>(kSeqLen);
+  static constexpr int kLayers = kLayerConfig.size();
+  static constexpr int kGemmaLayers = kLayers;
+  static constexpr int kModelDim = 3072;
+  static constexpr int kFFHiddenDim = 16 * 3072 / 2;  // = 24576
+  static constexpr int kHeads = 16;
+  static constexpr int kKVHeads = 16;  // standard MHA
+  static constexpr int kQKVDim = 256;  // query size == key size == value size
+  static constexpr int kTopK = gcpp::kTopK;
+  static constexpr bool kAbsolutePE = false;
+};
+
+template <typename TWeight>
+struct ConfigGemma2B : public ConfigBaseGemmaV1 {
+  using Weight = TWeight;  // make accessible where we only have a TConfig
+
+  static constexpr int kSeqLen = gcpp::kSeqLen;
+  static constexpr int kVocabSize = 256000;
+  static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
+      FixedLayerConfig<18>(LayerAttentionType::kGemma);
+  static constexpr std::array<size_t, 18> kAttentionWindowSizes =
+      FixedAttentionWindowSizes<18>(kSeqLen);
+  static constexpr int kLayers = kLayerConfig.size();
+  static constexpr int kGemmaLayers = kLayers;
+  static constexpr int kModelDim = 2048;
+  static constexpr int kFFHiddenDim = 16 * 2048 / 2;  // = 16384
+  static constexpr int kHeads = 8;
+  static constexpr int kKVHeads = 1;
+  static constexpr int kQKVDim = 256;  // query size == key size == value size
+  static constexpr int kTopK = gcpp::kTopK;
+  static constexpr bool kAbsolutePE = false;
+};
+
+template <typename TWeight>
+struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
+  using Weight = TWeight;  // make accessible where we only have a TConfig
+
+  static constexpr int kSeqLen = 8192;
+  static constexpr int kVocabSize = 256000;
+  static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
+      FixedLayerConfig<26>(LayerAttentionType::kGemma);
+  static constexpr std::array<size_t, 26> kAttentionWindowSizes =
+      RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
+  static constexpr int kLayers = kLayerConfig.size();
+  static constexpr int kGemmaLayers = kLayers;
+  static constexpr int kModelDim = 2304;
+  static constexpr int kFFHiddenDim = 8 * 2304 / 2;  // = 9216
+  static constexpr int kHeads = 8;
+  static constexpr int kKVHeads = 4;
+  static constexpr int kQKVDim = 256;  // query size == key size == value size
+  static constexpr int kTopK = gcpp::kTopK;
+  static constexpr bool kAbsolutePE = false;
+  static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
+};
+
+template <typename TWeight>
+struct ConfigGemmaTiny : public ConfigNoSSM {
+  using Weight = TWeight;  // make accessible where we only have a TConfig
+
+  static constexpr int kSeqLen = 32;
+  static constexpr int kVocabSize = 64;
+  static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
+      FixedLayerConfig<3>(LayerAttentionType::kGemma);
+  static constexpr std::array<size_t, 3> kAttentionWindowSizes =
+      FixedAttentionWindowSizes<3>(kSeqLen);
+  static constexpr int kLayers = kLayerConfig.size();
+  static constexpr int kGemmaLayers = kLayers;
+  static constexpr int kModelDim = 128;
+  static constexpr int kFFHiddenDim = 256;
+  static constexpr int kHeads = 4;
+  static constexpr int kKVHeads = 1;
+  static constexpr int kQKVDim = 16;  // query size == key size == value size
+  static constexpr int kTopK = gcpp::kTopK;
+  static constexpr bool kAbsolutePE = false;
+  static constexpr PostNormType kPostNorm = PostNormType::None;
+  static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
+
+  static constexpr float kAttCap = 0.0f;
+  // This is required for optimize_test to pass.
+  static constexpr float kFinalCap = 30.0f;
+};
+
+template <typename TWeight>
+struct ConfigGriffin2B {
+  using Weight = TWeight;  // make accessible where we only have a TConfig
+
+  // Griffin uses local attention, so kSeqLen is actually the local attention
+  // window.
+  static constexpr int kSeqLen = 2048;
+  static constexpr int kVocabSize = 256000;
+  static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGemma,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGemma,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGemma,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGemma,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGemma,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGemma,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGemma,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGemma,
+      LayerAttentionType::kGriffinRecurrentBlock,
+      LayerAttentionType::kGriffinRecurrentBlock,
+  };
+  static constexpr std::array<size_t, 26> kAttentionWindowSizes =
+      FixedAttentionWindowSizes<26>(kSeqLen);
+  static constexpr int kLayers = kLayerConfig.size();
+  static constexpr int kGemmaLayers =
+      NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
+  static constexpr int kGriffinLayers =
+      NumLayersOfTypeBefore(kLayerConfig,
+                            LayerAttentionType::kGriffinRecurrentBlock,
+                            kLayers);
+  static constexpr int kModelDim = 2560;
+  static constexpr int kFFHiddenDim = 7680;
+  static constexpr int kHeads = 10;
+  static constexpr int kKVHeads = 1;
+  static constexpr int kQKVDim = 256;  // query size == key size == value size
+  static constexpr int kTopK = gcpp::kTopK;
+  static constexpr bool kAbsolutePE = false;
+  static constexpr PostNormType kPostNorm = PostNormType::None;
+
+  // No SoftCap.
+  static constexpr float kAttCap = 0.0f;
+  static constexpr float kFinalCap = 0.0f;
+
+  // SSM config.
+  static constexpr int kConv1dWidth = 4;
+  static constexpr bool kFFBiases = true;
+  static constexpr bool kSoftmaxAttnOutputBiases = true;
+  static constexpr bool kUseHalfRope = true;
+  static constexpr bool kUseLocalAttention = true;
+  static constexpr bool kInterleaveQKV = false;
+  static constexpr int kNumTensorScales = 140;
+  static constexpr PostQKType kPostQK = PostQKType::Rope;
+  static constexpr ActivationType kActivation = ActivationType::Gelu;
+  static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
+  static constexpr ResidualType kResidual = ResidualType::Add;
+
+  // Self-extend parameters with defaul values
+  static constexpr bool kSelfExtend = false;
+  static constexpr size_t kSelfExtendNgbSize = 0;
+  static constexpr size_t kSelfExtendGrpSize = 1;
+};
 
 }  // namespace gcpp
 
diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h
index 028949c..196f1a4 100644
--- a/gemma/gemma-inl.h
+++ b/gemma/gemma-inl.h
@@ -327,6 +327,13 @@ class GemmaAttention {
                 PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
                                      kv);
 
+              // When embedding position, we will use grouped key position
+              if constexpr (TConfig::kSelfExtend) {
+                if (pos > ngb_size) {
+                  pos /= grp_size;
+                }
+              }
+
                 // If MHA, also copy V into KVCache.
                 if (is_mha_) {
                   hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
@@ -417,6 +424,14 @@ class GemmaAttention {
 
                 // Apply rope and scaling to Q.
                 const size_t pos = queries_pos_[query_idx] + batch_idx;
+                if constexpr (TConfig::kSelfExtend) {
+                    if (pos > ngb_size) {
+                      const size_t grp_pos = pos / grp_size;
+                      const size_t shift = ngb_size - ngb_size / grp_size;
+                      const size_t shifted_grouped_pos = grp_pos + shift;
+                      pos = shifted_grouped_pos;
+                    }
+                }
                 PositionalEncodingQK(q, pos, layer_, query_scale, q);
 
                 const size_t start_pos = StartPos(pos, layer_);

From fbba1972d089881725a57715539082f1cd0ad609 Mon Sep 17 00:00:00 2001
From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com>
Date: Sat, 19 Oct 2024 11:58:32 +0530
Subject: [PATCH 2/7] remove compile time config

---
 gemma/configs.h | 260 ------------------------------------------------
 1 file changed, 260 deletions(-)

diff --git a/gemma/configs.h b/gemma/configs.h
index 30e9a6a..f7c6ac2 100644
--- a/gemma/configs.h
+++ b/gemma/configs.h
@@ -192,266 +192,6 @@ ModelConfig ConfigFromModel(Model model);
 
 // Returns the sub-config for the ViT model of the PaliGemma model.
 ModelConfig VitConfig(const ModelConfig& config);
-template <class TConfig, typename = void>
-struct CacheLayerSize {
-  constexpr size_t operator()() const {
-    return TConfig::kKVHeads * TConfig::kQKVDim * 2;
-  }
-};
-
-template <class TConfig, typename = void>
-struct CachePosSize {
-  constexpr size_t operator()() const {
-    return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
-  }
-};
-
-struct ConfigNoSSM {
-  static constexpr int kGriffinLayers = 0;
-
-  static constexpr int kConv1dWidth = 0;
-  static constexpr bool kFFBiases = false;
-  static constexpr bool kSoftmaxAttnOutputBiases = false;
-  static constexpr bool kUseHalfRope = false;
-  static constexpr bool kUseLocalAttention = false;
-  static constexpr bool kInterleaveQKV = true;
-  static constexpr int kNumTensorScales = 0;
-
-  static constexpr PostQKType kPostQK = PostQKType::Rope;
-  static constexpr ActivationType kActivation = ActivationType::Gelu;
-  static constexpr ResidualType kResidual = ResidualType::Add;
-
-  // Self-extend parameters with defaul values
-  static constexpr bool kSelfExtend = false;
-  static constexpr size_t kSelfExtendNgbSize = 0;
-  static constexpr size_t kSelfExtendGrpSize = 1;
-};
-
-struct ConfigBaseGemmaV1 : ConfigNoSSM {
-  static constexpr float kAttCap = 0.0f;
-  static constexpr float kFinalCap = 0.0f;
-  static constexpr PostNormType kPostNorm = PostNormType::None;
-  static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
-};
-
-struct ConfigBaseGemmaV2 : ConfigNoSSM {
-  static constexpr float kAttCap = 50.0f;
-  static constexpr float kFinalCap = 30.0f;
-  static constexpr PostNormType kPostNorm = PostNormType::Scale;
-};
-
-template <typename TWeight>
-struct ConfigGemma27B : public ConfigBaseGemmaV2 {
-  using Weight = TWeight;  // make accessible where we only have a TConfig
-
-  static constexpr int kSeqLen = 8192;
-  static constexpr int kVocabSize = 256000;
-  static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
-      FixedLayerConfig<46>(LayerAttentionType::kGemma);
-  static constexpr std::array<size_t, 46> kAttentionWindowSizes =
-      RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
-  static constexpr int kLayers = kLayerConfig.size();
-  static constexpr int kGemmaLayers = kLayers;
-  static constexpr int kModelDim = 4608;
-  static constexpr int kFFHiddenDim = 16 * 4608 / 2;  // = 36864
-  static constexpr int kHeads = 32;
-  static constexpr int kKVHeads = 16;
-  static constexpr int kQKVDim = 128;  // query size == key size == value size
-  static constexpr int kTopK = gcpp::kTopK;
-  static constexpr bool kAbsolutePE = false;
-  static constexpr QueryScaleType kQueryScale =
-      QueryScaleType::SqrtModelDimDivNumHeads;
-};
-
-template <typename TWeight>
-struct ConfigGemma9B : public ConfigBaseGemmaV2 {
-  using Weight = TWeight;  // make accessible where we only have a TConfig
-
-  static constexpr int kSeqLen = 8192;
-  static constexpr int kVocabSize = 256000;
-  static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
-      FixedLayerConfig<42>(LayerAttentionType::kGemma);
-  static constexpr std::array<size_t, 42> kAttentionWindowSizes =
-      RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
-  static constexpr int kLayers = kLayerConfig.size();
-  static constexpr int kGemmaLayers = kLayers;
-  static constexpr int kModelDim = 3584;
-  static constexpr int kFFHiddenDim = 8 * 3584 / 2;  // = 14336
-  static constexpr int kHeads = 16;
-  static constexpr int kKVHeads = 8;
-  static constexpr int kQKVDim = 256;  // query size == key size == value size
-  static constexpr int kTopK = gcpp::kTopK;
-  static constexpr bool kAbsolutePE = false;
-  static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
-};
-
-template <typename TWeight>
-struct ConfigGemma7B : public ConfigBaseGemmaV1 {
-  using Weight = TWeight;  // make accessible where we only have a TConfig
-
-  static constexpr int kSeqLen = gcpp::kSeqLen;
-  static constexpr int kVocabSize = 256000;
-  static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
-      FixedLayerConfig<28>(LayerAttentionType::kGemma);
-  static constexpr std::array<size_t, 28> kAttentionWindowSizes =
-      FixedAttentionWindowSizes<28>(kSeqLen);
-  static constexpr int kLayers = kLayerConfig.size();
-  static constexpr int kGemmaLayers = kLayers;
-  static constexpr int kModelDim = 3072;
-  static constexpr int kFFHiddenDim = 16 * 3072 / 2;  // = 24576
-  static constexpr int kHeads = 16;
-  static constexpr int kKVHeads = 16;  // standard MHA
-  static constexpr int kQKVDim = 256;  // query size == key size == value size
-  static constexpr int kTopK = gcpp::kTopK;
-  static constexpr bool kAbsolutePE = false;
-};
-
-template <typename TWeight>
-struct ConfigGemma2B : public ConfigBaseGemmaV1 {
-  using Weight = TWeight;  // make accessible where we only have a TConfig
-
-  static constexpr int kSeqLen = gcpp::kSeqLen;
-  static constexpr int kVocabSize = 256000;
-  static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
-      FixedLayerConfig<18>(LayerAttentionType::kGemma);
-  static constexpr std::array<size_t, 18> kAttentionWindowSizes =
-      FixedAttentionWindowSizes<18>(kSeqLen);
-  static constexpr int kLayers = kLayerConfig.size();
-  static constexpr int kGemmaLayers = kLayers;
-  static constexpr int kModelDim = 2048;
-  static constexpr int kFFHiddenDim = 16 * 2048 / 2;  // = 16384
-  static constexpr int kHeads = 8;
-  static constexpr int kKVHeads = 1;
-  static constexpr int kQKVDim = 256;  // query size == key size == value size
-  static constexpr int kTopK = gcpp::kTopK;
-  static constexpr bool kAbsolutePE = false;
-};
-
-template <typename TWeight>
-struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
-  using Weight = TWeight;  // make accessible where we only have a TConfig
-
-  static constexpr int kSeqLen = 8192;
-  static constexpr int kVocabSize = 256000;
-  static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
-      FixedLayerConfig<26>(LayerAttentionType::kGemma);
-  static constexpr std::array<size_t, 26> kAttentionWindowSizes =
-      RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
-  static constexpr int kLayers = kLayerConfig.size();
-  static constexpr int kGemmaLayers = kLayers;
-  static constexpr int kModelDim = 2304;
-  static constexpr int kFFHiddenDim = 8 * 2304 / 2;  // = 9216
-  static constexpr int kHeads = 8;
-  static constexpr int kKVHeads = 4;
-  static constexpr int kQKVDim = 256;  // query size == key size == value size
-  static constexpr int kTopK = gcpp::kTopK;
-  static constexpr bool kAbsolutePE = false;
-  static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
-};
-
-template <typename TWeight>
-struct ConfigGemmaTiny : public ConfigNoSSM {
-  using Weight = TWeight;  // make accessible where we only have a TConfig
-
-  static constexpr int kSeqLen = 32;
-  static constexpr int kVocabSize = 64;
-  static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
-      FixedLayerConfig<3>(LayerAttentionType::kGemma);
-  static constexpr std::array<size_t, 3> kAttentionWindowSizes =
-      FixedAttentionWindowSizes<3>(kSeqLen);
-  static constexpr int kLayers = kLayerConfig.size();
-  static constexpr int kGemmaLayers = kLayers;
-  static constexpr int kModelDim = 128;
-  static constexpr int kFFHiddenDim = 256;
-  static constexpr int kHeads = 4;
-  static constexpr int kKVHeads = 1;
-  static constexpr int kQKVDim = 16;  // query size == key size == value size
-  static constexpr int kTopK = gcpp::kTopK;
-  static constexpr bool kAbsolutePE = false;
-  static constexpr PostNormType kPostNorm = PostNormType::None;
-  static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
-
-  static constexpr float kAttCap = 0.0f;
-  // This is required for optimize_test to pass.
-  static constexpr float kFinalCap = 30.0f;
-};
-
-template <typename TWeight>
-struct ConfigGriffin2B {
-  using Weight = TWeight;  // make accessible where we only have a TConfig
-
-  // Griffin uses local attention, so kSeqLen is actually the local attention
-  // window.
-  static constexpr int kSeqLen = 2048;
-  static constexpr int kVocabSize = 256000;
-  static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGemma,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGemma,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGemma,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGemma,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGemma,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGemma,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGemma,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGemma,
-      LayerAttentionType::kGriffinRecurrentBlock,
-      LayerAttentionType::kGriffinRecurrentBlock,
-  };
-  static constexpr std::array<size_t, 26> kAttentionWindowSizes =
-      FixedAttentionWindowSizes<26>(kSeqLen);
-  static constexpr int kLayers = kLayerConfig.size();
-  static constexpr int kGemmaLayers =
-      NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
-  static constexpr int kGriffinLayers =
-      NumLayersOfTypeBefore(kLayerConfig,
-                            LayerAttentionType::kGriffinRecurrentBlock,
-                            kLayers);
-  static constexpr int kModelDim = 2560;
-  static constexpr int kFFHiddenDim = 7680;
-  static constexpr int kHeads = 10;
-  static constexpr int kKVHeads = 1;
-  static constexpr int kQKVDim = 256;  // query size == key size == value size
-  static constexpr int kTopK = gcpp::kTopK;
-  static constexpr bool kAbsolutePE = false;
-  static constexpr PostNormType kPostNorm = PostNormType::None;
-
-  // No SoftCap.
-  static constexpr float kAttCap = 0.0f;
-  static constexpr float kFinalCap = 0.0f;
-
-  // SSM config.
-  static constexpr int kConv1dWidth = 4;
-  static constexpr bool kFFBiases = true;
-  static constexpr bool kSoftmaxAttnOutputBiases = true;
-  static constexpr bool kUseHalfRope = true;
-  static constexpr bool kUseLocalAttention = true;
-  static constexpr bool kInterleaveQKV = false;
-  static constexpr int kNumTensorScales = 140;
-  static constexpr PostQKType kPostQK = PostQKType::Rope;
-  static constexpr ActivationType kActivation = ActivationType::Gelu;
-  static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
-  static constexpr ResidualType kResidual = ResidualType::Add;
-
-  // Self-extend parameters with defaul values
-  static constexpr bool kSelfExtend = false;
-  static constexpr size_t kSelfExtendNgbSize = 0;
-  static constexpr size_t kSelfExtendGrpSize = 1;
-};
 
 }  // namespace gcpp
 

From f77e61e514b1cc0e10c8e292d6faaf6a46e57326 Mon Sep 17 00:00:00 2001
From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com>
Date: Sat, 19 Oct 2024 13:23:12 +0530
Subject: [PATCH 3/7] Use runtime config to setup self extend

---
 gemma/configs.h   |  4 ++++
 gemma/gemma-inl.h | 37 ++++++++++++++++++++-----------------
 2 files changed, 24 insertions(+), 17 deletions(-)

diff --git a/gemma/configs.h b/gemma/configs.h
index f7c6ac2..58f3446 100644
--- a/gemma/configs.h
+++ b/gemma/configs.h
@@ -127,6 +127,10 @@ struct LayerConfig {
   size_t conv1d_width = 0;
   bool ff_biases = false;
   bool softmax_attn_output_biases = false;
+  bool self_extend = false;
+  size_t ngb_size = 0;
+  size_t grp_size = 1;
+
   PostNormType post_norm = PostNormType::None;
   LayerAttentionType type = LayerAttentionType::kGemma;
   ActivationType activation = ActivationType::Gelu;
diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h
index 196f1a4..ea5aca0 100644
--- a/gemma/gemma-inl.h
+++ b/gemma/gemma-inl.h
@@ -312,28 +312,29 @@ class GemmaAttention {
                 const size_t interleaved_idx = task / layer_config_.kv_heads;
                 const size_t query_idx = interleaved_idx % num_queries_;
                 const size_t batch_idx = interleaved_idx / num_queries_;
-                const size_t pos = queries_pos_[query_idx] + batch_idx;
+                size_t pos = queries_pos_[query_idx] + batch_idx;
                 const size_t cache_pos = div_seq_len_.Remainder(pos);
                 const size_t kv_offset = cache_pos * cache_pos_size_ +
                                          layer_ * cache_layer_size_ +
                                          head * layer_config_.qkv_dim * 2;
                 KVCache& kv_cache = kv_caches_[query_idx];
+
+                const size_t grp_size = layer_config_.grp_size;
+                const size_t ngb_size = layer_config_.ngb_size;
+                const bool self_extend = layer_config_.self_extend;
+
                 float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
                 const float* HWY_RESTRICT mha_kv =
                     activations_.q.Batch(interleaved_idx) + head * q_stride_ +
                     layer_config_.qkv_dim;
 
+                // When embedding position, we will use grouped key position
+                if (self_extend && pos > ngb_size) {
+                  pos /= grp_size;
+                }
                 // Copy from `q` if MHA, or apply in-place.
                 PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
                                      kv);
-
-              // When embedding position, we will use grouped key position
-              if constexpr (TConfig::kSelfExtend) {
-                if (pos > ngb_size) {
-                  pos /= grp_size;
-                }
-              }
-
                 // If MHA, also copy V into KVCache.
                 if (is_mha_) {
                   hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
@@ -418,19 +419,21 @@ class GemmaAttention {
                 const size_t batch_idx = interleaved_idx / num_queries_;
                 const size_t head_offset =
                     (head / kHeadGroups) * layer_config_.qkv_dim * 2;
+
+                const size_t grp_size = layer_config_.grp_size;
+                const size_t ngb_size = layer_config_.ngb_size;
+                const bool self_extend = layer_config_.self_extend;
                 KVCache& kv_cache = kv_caches_[query_idx];
                 float* HWY_RESTRICT q =
                     activations_.q.Batch(interleaved_idx) + head * q_stride_;
 
                 // Apply rope and scaling to Q.
-                const size_t pos = queries_pos_[query_idx] + batch_idx;
-                if constexpr (TConfig::kSelfExtend) {
-                    if (pos > ngb_size) {
-                      const size_t grp_pos = pos / grp_size;
-                      const size_t shift = ngb_size - ngb_size / grp_size;
-                      const size_t shifted_grouped_pos = grp_pos + shift;
-                      pos = shifted_grouped_pos;
-                    }
+                size_t pos = queries_pos_[query_idx] + batch_idx;
+                if (self_extend && pos > ngb_size) {
+                  const size_t grp_pos = pos / grp_size;
+                  const size_t shift = ngb_size - ngb_size / grp_size;
+                  const size_t shifted_grouped_pos = grp_pos + shift;
+                  pos = shifted_grouped_pos;
                 }
                 PositionalEncodingQK(q, pos, layer_, query_scale, q);
 

From 3b270d236fc1ed73493ed24134e6c137f460cab9 Mon Sep 17 00:00:00 2001
From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com>
Date: Tue, 29 Oct 2024 22:38:05 +0530
Subject: [PATCH 4/7] Use hwy divisor to speed up division

---
 gemma/gemma-inl.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h
index ea5aca0..9fd739a 100644
--- a/gemma/gemma-inl.h
+++ b/gemma/gemma-inl.h
@@ -319,7 +319,7 @@ class GemmaAttention {
                                          head * layer_config_.qkv_dim * 2;
                 KVCache& kv_cache = kv_caches_[query_idx];
 
-                const size_t grp_size = layer_config_.grp_size;
+                const hwy::Divisor& div_grp_size { static_cast<uint32_t>(layer_config_.grp_size) };
                 const size_t ngb_size = layer_config_.ngb_size;
                 const bool self_extend = layer_config_.self_extend;
 
@@ -330,7 +330,7 @@ class GemmaAttention {
 
                 // When embedding position, we will use grouped key position
                 if (self_extend && pos > ngb_size) {
-                  pos /= grp_size;
+                  pos = div_grp_size.Divide(pos);
                 }
                 // Copy from `q` if MHA, or apply in-place.
                 PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,

From 719098fd3e62eea5443e6b4c0ac7634a4fc7e244 Mon Sep 17 00:00:00 2001
From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com>
Date: Fri, 1 Nov 2024 19:38:58 +0530
Subject: [PATCH 5/7] Move div_grp_size outside

---
 gemma/gemma-inl.h | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h
index 9fd739a..1faf6a3 100644
--- a/gemma/gemma-inl.h
+++ b/gemma/gemma-inl.h
@@ -305,6 +305,9 @@ class GemmaAttention {
       }
     }
 
+    // Self-extension
+    const hwy::Divisor& div_grp_size{
+        static_cast<uint32_t>(layer_config_.grp_size)};
     // Apply positional encodings for K (and copy KV to cache if MHA).
     pool_.Run(0, layer_config_.kv_heads * num_interleaved,
               [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
@@ -319,7 +322,6 @@ class GemmaAttention {
                                          head * layer_config_.qkv_dim * 2;
                 KVCache& kv_cache = kv_caches_[query_idx];
 
-                const hwy::Divisor& div_grp_size { static_cast<uint32_t>(layer_config_.grp_size) };
                 const size_t ngb_size = layer_config_.ngb_size;
                 const bool self_extend = layer_config_.self_extend;
 
@@ -328,7 +330,8 @@ class GemmaAttention {
                     activations_.q.Batch(interleaved_idx) + head * q_stride_ +
                     layer_config_.qkv_dim;
 
-                // When embedding position, we will use grouped key position
+                // In self-extend, when embedding position,
+                // we will use grouped key position
                 if (self_extend && pos > ngb_size) {
                   pos = div_grp_size.Divide(pos);
                 }
@@ -1484,7 +1487,7 @@ void GenerateBatchT(const ModelWeightsStorage& model,
                                              qbatch_size);
     QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
     const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
-                                             qbatch_size);
+                                       qbatch_size);
     const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
     GenerateT<T>(model, activations, runtime_config, qbatch_prompts, qbatch_pos,
                  qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info);

From 28063743fc3d5fe5d5b85ae8fe88db060777c6e0 Mon Sep 17 00:00:00 2001
From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com>
Date: Tue, 5 Nov 2024 22:43:07 +0530
Subject: [PATCH 6/7] Use explicit ctor for hwy divisor

---
 gemma/gemma-inl.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h
index 1faf6a3..368a8cd 100644
--- a/gemma/gemma-inl.h
+++ b/gemma/gemma-inl.h
@@ -306,8 +306,8 @@ class GemmaAttention {
     }
 
     // Self-extension
-    const hwy::Divisor& div_grp_size{
-        static_cast<uint32_t>(layer_config_.grp_size)};
+    const hwy::Divisor div_grp_size(
+        static_cast<uint32_t>(layer_config_.grp_size));
     // Apply positional encodings for K (and copy KV to cache if MHA).
     pool_.Run(0, layer_config_.kv_heads * num_interleaved,
               [&](uint64_t task, size_t /*thread*/) HWY_ATTR {

From 14d62b0098fd7d897760ebc0304faeb2bb649ceb Mon Sep 17 00:00:00 2001
From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com>
Date: Tue, 19 Nov 2024 22:33:27 +0530
Subject: [PATCH 7/7] Added support for mutable ModelConfig, run.cc can support
 runtime self extend config

---
 gemma/configs.h   | 11 +++++++++--
 gemma/gemma-inl.h | 23 +++++++++++++----------
 gemma/gemma.h     |  1 +
 gemma/run.cc      | 21 +++++++++++++++++++++
 gemma/weights.h   |  1 +
 util/app.h        | 11 +++++++++++
 6 files changed, 56 insertions(+), 12 deletions(-)

diff --git a/gemma/configs.h b/gemma/configs.h
index be36839..986ff30 100644
--- a/gemma/configs.h
+++ b/gemma/configs.h
@@ -135,9 +135,16 @@ struct LayerConfig {
   size_t conv1d_width = 0;
   bool ff_biases = false;
   bool softmax_attn_output_biases = false;
+
+  /**
+   * Self-extend
+   * Jin, Hongye, et al. "Llm maybe longlm: Self-extend llm context window without tuning." arXiv preprint arXiv:2401.01325 (2024).
+   */
   bool self_extend = false;
-  size_t ngb_size = 0;
-  size_t grp_size = 1;
+  // Self-extend neighbor size
+  size_t se_neighbor_size = std::numeric_limits<size_t>::max();
+  // Self-extend group window size
+  size_t se_group_size = 1;
 
   PostNormType post_norm = PostNormType::None;
   LayerAttentionType type = LayerAttentionType::kGemma;
diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h
index 3354ccf..d93c803 100644
--- a/gemma/gemma-inl.h
+++ b/gemma/gemma-inl.h
@@ -302,7 +302,7 @@ class GemmaAttention {
 
     // Self-extension
     const hwy::Divisor div_grp_size(
-        static_cast<uint32_t>(layer_config_.grp_size));
+        static_cast<uint32_t>(layer_config_.se_group_size));
     // Apply positional encodings for K (and copy KV to cache if MHA).
     pool_.Run(0, kv_heads * num_interleaved,
               [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
@@ -317,8 +317,8 @@ class GemmaAttention {
                                          head * qkv_dim * 2;
                 KVCache& kv_cache = kv_caches_[query_idx];
 
-                const size_t ngb_size = layer_config_.ngb_size;
-                const bool self_extend = layer_config_.self_extend;
+                const size_t se_neighbor_size = layer_config_.se_neighbor_size;
+                const bool enable_self_extend = layer_config_.self_extend;
 
                 float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
                 const float* HWY_RESTRICT mha_kv =
@@ -327,7 +327,7 @@ class GemmaAttention {
 
                 // In self-extend, when embedding position,
                 // we will use grouped key position
-                if (self_extend && pos > ngb_size) {
+                if (enable_self_extend && pos > se_neighbor_size) {
                   pos = div_grp_size.Divide(pos);
                 }
                 // Copy from `q` if MHA, or apply in-place.
@@ -417,18 +417,21 @@ class GemmaAttention {
                 const size_t head_offset =
                     (head / kHeadGroups) * layer_config_.qkv_dim * 2;
 
-                const size_t grp_size = layer_config_.grp_size;
-                const size_t ngb_size = layer_config_.ngb_size;
-                const bool self_extend = layer_config_.self_extend;
+                const size_t se_group_size = layer_config_.se_group_size;
+                const size_t se_neighbor_size = layer_config_.se_neighbor_size;
+                const bool enable_self_extend =
+                    layer_config_.self_extend;
+
                 KVCache& kv_cache = kv_caches_[query_idx];
                 float* HWY_RESTRICT q =
                     activations_.q.Batch(interleaved_idx) + head * q_stride_;
 
                 // Apply rope and scaling to Q.
                 size_t pos = queries_pos_[query_idx] + batch_idx;
-                if (self_extend && pos > ngb_size) {
-                  const size_t grp_pos = pos / grp_size;
-                  const size_t shift = ngb_size - ngb_size / grp_size;
+                if (enable_self_extend && pos > se_neighbor_size) {
+                  const size_t grp_pos = pos / se_group_size;
+                  const size_t shift =
+                      se_neighbor_size - se_neighbor_size / se_group_size;
                   const size_t shifted_grouped_pos = grp_pos + shift;
                   pos = shifted_grouped_pos;
                 }
diff --git a/gemma/gemma.h b/gemma/gemma.h
index 5df319f..3a685fe 100644
--- a/gemma/gemma.h
+++ b/gemma/gemma.h
@@ -194,6 +194,7 @@ class Gemma {
   ~Gemma();
 
   const ModelConfig& GetModelConfig() const { return model_.Config(); }
+  ModelConfig& GetMutableModelConfig() { return model_.MutableConfig(); }
   const ModelInfo& Info() const { return info_; }
   const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
   const ModelWeightsStorage& Weights() const { return model_; }
diff --git a/gemma/run.cc b/gemma/run.cc
index 2c62bdb..6fbf9bd 100644
--- a/gemma/run.cc
+++ b/gemma/run.cc
@@ -77,6 +77,26 @@ std::string GetPrompt(std::istream& input, int verbosity,
   return prompt_string;
 }
 
+// Extract args from the loader and modify model config
+void ApplySelfExtendIfGiven(Gemma& model, LoaderArgs loader) {
+  ModelConfig& config = model.GetMutableModelConfig();
+  if (loader.self_extend != Tristate::kTrue) {
+    return;
+  }
+
+  // Modify layer config in-place
+  auto& layer_configs = config.layer_configs;
+  std::transform(layer_configs.begin(), layer_configs.end(), layer_configs.begin(),
+                [&loader](LayerConfig& layer_config) {
+                  layer_config.self_extend =
+                      loader.self_extend == Tristate::kTrue;
+                  layer_config.se_group_size = loader.se_group_size;
+                  layer_config.se_neighbor_size = loader.se_neighbor_size;
+
+                  return layer_config;
+                });
+}
+
 // The main Read-Eval-Print Loop.
 void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
                const InferenceArgs& args, const AcceptFunc& accept_token,
@@ -206,6 +226,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
   Allocator::Init(pools.Topology());
 
   Gemma model = CreateGemma(loader, pools);
+  ApplySelfExtendIfGiven(model, loader);
   KVCache kv_cache =
       KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
 
diff --git a/gemma/weights.h b/gemma/weights.h
index ce2df43..71a8186 100644
--- a/gemma/weights.h
+++ b/gemma/weights.h
@@ -550,6 +550,7 @@ class ModelWeightsStorage {
   void CopyWithTranspose(hwy::ThreadPool& pool);
   void LogWeightStats();
   const ModelConfig& Config() const { return config_; }
+  ModelConfig& MutableConfig() { return config_; }
 
   template <typename T>
   ModelWeightsPtrs<T>* GetWeightsOfType() const {
diff --git a/util/app.h b/util/app.h
index ebc16b9..2aa8529 100644
--- a/util/app.h
+++ b/util/app.h
@@ -171,6 +171,11 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
   std::string model_type_str;
   std::string weight_type_str;
 
+  // Self-extend
+  Tristate self_extend;
+  size_t se_group_size;
+  size_t se_neighbor_size;
+
   template <class Visitor>
   void ForEach(const Visitor& visitor) {
     visitor(tokenizer, "tokenizer", Path(),
@@ -189,6 +194,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
     visitor(weight_type_str, "weight_type", std::string("sfp"),
             "Weight type\n    f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n"
             "    Required argument.");
+    visitor(self_extend, "self_extend", Tristate::kDefault,
+            "Apply self extend ? -1 = auto, 0 = no, 1 = yes.", 2);
+    visitor(se_group_size, "se_group_size", size_t{1}, "Group size for self extend");
+    visitor(se_neighbor_size, "se_neighbor_size",
+            std::numeric_limits<size_t>::max(),
+            "Neighbor window size for self extend");
   }
 
   // Uninitialized before Validate, must call after that.