From 8cf3966be451068c7038db0ee39ca6d2891ac95b Mon Sep 17 00:00:00 2001 From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:25:28 +0530 Subject: [PATCH 1/7] compile success: set default self extend values in noSSM and griffin --- gemma/configs.h | 260 ++++++++++++++++++++++++++++++++++++++++++++++ gemma/gemma-inl.h | 15 +++ 2 files changed, 275 insertions(+) diff --git a/gemma/configs.h b/gemma/configs.h index f7c6ac2..30e9a6a 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -192,6 +192,266 @@ ModelConfig ConfigFromModel(Model model); // Returns the sub-config for the ViT model of the PaliGemma model. ModelConfig VitConfig(const ModelConfig& config); +template <class TConfig, typename = void> +struct CacheLayerSize { + constexpr size_t operator()() const { + return TConfig::kKVHeads * TConfig::kQKVDim * 2; + } +}; + +template <class TConfig, typename = void> +struct CachePosSize { + constexpr size_t operator()() const { + return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()(); + } +}; + +struct ConfigNoSSM { + static constexpr int kGriffinLayers = 0; + + static constexpr int kConv1dWidth = 0; + static constexpr bool kFFBiases = false; + static constexpr bool kSoftmaxAttnOutputBiases = false; + static constexpr bool kUseHalfRope = false; + static constexpr bool kUseLocalAttention = false; + static constexpr bool kInterleaveQKV = true; + static constexpr int kNumTensorScales = 0; + + static constexpr PostQKType kPostQK = PostQKType::Rope; + static constexpr ActivationType kActivation = ActivationType::Gelu; + static constexpr ResidualType kResidual = ResidualType::Add; + + // Self-extend parameters with defaul values + static constexpr bool kSelfExtend = false; + static constexpr size_t kSelfExtendNgbSize = 0; + static constexpr size_t kSelfExtendGrpSize = 1; +}; + +struct ConfigBaseGemmaV1 : ConfigNoSSM { + static constexpr float kAttCap = 0.0f; + static constexpr float kFinalCap = 0.0f; + static constexpr PostNormType kPostNorm = PostNormType::None; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; +}; + +struct ConfigBaseGemmaV2 : ConfigNoSSM { + static constexpr float kAttCap = 50.0f; + static constexpr float kFinalCap = 30.0f; + static constexpr PostNormType kPostNorm = PostNormType::Scale; +}; + +template <typename TWeight> +struct ConfigGemma27B : public ConfigBaseGemmaV2 { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = 8192; + static constexpr int kVocabSize = 256000; + static constexpr std::array<LayerAttentionType, 46> kLayerConfig = + FixedLayerConfig<46>(LayerAttentionType::kGemma); + static constexpr std::array<size_t, 46> kAttentionWindowSizes = + RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen}); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 4608; + static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864 + static constexpr int kHeads = 32; + static constexpr int kKVHeads = 16; + static constexpr int kQKVDim = 128; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr QueryScaleType kQueryScale = + QueryScaleType::SqrtModelDimDivNumHeads; +}; + +template <typename TWeight> +struct ConfigGemma9B : public ConfigBaseGemmaV2 { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = 8192; + static constexpr int kVocabSize = 256000; + static constexpr std::array<LayerAttentionType, 42> kLayerConfig = + FixedLayerConfig<42>(LayerAttentionType::kGemma); + static constexpr std::array<size_t, 42> kAttentionWindowSizes = + RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen}); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 3584; + static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336 + static constexpr int kHeads = 16; + static constexpr int kKVHeads = 8; + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; +}; + +template <typename TWeight> +struct ConfigGemma7B : public ConfigBaseGemmaV1 { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = gcpp::kSeqLen; + static constexpr int kVocabSize = 256000; + static constexpr std::array<LayerAttentionType, 28> kLayerConfig = + FixedLayerConfig<28>(LayerAttentionType::kGemma); + static constexpr std::array<size_t, 28> kAttentionWindowSizes = + FixedAttentionWindowSizes<28>(kSeqLen); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 3072; + static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 + static constexpr int kHeads = 16; + static constexpr int kKVHeads = 16; // standard MHA + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; +}; + +template <typename TWeight> +struct ConfigGemma2B : public ConfigBaseGemmaV1 { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = gcpp::kSeqLen; + static constexpr int kVocabSize = 256000; + static constexpr std::array<LayerAttentionType, 18> kLayerConfig = + FixedLayerConfig<18>(LayerAttentionType::kGemma); + static constexpr std::array<size_t, 18> kAttentionWindowSizes = + FixedAttentionWindowSizes<18>(kSeqLen); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 2048; + static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 + static constexpr int kHeads = 8; + static constexpr int kKVHeads = 1; + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; +}; + +template <typename TWeight> +struct ConfigGemma2_2B : public ConfigBaseGemmaV2 { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = 8192; + static constexpr int kVocabSize = 256000; + static constexpr std::array<LayerAttentionType, 26> kLayerConfig = + FixedLayerConfig<26>(LayerAttentionType::kGemma); + static constexpr std::array<size_t, 26> kAttentionWindowSizes = + RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen}); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 2304; + static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216 + static constexpr int kHeads = 8; + static constexpr int kKVHeads = 4; + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; +}; + +template <typename TWeight> +struct ConfigGemmaTiny : public ConfigNoSSM { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = 32; + static constexpr int kVocabSize = 64; + static constexpr std::array<LayerAttentionType, 3> kLayerConfig = + FixedLayerConfig<3>(LayerAttentionType::kGemma); + static constexpr std::array<size_t, 3> kAttentionWindowSizes = + FixedAttentionWindowSizes<3>(kSeqLen); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 128; + static constexpr int kFFHiddenDim = 256; + static constexpr int kHeads = 4; + static constexpr int kKVHeads = 1; + static constexpr int kQKVDim = 16; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr PostNormType kPostNorm = PostNormType::None; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; + + static constexpr float kAttCap = 0.0f; + // This is required for optimize_test to pass. + static constexpr float kFinalCap = 30.0f; +}; + +template <typename TWeight> +struct ConfigGriffin2B { + using Weight = TWeight; // make accessible where we only have a TConfig + + // Griffin uses local attention, so kSeqLen is actually the local attention + // window. + static constexpr int kSeqLen = 2048; + static constexpr int kVocabSize = 256000; + static constexpr std::array<LayerAttentionType, 26> kLayerConfig = { + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + }; + static constexpr std::array<size_t, 26> kAttentionWindowSizes = + FixedAttentionWindowSizes<26>(kSeqLen); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = + NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); + static constexpr int kGriffinLayers = + NumLayersOfTypeBefore(kLayerConfig, + LayerAttentionType::kGriffinRecurrentBlock, + kLayers); + static constexpr int kModelDim = 2560; + static constexpr int kFFHiddenDim = 7680; + static constexpr int kHeads = 10; + static constexpr int kKVHeads = 1; + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr PostNormType kPostNorm = PostNormType::None; + + // No SoftCap. + static constexpr float kAttCap = 0.0f; + static constexpr float kFinalCap = 0.0f; + + // SSM config. + static constexpr int kConv1dWidth = 4; + static constexpr bool kFFBiases = true; + static constexpr bool kSoftmaxAttnOutputBiases = true; + static constexpr bool kUseHalfRope = true; + static constexpr bool kUseLocalAttention = true; + static constexpr bool kInterleaveQKV = false; + static constexpr int kNumTensorScales = 140; + static constexpr PostQKType kPostQK = PostQKType::Rope; + static constexpr ActivationType kActivation = ActivationType::Gelu; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; + static constexpr ResidualType kResidual = ResidualType::Add; + + // Self-extend parameters with defaul values + static constexpr bool kSelfExtend = false; + static constexpr size_t kSelfExtendNgbSize = 0; + static constexpr size_t kSelfExtendGrpSize = 1; +}; } // namespace gcpp diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 028949c..196f1a4 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -327,6 +327,13 @@ class GemmaAttention { PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f, kv); + // When embedding position, we will use grouped key position + if constexpr (TConfig::kSelfExtend) { + if (pos > ngb_size) { + pos /= grp_size; + } + } + // If MHA, also copy V into KVCache. if (is_mha_) { hwy::CopyBytes(mha_kv + layer_config_.qkv_dim, @@ -417,6 +424,14 @@ class GemmaAttention { // Apply rope and scaling to Q. const size_t pos = queries_pos_[query_idx] + batch_idx; + if constexpr (TConfig::kSelfExtend) { + if (pos > ngb_size) { + const size_t grp_pos = pos / grp_size; + const size_t shift = ngb_size - ngb_size / grp_size; + const size_t shifted_grouped_pos = grp_pos + shift; + pos = shifted_grouped_pos; + } + } PositionalEncodingQK(q, pos, layer_, query_scale, q); const size_t start_pos = StartPos(pos, layer_); From fbba1972d089881725a57715539082f1cd0ad609 Mon Sep 17 00:00:00 2001 From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com> Date: Sat, 19 Oct 2024 11:58:32 +0530 Subject: [PATCH 2/7] remove compile time config --- gemma/configs.h | 260 ------------------------------------------------ 1 file changed, 260 deletions(-) diff --git a/gemma/configs.h b/gemma/configs.h index 30e9a6a..f7c6ac2 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -192,266 +192,6 @@ ModelConfig ConfigFromModel(Model model); // Returns the sub-config for the ViT model of the PaliGemma model. ModelConfig VitConfig(const ModelConfig& config); -template <class TConfig, typename = void> -struct CacheLayerSize { - constexpr size_t operator()() const { - return TConfig::kKVHeads * TConfig::kQKVDim * 2; - } -}; - -template <class TConfig, typename = void> -struct CachePosSize { - constexpr size_t operator()() const { - return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()(); - } -}; - -struct ConfigNoSSM { - static constexpr int kGriffinLayers = 0; - - static constexpr int kConv1dWidth = 0; - static constexpr bool kFFBiases = false; - static constexpr bool kSoftmaxAttnOutputBiases = false; - static constexpr bool kUseHalfRope = false; - static constexpr bool kUseLocalAttention = false; - static constexpr bool kInterleaveQKV = true; - static constexpr int kNumTensorScales = 0; - - static constexpr PostQKType kPostQK = PostQKType::Rope; - static constexpr ActivationType kActivation = ActivationType::Gelu; - static constexpr ResidualType kResidual = ResidualType::Add; - - // Self-extend parameters with defaul values - static constexpr bool kSelfExtend = false; - static constexpr size_t kSelfExtendNgbSize = 0; - static constexpr size_t kSelfExtendGrpSize = 1; -}; - -struct ConfigBaseGemmaV1 : ConfigNoSSM { - static constexpr float kAttCap = 0.0f; - static constexpr float kFinalCap = 0.0f; - static constexpr PostNormType kPostNorm = PostNormType::None; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; -}; - -struct ConfigBaseGemmaV2 : ConfigNoSSM { - static constexpr float kAttCap = 50.0f; - static constexpr float kFinalCap = 30.0f; - static constexpr PostNormType kPostNorm = PostNormType::Scale; -}; - -template <typename TWeight> -struct ConfigGemma27B : public ConfigBaseGemmaV2 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 8192; - static constexpr int kVocabSize = 256000; - static constexpr std::array<LayerAttentionType, 46> kLayerConfig = - FixedLayerConfig<46>(LayerAttentionType::kGemma); - static constexpr std::array<size_t, 46> kAttentionWindowSizes = - RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen}); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 4608; - static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864 - static constexpr int kHeads = 32; - static constexpr int kKVHeads = 16; - static constexpr int kQKVDim = 128; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr QueryScaleType kQueryScale = - QueryScaleType::SqrtModelDimDivNumHeads; -}; - -template <typename TWeight> -struct ConfigGemma9B : public ConfigBaseGemmaV2 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 8192; - static constexpr int kVocabSize = 256000; - static constexpr std::array<LayerAttentionType, 42> kLayerConfig = - FixedLayerConfig<42>(LayerAttentionType::kGemma); - static constexpr std::array<size_t, 42> kAttentionWindowSizes = - RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen}); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 3584; - static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336 - static constexpr int kHeads = 16; - static constexpr int kKVHeads = 8; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; -}; - -template <typename TWeight> -struct ConfigGemma7B : public ConfigBaseGemmaV1 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = 256000; - static constexpr std::array<LayerAttentionType, 28> kLayerConfig = - FixedLayerConfig<28>(LayerAttentionType::kGemma); - static constexpr std::array<size_t, 28> kAttentionWindowSizes = - FixedAttentionWindowSizes<28>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 3072; - static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 - static constexpr int kHeads = 16; - static constexpr int kKVHeads = 16; // standard MHA - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; -}; - -template <typename TWeight> -struct ConfigGemma2B : public ConfigBaseGemmaV1 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = 256000; - static constexpr std::array<LayerAttentionType, 18> kLayerConfig = - FixedLayerConfig<18>(LayerAttentionType::kGemma); - static constexpr std::array<size_t, 18> kAttentionWindowSizes = - FixedAttentionWindowSizes<18>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 2048; - static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 - static constexpr int kHeads = 8; - static constexpr int kKVHeads = 1; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; -}; - -template <typename TWeight> -struct ConfigGemma2_2B : public ConfigBaseGemmaV2 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 8192; - static constexpr int kVocabSize = 256000; - static constexpr std::array<LayerAttentionType, 26> kLayerConfig = - FixedLayerConfig<26>(LayerAttentionType::kGemma); - static constexpr std::array<size_t, 26> kAttentionWindowSizes = - RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen}); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 2304; - static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216 - static constexpr int kHeads = 8; - static constexpr int kKVHeads = 4; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; -}; - -template <typename TWeight> -struct ConfigGemmaTiny : public ConfigNoSSM { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 32; - static constexpr int kVocabSize = 64; - static constexpr std::array<LayerAttentionType, 3> kLayerConfig = - FixedLayerConfig<3>(LayerAttentionType::kGemma); - static constexpr std::array<size_t, 3> kAttentionWindowSizes = - FixedAttentionWindowSizes<3>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 128; - static constexpr int kFFHiddenDim = 256; - static constexpr int kHeads = 4; - static constexpr int kKVHeads = 1; - static constexpr int kQKVDim = 16; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr PostNormType kPostNorm = PostNormType::None; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; - - static constexpr float kAttCap = 0.0f; - // This is required for optimize_test to pass. - static constexpr float kFinalCap = 30.0f; -}; - -template <typename TWeight> -struct ConfigGriffin2B { - using Weight = TWeight; // make accessible where we only have a TConfig - - // Griffin uses local attention, so kSeqLen is actually the local attention - // window. - static constexpr int kSeqLen = 2048; - static constexpr int kVocabSize = 256000; - static constexpr std::array<LayerAttentionType, 26> kLayerConfig = { - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - }; - static constexpr std::array<size_t, 26> kAttentionWindowSizes = - FixedAttentionWindowSizes<26>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = - NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); - static constexpr int kGriffinLayers = - NumLayersOfTypeBefore(kLayerConfig, - LayerAttentionType::kGriffinRecurrentBlock, - kLayers); - static constexpr int kModelDim = 2560; - static constexpr int kFFHiddenDim = 7680; - static constexpr int kHeads = 10; - static constexpr int kKVHeads = 1; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr PostNormType kPostNorm = PostNormType::None; - - // No SoftCap. - static constexpr float kAttCap = 0.0f; - static constexpr float kFinalCap = 0.0f; - - // SSM config. - static constexpr int kConv1dWidth = 4; - static constexpr bool kFFBiases = true; - static constexpr bool kSoftmaxAttnOutputBiases = true; - static constexpr bool kUseHalfRope = true; - static constexpr bool kUseLocalAttention = true; - static constexpr bool kInterleaveQKV = false; - static constexpr int kNumTensorScales = 140; - static constexpr PostQKType kPostQK = PostQKType::Rope; - static constexpr ActivationType kActivation = ActivationType::Gelu; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; - static constexpr ResidualType kResidual = ResidualType::Add; - - // Self-extend parameters with defaul values - static constexpr bool kSelfExtend = false; - static constexpr size_t kSelfExtendNgbSize = 0; - static constexpr size_t kSelfExtendGrpSize = 1; -}; } // namespace gcpp From f77e61e514b1cc0e10c8e292d6faaf6a46e57326 Mon Sep 17 00:00:00 2001 From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com> Date: Sat, 19 Oct 2024 13:23:12 +0530 Subject: [PATCH 3/7] Use runtime config to setup self extend --- gemma/configs.h | 4 ++++ gemma/gemma-inl.h | 37 ++++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/gemma/configs.h b/gemma/configs.h index f7c6ac2..58f3446 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -127,6 +127,10 @@ struct LayerConfig { size_t conv1d_width = 0; bool ff_biases = false; bool softmax_attn_output_biases = false; + bool self_extend = false; + size_t ngb_size = 0; + size_t grp_size = 1; + PostNormType post_norm = PostNormType::None; LayerAttentionType type = LayerAttentionType::kGemma; ActivationType activation = ActivationType::Gelu; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 196f1a4..ea5aca0 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -312,28 +312,29 @@ class GemmaAttention { const size_t interleaved_idx = task / layer_config_.kv_heads; const size_t query_idx = interleaved_idx % num_queries_; const size_t batch_idx = interleaved_idx / num_queries_; - const size_t pos = queries_pos_[query_idx] + batch_idx; + size_t pos = queries_pos_[query_idx] + batch_idx; const size_t cache_pos = div_seq_len_.Remainder(pos); const size_t kv_offset = cache_pos * cache_pos_size_ + layer_ * cache_layer_size_ + head * layer_config_.qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; + + const size_t grp_size = layer_config_.grp_size; + const size_t ngb_size = layer_config_.ngb_size; + const bool self_extend = layer_config_.self_extend; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; const float* HWY_RESTRICT mha_kv = activations_.q.Batch(interleaved_idx) + head * q_stride_ + layer_config_.qkv_dim; + // When embedding position, we will use grouped key position + if (self_extend && pos > ngb_size) { + pos /= grp_size; + } // Copy from `q` if MHA, or apply in-place. PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f, kv); - - // When embedding position, we will use grouped key position - if constexpr (TConfig::kSelfExtend) { - if (pos > ngb_size) { - pos /= grp_size; - } - } - // If MHA, also copy V into KVCache. if (is_mha_) { hwy::CopyBytes(mha_kv + layer_config_.qkv_dim, @@ -418,19 +419,21 @@ class GemmaAttention { const size_t batch_idx = interleaved_idx / num_queries_; const size_t head_offset = (head / kHeadGroups) * layer_config_.qkv_dim * 2; + + const size_t grp_size = layer_config_.grp_size; + const size_t ngb_size = layer_config_.ngb_size; + const bool self_extend = layer_config_.self_extend; KVCache& kv_cache = kv_caches_[query_idx]; float* HWY_RESTRICT q = activations_.q.Batch(interleaved_idx) + head * q_stride_; // Apply rope and scaling to Q. - const size_t pos = queries_pos_[query_idx] + batch_idx; - if constexpr (TConfig::kSelfExtend) { - if (pos > ngb_size) { - const size_t grp_pos = pos / grp_size; - const size_t shift = ngb_size - ngb_size / grp_size; - const size_t shifted_grouped_pos = grp_pos + shift; - pos = shifted_grouped_pos; - } + size_t pos = queries_pos_[query_idx] + batch_idx; + if (self_extend && pos > ngb_size) { + const size_t grp_pos = pos / grp_size; + const size_t shift = ngb_size - ngb_size / grp_size; + const size_t shifted_grouped_pos = grp_pos + shift; + pos = shifted_grouped_pos; } PositionalEncodingQK(q, pos, layer_, query_scale, q); From 3b270d236fc1ed73493ed24134e6c137f460cab9 Mon Sep 17 00:00:00 2001 From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com> Date: Tue, 29 Oct 2024 22:38:05 +0530 Subject: [PATCH 4/7] Use hwy divisor to speed up division --- gemma/gemma-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index ea5aca0..9fd739a 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -319,7 +319,7 @@ class GemmaAttention { head * layer_config_.qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; - const size_t grp_size = layer_config_.grp_size; + const hwy::Divisor& div_grp_size { static_cast<uint32_t>(layer_config_.grp_size) }; const size_t ngb_size = layer_config_.ngb_size; const bool self_extend = layer_config_.self_extend; @@ -330,7 +330,7 @@ class GemmaAttention { // When embedding position, we will use grouped key position if (self_extend && pos > ngb_size) { - pos /= grp_size; + pos = div_grp_size.Divide(pos); } // Copy from `q` if MHA, or apply in-place. PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f, From 719098fd3e62eea5443e6b4c0ac7634a4fc7e244 Mon Sep 17 00:00:00 2001 From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com> Date: Fri, 1 Nov 2024 19:38:58 +0530 Subject: [PATCH 5/7] Move div_grp_size outside --- gemma/gemma-inl.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 9fd739a..1faf6a3 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -305,6 +305,9 @@ class GemmaAttention { } } + // Self-extension + const hwy::Divisor& div_grp_size{ + static_cast<uint32_t>(layer_config_.grp_size)}; // Apply positional encodings for K (and copy KV to cache if MHA). pool_.Run(0, layer_config_.kv_heads * num_interleaved, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { @@ -319,7 +322,6 @@ class GemmaAttention { head * layer_config_.qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; - const hwy::Divisor& div_grp_size { static_cast<uint32_t>(layer_config_.grp_size) }; const size_t ngb_size = layer_config_.ngb_size; const bool self_extend = layer_config_.self_extend; @@ -328,7 +330,8 @@ class GemmaAttention { activations_.q.Batch(interleaved_idx) + head * q_stride_ + layer_config_.qkv_dim; - // When embedding position, we will use grouped key position + // In self-extend, when embedding position, + // we will use grouped key position if (self_extend && pos > ngb_size) { pos = div_grp_size.Divide(pos); } @@ -1484,7 +1487,7 @@ void GenerateBatchT(const ModelWeightsStorage& model, qbatch_size); QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], - qbatch_size); + qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); GenerateT<T>(model, activations, runtime_config, qbatch_prompts, qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info); From 28063743fc3d5fe5d5b85ae8fe88db060777c6e0 Mon Sep 17 00:00:00 2001 From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com> Date: Tue, 5 Nov 2024 22:43:07 +0530 Subject: [PATCH 6/7] Use explicit ctor for hwy divisor --- gemma/gemma-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 1faf6a3..368a8cd 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -306,8 +306,8 @@ class GemmaAttention { } // Self-extension - const hwy::Divisor& div_grp_size{ - static_cast<uint32_t>(layer_config_.grp_size)}; + const hwy::Divisor div_grp_size( + static_cast<uint32_t>(layer_config_.grp_size)); // Apply positional encodings for K (and copy KV to cache if MHA). pool_.Run(0, layer_config_.kv_heads * num_interleaved, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { From 14d62b0098fd7d897760ebc0304faeb2bb649ceb Mon Sep 17 00:00:00 2001 From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com> Date: Tue, 19 Nov 2024 22:33:27 +0530 Subject: [PATCH 7/7] Added support for mutable ModelConfig, run.cc can support runtime self extend config --- gemma/configs.h | 11 +++++++++-- gemma/gemma-inl.h | 23 +++++++++++++---------- gemma/gemma.h | 1 + gemma/run.cc | 21 +++++++++++++++++++++ gemma/weights.h | 1 + util/app.h | 11 +++++++++++ 6 files changed, 56 insertions(+), 12 deletions(-) diff --git a/gemma/configs.h b/gemma/configs.h index be36839..986ff30 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -135,9 +135,16 @@ struct LayerConfig { size_t conv1d_width = 0; bool ff_biases = false; bool softmax_attn_output_biases = false; + + /** + * Self-extend + * Jin, Hongye, et al. "Llm maybe longlm: Self-extend llm context window without tuning." arXiv preprint arXiv:2401.01325 (2024). + */ bool self_extend = false; - size_t ngb_size = 0; - size_t grp_size = 1; + // Self-extend neighbor size + size_t se_neighbor_size = std::numeric_limits<size_t>::max(); + // Self-extend group window size + size_t se_group_size = 1; PostNormType post_norm = PostNormType::None; LayerAttentionType type = LayerAttentionType::kGemma; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 3354ccf..d93c803 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -302,7 +302,7 @@ class GemmaAttention { // Self-extension const hwy::Divisor div_grp_size( - static_cast<uint32_t>(layer_config_.grp_size)); + static_cast<uint32_t>(layer_config_.se_group_size)); // Apply positional encodings for K (and copy KV to cache if MHA). pool_.Run(0, kv_heads * num_interleaved, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { @@ -317,8 +317,8 @@ class GemmaAttention { head * qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; - const size_t ngb_size = layer_config_.ngb_size; - const bool self_extend = layer_config_.self_extend; + const size_t se_neighbor_size = layer_config_.se_neighbor_size; + const bool enable_self_extend = layer_config_.self_extend; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; const float* HWY_RESTRICT mha_kv = @@ -327,7 +327,7 @@ class GemmaAttention { // In self-extend, when embedding position, // we will use grouped key position - if (self_extend && pos > ngb_size) { + if (enable_self_extend && pos > se_neighbor_size) { pos = div_grp_size.Divide(pos); } // Copy from `q` if MHA, or apply in-place. @@ -417,18 +417,21 @@ class GemmaAttention { const size_t head_offset = (head / kHeadGroups) * layer_config_.qkv_dim * 2; - const size_t grp_size = layer_config_.grp_size; - const size_t ngb_size = layer_config_.ngb_size; - const bool self_extend = layer_config_.self_extend; + const size_t se_group_size = layer_config_.se_group_size; + const size_t se_neighbor_size = layer_config_.se_neighbor_size; + const bool enable_self_extend = + layer_config_.self_extend; + KVCache& kv_cache = kv_caches_[query_idx]; float* HWY_RESTRICT q = activations_.q.Batch(interleaved_idx) + head * q_stride_; // Apply rope and scaling to Q. size_t pos = queries_pos_[query_idx] + batch_idx; - if (self_extend && pos > ngb_size) { - const size_t grp_pos = pos / grp_size; - const size_t shift = ngb_size - ngb_size / grp_size; + if (enable_self_extend && pos > se_neighbor_size) { + const size_t grp_pos = pos / se_group_size; + const size_t shift = + se_neighbor_size - se_neighbor_size / se_group_size; const size_t shifted_grouped_pos = grp_pos + shift; pos = shifted_grouped_pos; } diff --git a/gemma/gemma.h b/gemma/gemma.h index 5df319f..3a685fe 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -194,6 +194,7 @@ class Gemma { ~Gemma(); const ModelConfig& GetModelConfig() const { return model_.Config(); } + ModelConfig& GetMutableModelConfig() { return model_.MutableConfig(); } const ModelInfo& Info() const { return info_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const ModelWeightsStorage& Weights() const { return model_; } diff --git a/gemma/run.cc b/gemma/run.cc index 2c62bdb..6fbf9bd 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -77,6 +77,26 @@ std::string GetPrompt(std::istream& input, int verbosity, return prompt_string; } +// Extract args from the loader and modify model config +void ApplySelfExtendIfGiven(Gemma& model, LoaderArgs loader) { + ModelConfig& config = model.GetMutableModelConfig(); + if (loader.self_extend != Tristate::kTrue) { + return; + } + + // Modify layer config in-place + auto& layer_configs = config.layer_configs; + std::transform(layer_configs.begin(), layer_configs.end(), layer_configs.begin(), + [&loader](LayerConfig& layer_config) { + layer_config.self_extend = + loader.self_extend == Tristate::kTrue; + layer_config.se_group_size = loader.se_group_size; + layer_config.se_neighbor_size = loader.se_neighbor_size; + + return layer_config; + }); +} + // The main Read-Eval-Print Loop. void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, const InferenceArgs& args, const AcceptFunc& accept_token, @@ -206,6 +226,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { Allocator::Init(pools.Topology()); Gemma model = CreateGemma(loader, pools); + ApplySelfExtendIfGiven(model, loader); KVCache kv_cache = KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size); diff --git a/gemma/weights.h b/gemma/weights.h index ce2df43..71a8186 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -550,6 +550,7 @@ class ModelWeightsStorage { void CopyWithTranspose(hwy::ThreadPool& pool); void LogWeightStats(); const ModelConfig& Config() const { return config_; } + ModelConfig& MutableConfig() { return config_; } template <typename T> ModelWeightsPtrs<T>* GetWeightsOfType() const { diff --git a/util/app.h b/util/app.h index ebc16b9..2aa8529 100644 --- a/util/app.h +++ b/util/app.h @@ -171,6 +171,11 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> { std::string model_type_str; std::string weight_type_str; + // Self-extend + Tristate self_extend; + size_t se_group_size; + size_t se_neighbor_size; + template <class Visitor> void ForEach(const Visitor& visitor) { visitor(tokenizer, "tokenizer", Path(), @@ -189,6 +194,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> { visitor(weight_type_str, "weight_type", std::string("sfp"), "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n" " Required argument."); + visitor(self_extend, "self_extend", Tristate::kDefault, + "Apply self extend ? -1 = auto, 0 = no, 1 = yes.", 2); + visitor(se_group_size, "se_group_size", size_t{1}, "Group size for self extend"); + visitor(se_neighbor_size, "se_neighbor_size", + std::numeric_limits<size_t>::max(), + "Neighbor window size for self extend"); } // Uninitialized before Validate, must call after that.