Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] ISS-60: Implement Self Extend #431

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
Next Next commit
compile success: set default self extend values in noSSM and griffin
jonpsy committed Oct 19, 2024
commit 8cf3966be451068c7038db0ee39ca6d2891ac95b
260 changes: 260 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
@@ -192,6 +192,266 @@ ModelConfig ConfigFromModel(Model model);

// Returns the sub-config for the ViT model of the PaliGemma model.
ModelConfig VitConfig(const ModelConfig& config);
template <class TConfig, typename = void>
struct CacheLayerSize {
constexpr size_t operator()() const {
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
}
};

template <class TConfig, typename = void>
struct CachePosSize {
constexpr size_t operator()() const {
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
}
};

struct ConfigNoSSM {
static constexpr int kGriffinLayers = 0;

static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;

static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr ResidualType kResidual = ResidualType::Add;

// Self-extend parameters with defaul values
static constexpr bool kSelfExtend = false;
static constexpr size_t kSelfExtendNgbSize = 0;
static constexpr size_t kSelfExtendGrpSize = 1;
};

struct ConfigBaseGemmaV1 : ConfigNoSSM {
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};

struct ConfigBaseGemmaV2 : ConfigNoSSM {
static constexpr float kAttCap = 50.0f;
static constexpr float kFinalCap = 30.0f;
static constexpr PostNormType kPostNorm = PostNormType::Scale;
};

template <typename TWeight>
struct ConfigGemma27B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
FixedLayerConfig<46>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 4608;
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
static constexpr int kHeads = 32;
static constexpr int kKVHeads = 16;
static constexpr int kQKVDim = 128; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale =
QueryScaleType::SqrtModelDimDivNumHeads;
};

template <typename TWeight>
struct ConfigGemma9B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
FixedLayerConfig<42>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3584;
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 8;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};

template <typename TWeight>
struct ConfigGemma7B : public ConfigBaseGemmaV1 {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
FixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
FixedAttentionWindowSizes<28>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
};

template <typename TWeight>
struct ConfigGemma2B : public ConfigBaseGemmaV1 {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
FixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
FixedAttentionWindowSizes<18>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
};

template <typename TWeight>
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
FixedLayerConfig<26>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2304;
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 4;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};

template <typename TWeight>
struct ConfigGemmaTiny : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 64;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
FixedLayerConfig<3>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
FixedAttentionWindowSizes<3>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 128;
static constexpr int kFFHiddenDim = 256;
static constexpr int kHeads = 4;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 16; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;

static constexpr float kAttCap = 0.0f;
// This is required for optimize_test to pass.
static constexpr float kFinalCap = 30.0f;
};

template <typename TWeight>
struct ConfigGriffin2B {
using Weight = TWeight; // make accessible where we only have a TConfig

// Griffin uses local attention, so kSeqLen is actually the local attention
// window.
static constexpr int kSeqLen = 2048;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
};
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
FixedAttentionWindowSizes<26>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 2560;
static constexpr int kFFHiddenDim = 7680;
static constexpr int kHeads = 10;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None;

// No SoftCap.
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;

// SSM config.
static constexpr int kConv1dWidth = 4;
static constexpr bool kFFBiases = true;
static constexpr bool kSoftmaxAttnOutputBiases = true;
static constexpr bool kUseHalfRope = true;
static constexpr bool kUseLocalAttention = true;
static constexpr bool kInterleaveQKV = false;
static constexpr int kNumTensorScales = 140;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr ResidualType kResidual = ResidualType::Add;

// Self-extend parameters with defaul values
static constexpr bool kSelfExtend = false;
static constexpr size_t kSelfExtendNgbSize = 0;
static constexpr size_t kSelfExtendGrpSize = 1;
};

} // namespace gcpp

15 changes: 15 additions & 0 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
@@ -327,6 +327,13 @@ class GemmaAttention {
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
kv);

// When embedding position, we will use grouped key position
if constexpr (TConfig::kSelfExtend) {
if (pos > ngb_size) {
pos /= grp_size;
}
}

// If MHA, also copy V into KVCache.
if (is_mha_) {
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
@@ -417,6 +424,14 @@ class GemmaAttention {

// Apply rope and scaling to Q.
const size_t pos = queries_pos_[query_idx] + batch_idx;
if constexpr (TConfig::kSelfExtend) {
if (pos > ngb_size) {
const size_t grp_pos = pos / grp_size;
const size_t shift = ngb_size - ngb_size / grp_size;
const size_t shifted_grouped_pos = grp_pos + shift;
pos = shifted_grouped_pos;
}
}
PositionalEncodingQK(q, pos, layer_, query_scale, q);

const size_t start_pos = StartPos(pos, layer_);