Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Nov 9, 2024
1 parent c71c6b3 commit d54fcf2
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 146 deletions.
4 changes: 1 addition & 3 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,7 @@ enum TaskIDs {
RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID,
RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID,
RM_BACKGROUND_SERVING_TASK_ID,
LOAD_FLOAT_WEIGHT_TASK_ID,
LOAD_HALF_WEIGHT_TASK_ID,
LOAD_QUANT_WEIGHT_TASK_ID,
LOAD_WEIGHT_TASK_ID,
// Custom tasks
CUSTOM_GPU_TASK_ID_FIRST,
CUSTOM_GPU_TASK_ID_1,
Expand Down
31 changes: 12 additions & 19 deletions include/flexflow/utils/file_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,12 @@ class FileDataLoader {
void load_single_weight_tensor(FFModel *ff, Layer *l, int weight_idx);

void load_quantization_weight(FFModel *ff, Layer *l, int weight_idx);
#ifdef DEADCODE
void load_weights(FFModel *ff);
#endif

static void
load_float_weight_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void
load_half_weight_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void
load_quant_weight_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
load_weight_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
void load_weights_parallel(FFModel *ff, Context ctx, Runtime *runtime);

void load_positions(FFModel *ff,
Expand All @@ -79,6 +66,12 @@ struct WeightLoadTaskArgs {
FileDataLoader *loader;
Layer *layer;
int weight_idx;
WeightLoadTaskArgs(FFModel *_ff, FileDataLoader *_loader, Layer *_l, int _idx)
: ff(_ff), loader(_loader), layer(_l), weight_idx(_idx) {}
DataType data_type;
WeightLoadTaskArgs(FFModel *_ff,
FileDataLoader *_loader,
Layer *_l,
int _idx,
DataType _data_type)
: ff(_ff), loader(_loader), layer(_l), weight_idx(_idx),
data_type(_data_type) {}
};
1 change: 0 additions & 1 deletion src/c/flexflow_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2929,7 +2929,6 @@ void flexflow_file_data_loader_load_weights(flexflow_file_data_loader_t handle_,
flexflow_model_t model_handle_) {
FileDataLoader *handle = FFCObjectWrapper::unwrap(handle_);
FFModel *model = FFCObjectWrapper::unwrap(model_handle_);
// handle->load_weights(model);
Context ctx = model->config.lg_ctx;
Runtime *runtime = model->config.lg_hlr;
handle->load_weights_parallel(model, ctx, runtime);
Expand Down
4 changes: 1 addition & 3 deletions src/mapper/mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,7 @@ void FFMapper::select_task_options(const MapperContext ctx,
output.initial_proc = all_cpus[0];
return;
}
if ((task.task_id == LOAD_FLOAT_WEIGHT_TASK_ID) ||
(task.task_id == LOAD_HALF_WEIGHT_TASK_ID) ||
(task.task_id == LOAD_QUANT_WEIGHT_TASK_ID)) {
if (task.task_id == LOAD_WEIGHT_TASK_ID) {
output.initial_proc = all_cpus[0];
return;
}
Expand Down
115 changes: 30 additions & 85 deletions src/runtime/file_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -852,69 +852,33 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff,
delete data;
}

#ifdef DEADCODE
void FileDataLoader::load_weights(FFModel *ff) {
for (Layer *l : ff->layers) {
if (l->numWeights < 1 || l->name == NULL || strlen(l->name) < 1) {
continue;
}
for (int i = 0; i < l->numWeights; i++) {
Tensor weight = l->weights[i];
if (weight == NULL) {
continue;
}
// TODO: currently skip Lora layers
if (l->op_type == OP_LORA) {
continue;
}
switch (weight->data_type) {
case DT_HALF:
load_single_weight_tensor<half>(ff, l, i);
break;
case DT_FLOAT:
load_single_weight_tensor<float>(ff, l, i);
break;
case DT_INT4:
case DT_INT8:
// load weights in quantization
load_quantization_weight(ff, l, i);
break;
default:
assert(false && "Unsupported data type");
}
}
}
}
#endif

void FileDataLoader::load_float_weight_task(
Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime) {
WeightLoadTaskArgs const *args = (WeightLoadTaskArgs const *)task->args;
args->loader->load_single_weight_tensor<float>(
args->ff, args->layer, args->weight_idx);
}

void FileDataLoader::load_half_weight_task(
void FileDataLoader::load_weight_task(
Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime) {
WeightLoadTaskArgs const *args = (WeightLoadTaskArgs const *)task->args;
args->loader->load_single_weight_tensor<half>(
args->ff, args->layer, args->weight_idx);
}

void FileDataLoader::load_quant_weight_task(
Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime) {
WeightLoadTaskArgs const *args = (WeightLoadTaskArgs const *)task->args;
args->loader->load_quantization_weight(
args->ff, args->layer, args->weight_idx);
switch (args->data_type) {
case DT_HALF: {
args->loader->load_single_weight_tensor<half>(
args->ff, args->layer, args->weight_idx);
break;
}
case DT_FLOAT: {
args->loader->load_single_weight_tensor<float>(
args->ff, args->layer, args->weight_idx);
break;
}
case DT_INT4:
case DT_INT8: {
args->loader->load_quantization_weight(
args->ff, args->layer, args->weight_idx);
break;
}
default:
assert(false && "Unsupported data type");
}
}

void FileDataLoader::load_weights_parallel(FFModel *ff,
Expand All @@ -937,35 +901,16 @@ void FileDataLoader::load_weights_parallel(FFModel *ff,
continue;
}

// Create task arguments
WeightLoadTaskArgs args(ff, this, l, i);

switch (weight->data_type) {
case DT_HALF: {
TaskLauncher launcher(
LOAD_HALF_WEIGHT_TASK_ID,
TaskArgument(&args, sizeof(WeightLoadTaskArgs)));
futures.push_back(runtime->execute_task(ctx, launcher));
break;
}
case DT_FLOAT: {
TaskLauncher launcher(
LOAD_FLOAT_WEIGHT_TASK_ID,
TaskArgument(&args, sizeof(WeightLoadTaskArgs)));
futures.push_back(runtime->execute_task(ctx, launcher));
break;
}
case DT_INT4:
case DT_INT8: {
TaskLauncher launcher(
LOAD_QUANT_WEIGHT_TASK_ID,
TaskArgument(&args, sizeof(WeightLoadTaskArgs)));
futures.push_back(runtime->execute_task(ctx, launcher));
break;
}
default:
assert(false && "Unsupported data type");
if (weight->data_type != DT_FLOAT && weight->data_type != DT_HALF &&
weight->data_type != DT_INT4 && weight->data_type != DT_INT8) {
assert(false && "Unsupported data type");
}

// Create task arguments
WeightLoadTaskArgs args(ff, this, l, i, weight->data_type);
TaskLauncher launcher(LOAD_WEIGHT_TASK_ID,
TaskArgument(&args, sizeof(WeightLoadTaskArgs)));
futures.push_back(runtime->execute_task(ctx, launcher));
}
}

Expand Down
39 changes: 4 additions & 35 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4801,47 +4801,16 @@ void register_flexflow_internal_tasks(Runtime *runtime,
}
}
{
TaskVariantRegistrar registrar(LOAD_FLOAT_WEIGHT_TASK_ID,
"load_float_weight_task");
TaskVariantRegistrar registrar(LOAD_WEIGHT_TASK_ID, "load_weight_task");
registrar.add_constraint(ProcessorConstraint(Processor::LOC_PROC));
if (pre_register) {
Runtime::preregister_task_variant<FileDataLoader::load_float_weight_task>(
registrar, "load_float_weight_task");
Runtime::preregister_task_variant<FileDataLoader::load_weight_task>(
registrar, "load_weight_task");
} else {
if (enable_control_replication) {
registrar.global_registration = false;
}
runtime->register_task_variant<FileDataLoader::load_float_weight_task>(
registrar);
}
}
{
TaskVariantRegistrar registrar(LOAD_HALF_WEIGHT_TASK_ID,
"load_half_weight_task");
registrar.add_constraint(ProcessorConstraint(Processor::LOC_PROC));
if (pre_register) {
Runtime::preregister_task_variant<FileDataLoader::load_half_weight_task>(
registrar, "load_half_weight_task");
} else {
if (enable_control_replication) {
registrar.global_registration = false;
}
runtime->register_task_variant<FileDataLoader::load_half_weight_task>(
registrar);
}
}
{
TaskVariantRegistrar registrar(LOAD_QUANT_WEIGHT_TASK_ID,
"load_quant_weight_task");
registrar.add_constraint(ProcessorConstraint(Processor::LOC_PROC));
if (pre_register) {
Runtime::preregister_task_variant<FileDataLoader::load_quant_weight_task>(
registrar, "load_quant_weight_task");
} else {
if (enable_control_replication) {
registrar.global_registration = false;
}
runtime->register_task_variant<FileDataLoader::load_quant_weight_task>(
runtime->register_task_variant<FileDataLoader::load_weight_task>(
registrar);
}
}
Expand Down

0 comments on commit d54fcf2

Please sign in to comment.