Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify constructor and allocate weights #1428

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/ops/inc_multihead_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#endif
#include "flexflow/utils/hash_utils.h"
#include "legion/legion_utilities.h"
#include "flexflow/ops/linear.h"
#include "flexflow/ops/lora_linear.h"

namespace FlexFlow {

Expand Down Expand Up @@ -301,6 +303,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention(
DataType _quantization_type,
bool _offload,
int _tensor_parallelism_degree,
bool use_lora,
char const *name)
// Initializer* _bias_initializer)
: Op(model,
Expand Down Expand Up @@ -336,6 +339,20 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention(
dims[0].size = _embed_dim;
// Currently require no parallelism along this dim
assert(dims[0].degree == 1);


// Initialize separate Q, K, V linear layers
qProj = model.dense(_input, _kdim * _num_q_heads, AM_LINEAR, _qkv_bias, _input->data_type, nullptr, nullptr, nullptr, RM_NONE, 0.0f, "qProj");
kProj = model.dense(_input, _kdim * _num_kv_heads, AM_LINEAR, _qkv_bias, _input->data_type, nullptr, nullptr, nullptr, RM_NONE, 0.0f, "kProj");
vProj = model.dense(_input, _vdim * _num_kv_heads, AM_LINEAR, _qkv_bias, _input->data_type, nullptr, nullptr, nullptr, RM_NONE, 0.0f, "vProj");

if (use_lora) {
// Initialize LoRA layers
qProj_lora = model.add_lora_layer({"qProj"});
kProj_lora = model.add_lora_layer({"kProj"});
vProj_lora = model.add_lora_layer({"vProj"});
}

if (allocate_weights) {
// Create weight tensor
int num_dims = inputs[0]->num_dims;
Expand Down Expand Up @@ -384,6 +401,50 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention(
initializer,
CHOSEN_SYNC_TYPE);
}

if (use_lora) {
// Allocate LoRA weights for Q, K, V projections
int rank = 4; // Example rank, this should be configurable

// LoRA Q projection weights
ParallelDim lora_q_dims[2];
lora_q_dims[0] = dims[0]; // output dimension
lora_q_dims[1] = dims[1]; // input dimension
lora_q_dims[1].size = rank * this->qProjSize;
weights[2] = model.create_parallel_weight<2>(
lora_q_dims,
this->data_type,
nullptr /*owner_op*/,
true /*create_grad*/,
initializer,
CHOSEN_SYNC_TYPE);

// LoRA K projection weights
ParallelDim lora_k_dims[2];
lora_k_dims[0] = dims[0];
lora_k_dims[1] = dims[1];
lora_k_dims[1].size = rank * this->kProjSize;
weights[3] = model.create_parallel_weight<2>(
lora_k_dims,
this->data_type,
nullptr /*owner_op*/,
true /*create_grad*/,
initializer,
CHOSEN_SYNC_TYPE);

// LoRA V projection weights
ParallelDim lora_v_dims[2];
lora_v_dims[0] = dims[0];
lora_v_dims[1] = dims[1];
lora_v_dims[1].size = rank * this->vProjSize;
weights[4] = model.create_parallel_weight<2>(
lora_v_dims,
this->data_type,
nullptr /*owner_op*/,
true /*create_grad*/,
initializer,
CHOSEN_SYNC_TYPE);
}
}

outputs[0] = model.create_parallel_tensor_legion_ordering(
Expand Down
Loading