From 59e8f81d44ce35193c82a15d7d4db30b1394b39a Mon Sep 17 00:00:00 2001 From: April Yang Date: Mon, 3 Jun 2024 00:41:44 +0000 Subject: [PATCH] modify constructor and allocate weights --- src/ops/inc_multihead_self_attention.cc | 61 +++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index 5d52034575..c34604be53 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -23,6 +23,8 @@ #endif #include "flexflow/utils/hash_utils.h" #include "legion/legion_utilities.h" +#include "flexflow/ops/linear.h" +#include "flexflow/ops/lora_linear.h" namespace FlexFlow { @@ -301,6 +303,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( DataType _quantization_type, bool _offload, int _tensor_parallelism_degree, + bool use_lora, char const *name) // Initializer* _bias_initializer) : Op(model, @@ -336,6 +339,20 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( dims[0].size = _embed_dim; // Currently require no parallelism along this dim assert(dims[0].degree == 1); + + + // Initialize separate Q, K, V linear layers + qProj = model.dense(_input, _kdim * _num_q_heads, AM_LINEAR, _qkv_bias, _input->data_type, nullptr, nullptr, nullptr, RM_NONE, 0.0f, "qProj"); + kProj = model.dense(_input, _kdim * _num_kv_heads, AM_LINEAR, _qkv_bias, _input->data_type, nullptr, nullptr, nullptr, RM_NONE, 0.0f, "kProj"); + vProj = model.dense(_input, _vdim * _num_kv_heads, AM_LINEAR, _qkv_bias, _input->data_type, nullptr, nullptr, nullptr, RM_NONE, 0.0f, "vProj"); + + if (use_lora) { + // Initialize LoRA layers + qProj_lora = model.add_lora_layer({"qProj"}); + kProj_lora = model.add_lora_layer({"kProj"}); + vProj_lora = model.add_lora_layer({"vProj"}); + } + if (allocate_weights) { // Create weight tensor int num_dims = inputs[0]->num_dims; @@ -384,6 +401,50 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( initializer, CHOSEN_SYNC_TYPE); } + + if (use_lora) { + // Allocate LoRA weights for Q, K, V projections + int rank = 4; // Example rank, this should be configurable + + // LoRA Q projection weights + ParallelDim lora_q_dims[2]; + lora_q_dims[0] = dims[0]; // output dimension + lora_q_dims[1] = dims[1]; // input dimension + lora_q_dims[1].size = rank * this->qProjSize; + weights[2] = model.create_parallel_weight<2>( + lora_q_dims, + this->data_type, + nullptr /*owner_op*/, + true /*create_grad*/, + initializer, + CHOSEN_SYNC_TYPE); + + // LoRA K projection weights + ParallelDim lora_k_dims[2]; + lora_k_dims[0] = dims[0]; + lora_k_dims[1] = dims[1]; + lora_k_dims[1].size = rank * this->kProjSize; + weights[3] = model.create_parallel_weight<2>( + lora_k_dims, + this->data_type, + nullptr /*owner_op*/, + true /*create_grad*/, + initializer, + CHOSEN_SYNC_TYPE); + + // LoRA V projection weights + ParallelDim lora_v_dims[2]; + lora_v_dims[0] = dims[0]; + lora_v_dims[1] = dims[1]; + lora_v_dims[1].size = rank * this->vProjSize; + weights[4] = model.create_parallel_weight<2>( + lora_v_dims, + this->data_type, + nullptr /*owner_op*/, + true /*create_grad*/, + initializer, + CHOSEN_SYNC_TYPE); + } } outputs[0] = model.create_parallel_tensor_legion_ordering(