Skip to content

Commit

Permalink
use grad to store peft in/output (#1241)
Browse files Browse the repository at this point in the history
* use grad to store peft in/output

* format

* .
  • Loading branch information
xinhaoc authored Dec 6, 2023
1 parent d9b154f commit 3a34c88
Show file tree
Hide file tree
Showing 13 changed files with 157 additions and 284 deletions.
48 changes: 17 additions & 31 deletions src/ops/add_bias_residual_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,50 +910,36 @@ Legion::FutureMap AddBiasResidualLayerNorm::peft_bwd(
set_argumentmap_for_inference(ff, argmap, batch_outputs[0]);
size_t machine_view_hash = view->hash();
IndexLauncher launcher(ADD_BIAS_RESIDUAL_LAYERNORM_PEFT_BWD_TASK_ID,
parallel_is,
TaskArgument(NULL, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
parallel_is, TaskArgument(NULL, 0), argmap,
Predicate::TRUE_PRED, false /*must*/, 0 /*mapper_id*/,
machine_view_hash);
launcher.add_future(bc);
int field_id = 0;
// output_grad
launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
batch_outputs[1]->region));
launcher.add_region_requirement(
RegionRequirement(batch_outputs[1]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_outputs[1]->region_grad));
launcher.add_field(field_id++, FID_DATA);
// input grad
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_inputs[0]->region));
launcher.add_region_requirement(
RegionRequirement(batch_inputs[0]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_inputs[0]->region_grad));
launcher.add_field(field_id++, FID_DATA);
// residual grad
launcher.add_region_requirement(RegionRequirement(batch_inputs[1]->part,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_inputs[1]->region));
launcher.add_region_requirement(
RegionRequirement(batch_inputs[1]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_inputs[1]->region_grad));
launcher.add_field(field_id++, FID_DATA);
// attn bias grad
launcher.add_region_requirement(RegionRequirement(batch_inputs[2]->part,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_inputs[2]->region));
launcher.add_region_requirement(
RegionRequirement(batch_inputs[2]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_inputs[2]->region_grad));
launcher.add_field(field_id++, FID_DATA);
if (elementwise_affine) {
// gamma
launcher.add_region_requirement(RegionRequirement(weights[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[0]->region));
launcher.add_region_requirement(
RegionRequirement(weights[0]->part, 0 /*projection id*/, READ_ONLY,
EXCLUSIVE, weights[0]->region));
launcher.add_field(field_id++, FID_DATA);
}
return runtime->execute_index_space(ctx, launcher);
Expand Down
42 changes: 15 additions & 27 deletions src/ops/fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,45 +487,33 @@ FutureMap FusedOp::inference(FFModel const &ff,
// so we transfer the maximum of them
// size_t batch_config_size =
// std::max(sizeof(TreeVerifyBatchConfig), sizeof(BeamSearchBatchConfig));
IndexLauncher launcher(FUSEDOP_INF_TASK_ID,
parallel_is,
TaskArgument(nullptr, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
IndexLauncher launcher(FUSEDOP_INF_TASK_ID, parallel_is,
TaskArgument(nullptr, 0), argmap, Predicate::TRUE_PRED,
false /*must*/, 0 /*mapper_id*/, machine_view_hash);
launcher.add_future(bc);
int offset = 0;
for (int i = 0; i < numInputs; i++) {
assert(inputs[i]->part != LogicalPartition::NO_PART);
assert(inputs[i]->region != LogicalRegion::NO_REGION);
launcher.add_region_requirement(RegionRequirement(batch_inputs[i]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
batch_inputs[i]->region));
launcher.add_region_requirement(
RegionRequirement(batch_inputs[i]->part, 0 /*projection id*/, READ_ONLY,
EXCLUSIVE, batch_inputs[i]->region));
launcher.add_field(offset + i, FID_DATA);
}
offset += numInputs;
for (int i = 0; i < numWeights; i++) {
assert(weights[i]->region != LogicalRegion::NO_REGION);
launcher.add_region_requirement(RegionRequirement(weights[i]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[i]->region));
launcher.add_region_requirement(
RegionRequirement(weights[i]->part, 0 /*projection id*/, READ_ONLY,
EXCLUSIVE, weights[i]->region));
launcher.add_field(offset + i, FID_DATA);
}
offset += numWeights;
for (int i = 0; i < numOutputs; i++) {
assert(outputs[i]->region != LogicalRegion::NO_REGION);
launcher.add_region_requirement(
RegionRequirement(batch_outputs[i]->part,
0 /*projection id*/,
WRITE_ONLY,
EXCLUSIVE,
batch_outputs[i]->region));
RegionRequirement(batch_outputs[i]->part, 0 /*projection id*/,
WRITE_ONLY, EXCLUSIVE, batch_outputs[i]->region));
launcher.add_field(offset + i, FID_DATA);
}
return runtime->execute_index_space(ctx, launcher);
Expand Down Expand Up @@ -561,11 +549,11 @@ FutureMap FusedOp::peft_bwd(FFModel const &ff,
for (int i = 0; i < numInputs; i++) {
assert(inputs[i]->part != LogicalPartition::NO_PART);
assert(inputs[i]->region != LogicalRegion::NO_REGION);
launcher.add_region_requirement(RegionRequirement(batch_inputs[i]->part,
launcher.add_region_requirement(RegionRequirement(batch_inputs[i]->part_grad,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_inputs[i]->region));
batch_inputs[i]->region_grad));
launcher.add_field(offset + i, FID_DATA);
}
offset += numInputs;
Expand All @@ -582,11 +570,11 @@ FutureMap FusedOp::peft_bwd(FFModel const &ff,
for (int i = 0; i < numOutputs; i++) {
assert(outputs[i]->region != LogicalRegion::NO_REGION);
launcher.add_region_requirement(
RegionRequirement(batch_outputs[i]->part,
RegionRequirement(batch_outputs[i]->part_grad,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_outputs[i]->region));
batch_outputs[i]->region_grad));
launcher.add_field(offset + i, FID_DATA);
}
return runtime->execute_index_space(ctx, launcher);
Expand Down
44 changes: 14 additions & 30 deletions src/ops/inc_multihead_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -891,42 +891,26 @@ FutureMap IncMultiHeadSelfAttention::peft_bwd(
size_t machine_view_hash = view->hash();
int idx = 0;
IndexLauncher launcher(INC_MULTIHEAD_SELF_ATTENTION_PEFT_BWD_TASK_ID,
parallel_is,
TaskArgument(nullptr, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
parallel_is, TaskArgument(nullptr, 0), argmap,
Predicate::TRUE_PRED, false /*must*/, 0 /*mapper_id*/,
machine_view_hash);
launcher.add_future(bc);
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_inputs[0]->region));
launcher.add_field(idx++, FID_DATA);
launcher.add_region_requirement(
RegionRequirement(weights[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[0]->region,
ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0));
RegionRequirement(batch_inputs[0]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_inputs[0]->region_grad));
launcher.add_field(idx++, FID_DATA);
launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
batch_outputs[0]->region));
launcher.add_region_requirement(RegionRequirement(
weights[0]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE,
weights[0]->region, ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0));
launcher.add_field(idx++, FID_DATA);
launcher.add_region_requirement(
RegionRequirement(batch_outputs[0]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_outputs[0]->region_grad));
launcher.add_field(idx++, FID_DATA);
if (qkv_bias || final_bias) {
launcher.add_region_requirement(
RegionRequirement(weights[1]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[1]->region,
ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0));
launcher.add_region_requirement(RegionRequirement(
weights[1]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE,
weights[1]->region, ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0));
launcher.add_field(idx++, FID_DATA);
}
return runtime->execute_index_space(ctx, launcher);
Expand Down
35 changes: 12 additions & 23 deletions src/ops/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -661,36 +661,25 @@ Legion::FutureMap
size_t machine_view_hash = view->hash();
/* std::cout << "LayerNorm op machine_view: " << *(MachineView const *)mv
<< std::endl; */
IndexLauncher launcher(LAYERNORM_PEFT_BWD_TASK_ID,
parallel_is,
TaskArgument(NULL, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
IndexLauncher launcher(LAYERNORM_PEFT_BWD_TASK_ID, parallel_is,
TaskArgument(NULL, 0), argmap, Predicate::TRUE_PRED,
false /*must*/, 0 /*mapper_id*/, machine_view_hash);
launcher.add_future(bc);
// regions[0](I): output_grad
launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
batch_outputs[0]->region));
launcher.add_region_requirement(
RegionRequirement(batch_outputs[0]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_outputs[0]->region_grad));
launcher.add_field(0, FID_DATA);
// regions[1](I/O): input_grad
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_inputs[0]->region));
launcher.add_region_requirement(
RegionRequirement(batch_inputs[0]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_inputs[0]->region_grad));
launcher.add_field(2, FID_DATA);
if (elementwise_affine) {
// regions[2](I): gamma
launcher.add_region_requirement(RegionRequirement(weights[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[0]->region));
launcher.add_region_requirement(
RegionRequirement(weights[0]->part, 0 /*projection id*/, READ_ONLY,
EXCLUSIVE, weights[0]->region));
launcher.add_field(3, FID_DATA);
}
return runtime->execute_index_space(ctx, launcher);
Expand Down
45 changes: 15 additions & 30 deletions src/ops/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,41 +688,26 @@ FutureMap Linear::peft_bwd(FFModel const &ff,
size_t machine_view_hash = view->hash();
/* std::cout << "Linear op machine_view: " << *(MachineView const *)mv
<< std::endl; */
IndexLauncher launcher(LINEAR_PEFT_BWD_TASK_ID,
parallel_is,
TaskArgument(nullptr, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
IndexLauncher launcher(LINEAR_PEFT_BWD_TASK_ID, parallel_is,
TaskArgument(nullptr, 0), argmap, Predicate::TRUE_PRED,
false /*must*/, 0 /*mapper_id*/, machine_view_hash);
launcher.add_future(bc);
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_inputs[0]->region));
launcher.add_region_requirement(
RegionRequirement(batch_inputs[0]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_inputs[0]->region_grad));
launcher.add_field(0, FID_DATA);
launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_outputs[0]->region));
launcher.add_field(1, FID_DATA);
launcher.add_region_requirement(
RegionRequirement(weights[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[0]->region,
ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0));
RegionRequirement(batch_outputs[0]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_outputs[0]->region_grad));
launcher.add_field(1, FID_DATA);
launcher.add_region_requirement(RegionRequirement(
weights[0]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE,
weights[0]->region, ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0));
launcher.add_field(2, FID_DATA);
if (use_bias) {
launcher.add_region_requirement(RegionRequirement(weights[1]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[1]->region));
launcher.add_region_requirement(
RegionRequirement(weights[1]->part, 0 /*projection id*/, READ_ONLY,
EXCLUSIVE, weights[1]->region));
launcher.add_field(3, FID_DATA);
}
return runtime->execute_index_space(ctx, launcher);
Expand Down
27 changes: 9 additions & 18 deletions src/ops/lora_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -577,26 +577,17 @@ FutureMap LoraLinear::peft_bwd(FFModel const &ff,
MachineView const *view = mv ? mv : &output_tensor->machine_view;
set_argumentmap_for_inference(ff, argmap, output_tensor);
size_t machine_view_hash = view->hash();
IndexLauncher launcher(LORA_LINEAR_PEFT_BWD_TASK_ID,
parallel_is,
TaskArgument(nullptr, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
IndexLauncher launcher(LORA_LINEAR_PEFT_BWD_TASK_ID, parallel_is,
TaskArgument(nullptr, 0), argmap, Predicate::TRUE_PRED,
false /*must*/, 0 /*mapper_id*/, machine_view_hash);
launcher.add_future(bc);
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_inputs[0]->region));
launcher.add_region_requirement(
RegionRequirement(batch_inputs[0]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_inputs[0]->region_grad));
launcher.add_field(0, FID_DATA);
launcher.add_region_requirement(RegionRequirement(batch_inputs[1]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
batch_inputs[1]->region));
launcher.add_region_requirement(
RegionRequirement(batch_inputs[1]->part_grad, 0 /*projection id*/,
READ_WRITE, EXCLUSIVE, batch_inputs[1]->region_grad));
launcher.add_field(1, FID_DATA);
return runtime->execute_index_space(ctx, launcher);
}
Expand Down
Loading

0 comments on commit 3a34c88

Please sign in to comment.