diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 3e173cbc24..9baac66869 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -37,9 +37,12 @@ class InferenceManager { static InferenceManager *get_inference_manager(); void compile_model_and_allocate_buffer(FFModel *model); void init_operators_inference(FFModel *model); - InferenceResultFuture inference(FFModel *model, int index, BatchConfig const &bc); - InferenceResultFuture inference(FFModel *model, int index, BatchConfigFuture const &bc); - FinetuningBwdFuture peft_bwd(FFModel *model, int index, BatchConfigFuture const &bc); + InferenceResultFuture + inference(FFModel *model, int index, BatchConfig const &bc); + InferenceResultFuture + inference(FFModel *model, int index, BatchConfigFuture const &bc); + FinetuningBwdFuture + peft_bwd(FFModel *model, int index, BatchConfigFuture const &bc); void load_input_tokens_from_batch_config(FFModel *model, BatchConfigFuture const &bc, ParallelTensor const input, @@ -67,7 +70,7 @@ struct Request { }; enum FinetuningStatus { FORWARD_PHASE = 201, - BACKWARD_PHASE = 202, + BACKWARD_PHASE = 202, }; struct PeftFinetuningInfo { FinetuningStatus status = FORWARD_PHASE; @@ -80,8 +83,8 @@ struct Request { std::vector finetuning_losses; // bwd state int last_processed_layer = INT_MAX; - // how many gradient accumulation steps to do before updating the weights. if - // left as -1, it will be set to the number of entries in the dataset + // how many gradient accumulation steps to do before updating the weights. + // if left as -1, it will be set to the number of entries in the dataset int gradient_accumulation_steps = -1; // std::vector finetuning_tokens_per_batch; }; @@ -96,12 +99,12 @@ struct Request { // inference fields std::string prompt; std::vector tokens; - + // peft fields PEFTModelID peft_model_id = PEFTModelID::NO_ID; PeftFinetuningInfo peft_finetuning_info; std::vector> dataset; - + // speculation fields int initial_len = 0; int ssm_cache_size = 0; @@ -109,7 +112,7 @@ struct Request { std::vector beam_trees; Request() = default; - Request(const Request& other); + Request(Request const &other); void load_token_ids(); friend std::ostream &operator<<(std::ostream &os, Request const &req); @@ -214,25 +217,40 @@ class RequestManager { void add_peft_config_to_request_info(BatchConfig &bc, int req_idx, LoraLinearConfig const &peft_config); - + // helpers for prepare_next_batch - void process_inf_req_progress(BatchConfig const &old_fwd_bc, InferenceResult const &result); + void process_inf_req_progress(BatchConfig const &old_fwd_bc, + InferenceResult const &result); void handle_completed_inf_req(BatchConfig const &old_bc, int i); - void add_continuing_inf_req_to_new_batch(BatchConfig &new_bc, BatchConfig const &old_bc, int &num_active_req, int &num_concurrent_inf_adapters, int i); - void add_new_inf_req(BatchConfig &new_bc, int &num_active_req, int &num_concurrent_inf_adapters, int i); + void add_continuing_inf_req_to_new_batch(BatchConfig &new_bc, + BatchConfig const &old_bc, + int &num_active_req, + int &num_concurrent_inf_adapters, + int i); + void add_new_inf_req(BatchConfig &new_bc, + int &num_active_req, + int &num_concurrent_inf_adapters, + int i); void handle_completed_finetuning_req(BatchConfig const &old_finetuning_bc); void add_finetuning_req_fwd_batch(BatchConfig &new_bc); void add_finetuning_req_bwd_batch(BatchConfig &new_bc); bool finetuning_fwd_work_available(); bool finetuning_bwd_work_available(); - void process_finetuning_req_fwd_progress(BatchConfig const &old_fwd_bc, InferenceResult const &result); + void process_finetuning_req_fwd_progress(BatchConfig const &old_fwd_bc, + InferenceResult const &result); void process_finetuning_req_bwd_progress(BatchConfig const &old_bwd_bc); - void process_work_from_old_batches(BatchConfig const &old_fwd_bc, BatchConfig const &old_bwd_bc, InferenceResult const &result); + void process_work_from_old_batches(BatchConfig const &old_fwd_bc, + BatchConfig const &old_bwd_bc, + InferenceResult const &result); BatchConfig prepare_next_bwd_batch(); - BatchConfig prepare_next_fwd_batch(BatchConfig const &old_fwd_bc, InferenceResult const &result); - BatchConfigPairFuture prepare_next_batch(std::tuple &batch_pipeline_entry, - Context ctx, - Runtime *runtime); + BatchConfig prepare_next_fwd_batch(BatchConfig const &old_fwd_bc, + InferenceResult const &result); + BatchConfigPairFuture + prepare_next_batch(std::tuple &batch_pipeline_entry, + Context ctx, + Runtime *runtime); // BatchConfig prepare_next_batch(BatchConfig const &bc, // InferenceResult const &result); // BatchConfigFuture prepare_next_batch(BatchConfigFuture const &bc, @@ -311,7 +329,7 @@ class RequestManager { Legion::Context ctx, Legion::Runtime *runtime); static std::pair prepare_next_batch_task( - Legion::Task const *task, + Legion::Task const *task, std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index d96e278bf5..4dfc2df474 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -275,10 +275,11 @@ void FlexFlow::top_level_task(Task const *task, using json = nlohmann::json; std::ifstream file_handle(file_paths.prompt_file_path); assert(file_handle.good() && "Prompt file does not exist."); - nlohmann::ordered_json prompt_json = nlohmann::ordered_json::parse(file_handle, - /*parser_callback_t */ nullptr, - /*allow_exceptions */ true, - /*ignore_comments */ true); + nlohmann::ordered_json prompt_json = + nlohmann::ordered_json::parse(file_handle, + /*parser_callback_t */ nullptr, + /*allow_exceptions */ true, + /*ignore_comments */ true); file_handle.close(); auto &metadata = prompt_json["metadata"]; int num_warmup_requests = metadata["num_warmup_requests"]; @@ -289,7 +290,7 @@ void FlexFlow::top_level_task(Task const *task, int response_length = entry["response_length"]; std::string text = entry["prompt"]; bool is_warmup_request = total_requests < num_warmup_requests; - + Request inference_req; inference_req.prompt = text; inference_req.add_special_tokens = false; @@ -302,10 +303,11 @@ void FlexFlow::top_level_task(Task const *task, requests.push_back(inference_req); num_regular_requests++; } - + total_requests++; } - std::vector warmup_result = model.generate(warmup_requests); + std::vector warmup_result = + model.generate(warmup_requests); std::vector result = model.generate(requests); assert(warmup_result.size() == warmup_requests.size()); @@ -313,7 +315,7 @@ void FlexFlow::top_level_task(Task const *task, assert(result.size() + warmup_result.size() == total_requests); int i = 0; for (auto &entry : prompt_json["entries"]) { - if (i result = model.generate(requests); } @@ -377,7 +377,7 @@ void FlexFlow::top_level_task(Task const *task, fine_tuning_req.max_length = lengths[i]; fine_tuning_req.peft_model_id = (peft_model_id != nullptr) ? *peft_model_id : PEFTModelID::NO_ID; - fine_tuning_req.max_training_steps = 1; + fine_tuning_req.peft_finetuning_info.max_training_steps = 1; requests.push_back(fine_tuning_req); } std::vector result = model.generate(requests); diff --git a/inference/peft/req_rate_benchmark.cc b/inference/peft/req_rate_benchmark.cc index cecfc4c67d..d249728841 100644 --- a/inference/peft/req_rate_benchmark.cc +++ b/inference/peft/req_rate_benchmark.cc @@ -35,8 +35,8 @@ Legion::Logger log_app("llama"); class ConcurrentQueue { public: - std::queue inf_queue; - std::queue peft_queue; + std::queue inf_queue; + std::queue peft_queue; std::mutex request_queue_mutex; bool producer_finished = false; }; @@ -58,7 +58,7 @@ void consume() { bool queue_is_empty = false; // int i=0; while (!producer_is_finished || !queue_is_empty) { - RequestManager::RequestGuid guid = RequestManager::INVALID_GUID; + BatchConfig::RequestGuid guid = BatchConfig::INVALID_GUID; { const std::lock_guard lock(guids->request_queue_mutex); queue_is_empty = guids->inf_queue.empty(); @@ -68,7 +68,7 @@ void consume() { guids->inf_queue.pop(); } } - if (guid != RequestManager::INVALID_GUID) { + if (guid != BatchConfig::INVALID_GUID) { GenerationResult result = rm->get_generation_result(guid); } else { std::this_thread::sleep_for(std::chrono::milliseconds(nb_millisecs)); @@ -396,7 +396,7 @@ void FlexFlow::top_level_task(Task const *task, fine_tuning_req.warmup = true; fine_tuning_req.peft_model_id = (peft_model_id != nullptr) ? *peft_model_id : PEFTModelID::NO_ID; - fine_tuning_req.max_training_steps = 1; + fine_tuning_req.peft_finetuning_info.max_training_steps = 1; requests.push_back(fine_tuning_req); std::vector result = model.generate(requests); } @@ -459,10 +459,10 @@ void FlexFlow::top_level_task(Task const *task, fine_tuning_req.max_length = 1024; fine_tuning_req.peft_model_id = (peft_model_id != nullptr) ? *peft_model_id : PEFTModelID::NO_ID; - fine_tuning_req.max_training_steps = 1000000000; - RequestManager::RequestGuid ft_guid = + fine_tuning_req.peft_finetuning_info.max_training_steps = 1000000000; + BatchConfig::RequestGuid ft_guid = rm->register_new_peft_request(fine_tuning_req); - if (ft_guid != RequestManager::INVALID_GUID) { + if (ft_guid != BatchConfig::INVALID_GUID) { const std::lock_guard lock(guids->request_queue_mutex); guids->peft_queue.push(ft_guid); } @@ -495,9 +495,9 @@ void FlexFlow::top_level_task(Task const *task, { const std::lock_guard lock(guids->request_queue_mutex); for (int i = 0; i < requests.size(); i++) { - RequestManager::RequestGuid guid = + BatchConfig::RequestGuid guid = rm->register_new_request(requests.at(i)); - if (guid != RequestManager::INVALID_GUID) { + if (guid != BatchConfig::INVALID_GUID) { guids->inf_queue.push(guid); } } diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 479a232a53..18f7692ca2 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1747,7 +1747,8 @@ void flexflow_model_generate(flexflow_model_t handle_, } std::string const dataset_fp(dataset_filepaths[i]); fine_tuning_req.peft_finetuning_info.dataset_filepath = dataset_fp; - fine_tuning_req.peft_finetuning_info.max_training_steps = training_steps[i]; + fine_tuning_req.peft_finetuning_info.max_training_steps = + training_steps[i]; requests.push_back(fine_tuning_req); DEBUG_PRINT("[Model] finetune[%d] %p %s %i %i %i %i", i, diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index 90d0ae4398..072ce364d5 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -698,19 +698,19 @@ __global__ void scaling_query_kernel(DT *input_ptr, template __global__ void apply_rotary_embedding_fwd(DT *input_ptr, - hipFloatComplex *complex_input, - BatchConfig::PerTokenInfo const *tokenInfos, - float rope_theta, - bool llama3_rope, - float factor, - float low_freq_factor, - float high_freq_factor, - int original_max_position_embeddings, - int qProjSize, - int kProjSize, - int num_tokens, - size_t q_array_size, - int hidden_size) { + hipFloatComplex *complex_input, + BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, + int qProjSize, + int kProjSize, + int num_tokens, + size_t q_array_size, + int hidden_size) { CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { // create complex number bool q_tensor = i < (q_array_size / 2); diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 4b86bd6a67..cbb5bf638f 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -654,19 +654,19 @@ __global__ void scaling_query_kernel(DT *input_ptr, template __global__ void apply_rotary_embedding_fwd(DT *input_ptr, - cuFloatComplex *complex_input, - BatchConfig::PerTokenInfo const *tokenInfos, - float rope_theta, - bool llama3_rope, - float factor, - float low_freq_factor, - float high_freq_factor, - int original_max_position_embeddings, - int qProjSize, - int kProjSize, - int num_tokens, - size_t q_array_size, - int hidden_size) { + cuFloatComplex *complex_input, + BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, + int qProjSize, + int kProjSize, + int num_tokens, + size_t q_array_size, + int hidden_size) { CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { // create complex number bool q_tensor = i < (q_array_size / 2); @@ -826,9 +826,9 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, /*q&k*/ parallelism = num_tokens * m->hidden_size; apply_rotary_embedding_fwd<<>>( + min(CUDA_NUM_THREADS, parallelism), + 0, + stream>>>( output_ptr, m->complex_input, m->token_infos, diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index ea14fa1c51..301645e46d 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -380,7 +380,9 @@ void InferenceManager::init_operators_inference(FFModel *model) { } } -InferenceResultFuture InferenceManager::inference(FFModel *model, int index, BatchConfig const &bc) { +InferenceResultFuture InferenceManager::inference(FFModel *model, + int index, + BatchConfig const &bc) { if (bc.get_mode() == INC_DECODING_MODE) { BatchConfigFuture bcf = Future::from_value(bc); return inference(model, index, bcf); @@ -403,7 +405,9 @@ InferenceResultFuture InferenceManager::inference(FFModel *model, int index, Bat } } -InferenceResultFuture InferenceManager::inference(FFModel *model, int index, BatchConfigFuture const &bc) { +InferenceResultFuture InferenceManager::inference(FFModel *model, + int index, + BatchConfigFuture const &bc) { // log_inf_mgr.print("mode(%d) num_active_infr_tokens(%d) // num_active_requests(%d)", // bc.get_mode(), @@ -465,7 +469,9 @@ InferenceResultFuture InferenceManager::inference(FFModel *model, int index, Bat return irf; }; -FinetuningBwdFuture InferenceManager::peft_bwd(FFModel *model, int index, BatchConfigFuture const &bc) { +FinetuningBwdFuture InferenceManager::peft_bwd(FFModel *model, + int index, + BatchConfigFuture const &bc) { int batch_index = index % model->config.data_parallelism_degree; FutureMap fm; bool found_input_operator = false; diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 3ac895ad60..292ed6cab1 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -4690,14 +4690,14 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.set_leaf(); if (pre_register) { Runtime::preregister_task_variant< - BatchConfig, + std::pair, RequestManager::prepare_next_batch_task>( registrar, "RequestManager Prepare Next Batch Task"); } else { if (enable_control_replication) { registrar.global_registration = false; } - runtime->register_task_variantregister_task_variant, RequestManager::prepare_next_batch_task>( registrar); } diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 019f8577f4..143a23aba3 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -54,23 +54,17 @@ RequestGuid RequestManager::assign_next_guid() { return next_available_guid++; } -Request::Request(const Request& other) - : req_type(other.req_type), - max_length(other.max_length), +Request::Request(Request const &other) + : req_type(other.req_type), max_length(other.max_length), max_new_tokens(other.max_new_tokens), benchmarking_tokens(other.benchmarking_tokens), - add_special_tokens(other.add_special_tokens), - warmup(other.warmup), - status(Request::PENDING), - prompt(other.prompt), - tokens(other.tokens), + add_special_tokens(other.add_special_tokens), warmup(other.warmup), + status(Request::PENDING), prompt(other.prompt), tokens(other.tokens), peft_model_id(other.peft_model_id), peft_finetuning_info(other.peft_finetuning_info), - initial_len(other.initial_len), - ssm_cache_size(other.ssm_cache_size), - llm_cache_size(other.llm_cache_size), - beam_trees(other.beam_trees) { - + initial_len(other.initial_len), ssm_cache_size(other.ssm_cache_size), + llm_cache_size(other.llm_cache_size), beam_trees(other.beam_trees) { + RequestManager *rm = RequestManager::get_request_manager(); guid = rm->assign_next_guid(); int max_seq_len = rm->get_max_sequence_length(); @@ -82,13 +76,16 @@ Request::Request(const Request& other) // both set if (max_length != -1 && max_new_tokens != -1) { max_length = -1; - std::cout << "Both `max_new_tokens` (=" << max_new_tokens - << ") and `max_length`(=" << max_length - << ") seem to have been set. `max_new_tokens` will take precedence."; + std::cout + << "Both `max_new_tokens` (=" << max_new_tokens + << ") and `max_length`(=" << max_length + << ") seem to have been set. `max_new_tokens` will take precedence."; } } else { if (max_new_tokens != -1) { - std::cerr << "Error: max_new_tokens is not allowed for PEFT finetuning requests" << std::endl; + std::cerr + << "Error: max_new_tokens is not allowed for PEFT finetuning requests" + << std::endl; assert(false); } if (max_length == -1) { @@ -100,11 +97,13 @@ Request::Request(const Request& other) bool RequestManager::load_request_token_ids(Request &request) { if (request.req_type == RequestType::REQ_INFERENCE) { // load prompt token ids - if (bos_token_id >= 0 && model_type != ModelType::FALCON && request.add_special_tokens) { + if (bos_token_id >= 0 && model_type != ModelType::FALCON && + request.add_special_tokens) { request.tokens.push_back(bos_token_id); } if (request.benchmarking_tokens >= 0) { - assert(request.benchmarking_tokens < get_max_sequence_length() && "Benchmarking tokens exceed max sequence length"); + assert(request.benchmarking_tokens < get_max_sequence_length() && + "Benchmarking tokens exceed max sequence length"); request.tokens.insert(request.tokens.end(), request.benchmarking_tokens, 15); // insert random number @@ -139,7 +138,7 @@ bool RequestManager::load_request_token_ids(Request &request) { if (get_num_ssms() == 0) { std::cout << "No small speculative model registered, using incremental " - "decoding." + "decoding." << std::endl; } else { std::cout << "Num of SSMs: " << get_num_ssms() << std::endl; @@ -152,8 +151,8 @@ bool RequestManager::load_request_token_ids(Request &request) { // load dataset token ids if (request.benchmarking_tokens >= 0) { assert(request.benchmarking_tokens <= get_max_sequence_length() && - "Benchmarking tokens exceed max sequence length"); - + "Benchmarking tokens exceed max sequence length"); + std::vector input_tokens; bool bos_added = (bos_token_id >= 0 && request.add_special_tokens && model_type != ModelType::FALCON); @@ -193,23 +192,25 @@ bool RequestManager::load_request_token_ids(Request &request) { } } if (request.peft_finetuning_info.gradient_accumulation_steps == -1) { - request.peft_finetuning_info.gradient_accumulation_steps = request.dataset.size(); + request.peft_finetuning_info.gradient_accumulation_steps = + request.dataset.size(); } assert(request.peft_finetuning_info.gradient_accumulation_steps > 0 && - "Invalid gradient accumulation steps"); - assert(request.peft_finetuning_info.gradient_accumulation_steps <= request.peft_finetuning_info.max_training_steps && - "Gradient accumulation steps should be less than or equal to max " - "training steps"); - } - assert(get_num_ssms() == 0 && "Small speculative models not supported for " + "Invalid gradient accumulation steps"); + assert(request.peft_finetuning_info.gradient_accumulation_steps <= + request.peft_finetuning_info.max_training_steps && + "Gradient accumulation steps should be less than or equal to max " + "training steps"); + } + assert(get_num_ssms() == 0 && "Small speculative models not supported for " "PEFT finetuning requests"); return true; } template -std::ostream& operator<<(std::ostream& os, const std::vector& array) { +std::ostream &operator<<(std::ostream &os, std::vector const &array) { os << "["; - for (const auto& element : array) { + for (auto const &element : array) { os << element << " "; } os << "]"; @@ -233,13 +234,20 @@ std::ostream &operator<<(std::ostream &os, Request const &req) { } else { os << " peft_finetuning_info: {\n"; os << " status: " << req.peft_finetuning_info.status << "\n"; - os << " dataset_entry_processed_tokens: " << req.peft_finetuning_info.dataset_entry_processed_tokens << "\n"; - os << " max_training_steps: " << req.peft_finetuning_info.max_training_steps << "\n"; - os << " gradient_accumulation_steps: " << req.peft_finetuning_info.gradient_accumulation_steps << "\n"; - os << " completed_training_steps: " << req.peft_finetuning_info.completed_training_steps << "\n"; - // os << " finetuning_tokens_per_batch: " << req.peft_finetuning_info.finetuning_tokens_per_batch << "\n"; - os << " finetuning_losses: " << req.peft_finetuning_info.finetuning_losses << "\n"; - os << " dataset_filepath: " << req.peft_finetuning_info.dataset_filepath << "\n"; + os << " dataset_entry_processed_tokens: " + << req.peft_finetuning_info.dataset_entry_processed_tokens << "\n"; + os << " max_training_steps: " + << req.peft_finetuning_info.max_training_steps << "\n"; + os << " gradient_accumulation_steps: " + << req.peft_finetuning_info.gradient_accumulation_steps << "\n"; + os << " completed_training_steps: " + << req.peft_finetuning_info.completed_training_steps << "\n"; + // os << " finetuning_tokens_per_batch: " << + // req.peft_finetuning_info.finetuning_tokens_per_batch << "\n"; + os << " finetuning_losses: " + << req.peft_finetuning_info.finetuning_losses << "\n"; + os << " dataset_filepath: " << req.peft_finetuning_info.dataset_filepath + << "\n"; os << " dataset: " << req.dataset.size() << " entries\n"; os << " }\n"; } @@ -459,7 +467,8 @@ int RequestManager::get_num_transformer_layers() { return num_transformer_layers; } -void RequestManager::set_num_layers_per_finetuning_step(int num_layers_per_finetuning_step_) { +void RequestManager::set_num_layers_per_finetuning_step( + int num_layers_per_finetuning_step_) { num_layers_per_finetuning_step = num_layers_per_finetuning_step_; } @@ -500,8 +509,7 @@ PEFTModelID * return peft_model_id; } -RequestGuid - RequestManager::register_new_request(Request const &request_) { +RequestGuid RequestManager::register_new_request(Request const &request_) { const std::lock_guard lock(request_queue_mutex); // Add a new request Request request(request_); @@ -545,7 +553,7 @@ RequestGuid RequestManager::register_new_peft_request(Request const &request_) { const std::lock_guard lock(request_queue_mutex); // Add a new request Request request(request_); - if(!load_request_token_ids(request)) { + if (!load_request_token_ids(request)) { return BatchConfig::INVALID_GUID; } @@ -587,7 +595,8 @@ bool RequestManager::is_request_completed(RequestGuid const &guid) { return request.status == Request::COMPLETED; } -GenerationResult RequestManager::get_generation_result(RequestGuid const &guid) { +GenerationResult + RequestManager::get_generation_result(RequestGuid const &guid) { // First get the future of the request std::future future; { @@ -611,16 +620,18 @@ size_t RequestManager::get_num_processed_requests() { } BatchConfigPairFuture RequestManager::prepare_next_batch( - std::tuple &batch_pipeline_entry, - Context ctx, - Runtime *runtime) { + std::tuple &batch_pipeline_entry, + Context ctx, + Runtime *runtime) { RequestManager *rm = this; TaskLauncher launcher(RM_PREPARE_NEXT_BATCH_TASK_ID, TaskArgument(&rm, sizeof(RequestManager *))); launcher.add_future(std::get<0>(batch_pipeline_entry)); launcher.add_future(std::get<1>(batch_pipeline_entry)); launcher.add_future(std::get<2>(batch_pipeline_entry)); - launcher.add_future(std::get<3>(batch_pipeline_entry)); + // launcher.add_future(std::get<3>(batch_pipeline_entry)); return runtime->execute_task(ctx, launcher); } @@ -632,7 +643,8 @@ std::pair RequestManager::prepare_next_batch_task( RequestManager *rm = *((RequestManager **)task->args); BatchConfig const *old_fwd_bc = BatchConfig::from_future(task->futures[0]); BatchConfig const *old_bwd_bc = BatchConfig::from_future(task->futures[1]); - InferenceResult const &result = Future(task->futures[2]).get_result(); + InferenceResult const &result = + Future(task->futures[2]).get_result(); Future(task->futures[3]).get_void_result(); rm->process_work_from_old_batches(*old_fwd_bc, *old_bwd_bc, result); BatchConfig new_fwd_bc = rm->prepare_next_fwd_batch(*old_fwd_bc, result); @@ -649,8 +661,7 @@ bool RequestManager::is_eos_token(int token_id) { return false; } -bool RequestManager::inf_req_completed(BatchConfig const &old_bc, - int i) { +bool RequestManager::inf_req_completed(BatchConfig const &old_bc, int i) { Request &request = all_requests[old_bc.requestsInfo[i].request_guid]; bool request_completed = false; // printf("model_type = %d\n", this->model_type); @@ -717,21 +728,27 @@ void RequestManager::add_peft_config_to_request_info( // << bc.requestsInfo[req_idx].peft_model_config_str << std::endl; } -void RequestManager::process_inf_req_progress(BatchConfig const &old_fwd_bc, InferenceResult const &result) { +void RequestManager::process_inf_req_progress(BatchConfig const &old_fwd_bc, + InferenceResult const &result) { for (int i = 0; i < old_fwd_bc.num_active_tokens(); i++) { - size_t guid = old_fwd_bc.requestsInfo[old_fwd_bc.tokensInfo[i].request_index].request_guid; + size_t guid = + old_fwd_bc.requestsInfo[old_fwd_bc.tokensInfo[i].request_index] + .request_guid; Request &request = all_requests[guid]; if (request.req_type == RequestType::REQ_FINETUNING) { // finetuning requests don't produce any new decoding token continue; } - if (old_fwd_bc.tokensInfo[i].abs_depth_in_request + 1 < request.tokens.size()) { - assert(old_fwd_bc.requestsInfo[old_fwd_bc.tokensInfo[i].request_index].prompt_phase == true); + if (old_fwd_bc.tokensInfo[i].abs_depth_in_request + 1 < + request.tokens.size()) { + assert(old_fwd_bc.requestsInfo[old_fwd_bc.tokensInfo[i].request_index] + .prompt_phase == true); // This is a prompt token continue; } else { // This is a decoding token - assert(old_fwd_bc.tokensInfo[i].abs_depth_in_request + 1 == request.tokens.size()); + assert(old_fwd_bc.tokensInfo[i].abs_depth_in_request + 1 == + request.tokens.size()); request.tokens.push_back(result.token_ids[i]); if (!profiling_requests[guid].first_token_time_set) { profiling_requests[guid].first_token_time = @@ -741,18 +758,22 @@ void RequestManager::process_inf_req_progress(BatchConfig const &old_fwd_bc, Inf // log_req_mgr.print("Output token is: %d", result.token_ids[i]); } } - int inference_batch_size = BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; + int inference_batch_size = + BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; for (int req_idx = 0; req_idx < inference_batch_size; req_idx++) { - if (!old_fwd_bc.request_completed[req_idx] && inf_req_completed(old_fwd_bc, req_idx)) { + if (!old_fwd_bc.request_completed[req_idx] && + inf_req_completed(old_fwd_bc, req_idx)) { handle_completed_inf_req(old_fwd_bc, req_idx); } } } -void RequestManager::handle_completed_inf_req(BatchConfig const &old_bc, int i) { +void RequestManager::handle_completed_inf_req(BatchConfig const &old_bc, + int i) { Request &request = all_requests[old_bc.requestsInfo[i].request_guid]; assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0); - assert(request.req_type == RequestType::REQ_INFERENCE && "Found misplaced finetuning request"); + assert(request.req_type == RequestType::REQ_INFERENCE && + "Found misplaced finetuning request"); if (is_eos_token(request.tokens.back())) { // remove the EOS token @@ -762,8 +783,7 @@ void RequestManager::handle_completed_inf_req(BatchConfig const &old_bc, int i) // Unlike Huggingface, the sentencepiece C++ library automatically // removes the BOS token if (model_type == ModelType::LLAMA && old_llama_tokenizer && - request.add_special_tokens && - request.tokens.at(0) == bos_token_id) { + request.add_special_tokens && request.tokens.at(0) == bos_token_id) { output = " " + output; } { @@ -782,8 +802,7 @@ void RequestManager::handle_completed_inf_req(BatchConfig const &old_bc, int i) num_processed_requests++; ProfileInfo profile_info = profiling_requests[request.guid]; profile_info.finish_time = Realm::Clock::current_time_in_microseconds(); - total_request_run_time += - profile_info.finish_time - profile_info.start_time; + total_request_run_time += profile_info.finish_time - profile_info.start_time; profiling_requests[request.guid] = profile_info; log_req_mgr.print("[%s] guid(%zu) llm_decoding_steps(%d) start(%.1lf) " "finish(%.1lf) latency(%.1lf) ttft(%.1lf)", @@ -799,15 +818,15 @@ void RequestManager::handle_completed_inf_req(BatchConfig const &old_bc, int i) if (!output_filepath.empty()) { std::ofstream outputFile(output_filepath, std::ios::app); if (outputFile.is_open()) { - outputFile << "[" << (request.warmup ? "Warmup" : "Profile") - << "] guid(" << request.guid << ") llm_decoding_steps(" - << profile_info.llm_decoding_steps << ") latency(" - << std::fixed << std::setprecision(3) - << (profile_info.finish_time - profile_info.start_time) - << ") ttft(" << std::fixed << std::setprecision(3) - << (profile_info.first_token_time - - profile_info.registration_time) - << ")\n"; + outputFile << "[" << (request.warmup ? "Warmup" : "Profile") << "] guid(" + << request.guid << ") llm_decoding_steps(" + << profile_info.llm_decoding_steps << ") latency(" + << std::fixed << std::setprecision(3) + << (profile_info.finish_time - profile_info.start_time) + << ") ttft(" << std::fixed << std::setprecision(3) + << (profile_info.first_token_time - + profile_info.registration_time) + << ")\n"; if (request.benchmarking_tokens <= 0) { outputFile << "token IDs: "; for (int i = 0; i < request.tokens.size(); i++) { @@ -828,33 +847,46 @@ void RequestManager::handle_completed_inf_req(BatchConfig const &old_bc, int i) } } -void RequestManager::add_continuing_inf_req_to_new_batch(BatchConfig &new_bc, BatchConfig const &old_bc, int &num_active_req, int &num_concurrent_inf_adapters, int i) { - assert(new_bc.num_tokens < get_max_tokens_per_batch() && "Trying to add a continuing request when the batch is full"); +void RequestManager::add_continuing_inf_req_to_new_batch( + BatchConfig &new_bc, + BatchConfig const &old_bc, + int &num_active_req, + int &num_concurrent_inf_adapters, + int i) { + assert(new_bc.num_tokens < get_max_tokens_per_batch() && + "Trying to add a continuing request when the batch is full"); Request &request = all_requests[old_bc.requestsInfo[i].request_guid]; assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0); - assert(request.req_type == RequestType::REQ_INFERENCE && "Found misplaced finetuning request"); - int processed_tokens = old_bc.requestsInfo[i].first_token_depth_in_request + old_bc.requestsInfo[i].num_tokens_in_batch; - assert(processed_tokens < request.tokens.size() && "Continuing request has already finished"); - int inference_batch_size = BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; + assert(request.req_type == RequestType::REQ_INFERENCE && + "Found misplaced finetuning request"); + int processed_tokens = old_bc.requestsInfo[i].first_token_depth_in_request + + old_bc.requestsInfo[i].num_tokens_in_batch; + assert(processed_tokens < request.tokens.size() && + "Continuing request has already finished"); + int inference_batch_size = + BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; if (old_bc.requestsInfo[i].peft_model_id != PEFTModelID::NO_ID) { num_concurrent_inf_adapters += 1; } - + // add request to new bc, at the same index new_bc.request_completed[i] = false; new_bc.requestsInfo[i].first_token_depth_in_request = processed_tokens; new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].peft_model_id = old_bc.requestsInfo[i].peft_model_id; - std::strcpy(new_bc.requestsInfo[i].peft_model_config_str, old_bc.requestsInfo[i].peft_model_config_str); - new_bc.requestsInfo[i].finetuning_request = old_bc.requestsInfo[i].finetuning_request; + std::strcpy(new_bc.requestsInfo[i].peft_model_config_str, + old_bc.requestsInfo[i].peft_model_config_str); + new_bc.requestsInfo[i].finetuning_request = + old_bc.requestsInfo[i].finetuning_request; new_bc.requestsInfo[i].max_length = old_bc.requestsInfo[i].max_length; - + num_active_req++; new_bc.requestsInfo[num_active_req].batch_config_request_id = i; - - if (new_bc.requestsInfo[i].first_token_depth_in_request + 1 == request.tokens.size()) { + + if (new_bc.requestsInfo[i].first_token_depth_in_request + 1 == + request.tokens.size()) { // Incremental phase new_bc.requestsInfo[i].num_tokens_in_batch = 1; new_bc.num_generation_tokens++; @@ -878,9 +910,9 @@ void RequestManager::add_continuing_inf_req_to_new_batch(BatchConfig &new_bc, Ba } new_bc.requestsInfo[i].num_tokens_in_batch = std::min(get_max_tokens_per_batch() - new_bc.num_tokens - - space_for_incr_dec_requests, - (int)request.tokens.size() - - new_bc.requestsInfo[i].first_token_depth_in_request); + space_for_incr_dec_requests, + (int)request.tokens.size() - + new_bc.requestsInfo[i].first_token_depth_in_request); new_bc.requestsInfo[i].prompt_phase = true; } for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { @@ -895,16 +927,21 @@ void RequestManager::add_continuing_inf_req_to_new_batch(BatchConfig &new_bc, Ba profiling_requests[new_bc.requestsInfo[i].request_guid].llm_decoding_steps++; } -void RequestManager::add_new_inf_req(BatchConfig &new_bc, int &num_active_req, int &num_concurrent_inf_adapters, int i) { - assert(!pending_infr_request_queue.empty() && "Trying to add a new inference request when there are none"); - assert(new_bc.num_tokens < get_max_tokens_per_batch() && "Trying to add a new inference request when the batch is full"); - +void RequestManager::add_new_inf_req(BatchConfig &new_bc, + int &num_active_req, + int &num_concurrent_inf_adapters, + int i) { + assert(!pending_infr_request_queue.empty() && + "Trying to add a new inference request when there are none"); + assert(new_bc.num_tokens < get_max_tokens_per_batch() && + "Trying to add a new inference request when the batch is full"); + Request new_request = pending_infr_request_queue.front(); assert(new_request.req_type == RequestType::REQ_INFERENCE); // if the request has peft adapters and we are at capacity, don't add it yet if (new_request.peft_model_id != PEFTModelID::NO_ID && - num_concurrent_inf_adapters == get_max_concurrent_adapters()) { + num_concurrent_inf_adapters == get_max_concurrent_adapters()) { return; } @@ -913,11 +950,14 @@ void RequestManager::add_new_inf_req(BatchConfig &new_bc, int &num_active_req, i new_bc.requestsInfo[i].first_token_depth_in_request = 0; new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = new_request.guid; - new_bc.requestsInfo[i].num_tokens_in_batch = std::min(get_max_tokens_per_batch() - new_bc.num_tokens, (int)new_request.tokens.size()); + new_bc.requestsInfo[i].num_tokens_in_batch = + std::min(get_max_tokens_per_batch() - new_bc.num_tokens, + (int)new_request.tokens.size()); new_bc.requestsInfo[i].max_length = new_request.max_length; new_bc.requestsInfo[i].peft_model_id = new_request.peft_model_id; if (new_request.peft_model_id != PEFTModelID::NO_ID) { - add_peft_config_to_request_info(new_bc, i, get_peft_config(new_request.peft_model_id)); + add_peft_config_to_request_info( + new_bc, i, get_peft_config(new_request.peft_model_id)); } new_bc.requestsInfo[i].finetuning_request = false; new_bc.request_completed[i] = false; @@ -938,16 +978,21 @@ void RequestManager::add_new_inf_req(BatchConfig &new_bc, int &num_active_req, i } } -void RequestManager::handle_completed_finetuning_req(BatchConfig const &old_finetuning_bc) { - assert(old_finetuning_bc.num_active_requests() == 1 && "Number of active requests in a finetuning batch should be 1"); - assert(!old_finetuning_bc.request_completed[0] && "Finetuning request not found in new batch"); - +void RequestManager::handle_completed_finetuning_req( + BatchConfig const &old_finetuning_bc) { + assert(old_finetuning_bc.num_active_requests() == 1 && + "Number of active requests in a finetuning batch should be 1"); + assert(!old_finetuning_bc.request_completed[0] && + "Finetuning request not found in new batch"); + // sync metadata with all_requests Request &pq_request = pending_peft_request_queue.front(); Request &request = all_requests[pq_request.guid]; - assert(request.req_type == RequestType::REQ_FINETUNING && "Found misplaced inference request"); + assert(request.req_type == RequestType::REQ_FINETUNING && + "Found misplaced inference request"); assert(request.guid == pq_request.guid && "Request GUID mismatch"); - assert(old_finetuning_bc.requestsInfo[0].request_guid == pq_request.guid && "Request GUID mismatch"); + assert(old_finetuning_bc.requestsInfo[0].request_guid == pq_request.guid && + "Request GUID mismatch"); request.status = Request::COMPLETED; request.peft_finetuning_info = pq_request.peft_finetuning_info; // remove from pending peft queue @@ -962,8 +1007,7 @@ void RequestManager::handle_completed_finetuning_req(BatchConfig const &old_fine ProfileInfo profile_info = profiling_requests[request.guid]; profile_info.finish_time = Realm::Clock::current_time_in_microseconds(); - total_request_run_time += - profile_info.finish_time - profile_info.start_time; + total_request_run_time += profile_info.finish_time - profile_info.start_time; profiling_requests[request.guid] = profile_info; // log_req_mgr.print("[%s] guid(%zu) completed_training_steps(%d) " // "processed_finetuning_tokens(%lu) latency(%.1lf)", @@ -1005,92 +1049,127 @@ void RequestManager::handle_completed_finetuning_req(BatchConfig const &old_fine void RequestManager::add_finetuning_req_fwd_batch(BatchConfig &new_bc) { assert(enable_peft_finetuning && "PEFT finetuning is not enabled"); - assert(!pending_peft_request_queue.empty() && "Trying to add a new finetuning request when there are none"); - assert(new_bc.num_tokens < get_max_tokens_per_batch() && "Trying to add a new finetuning request when the batch is full"); - int inference_batch_size = BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; - assert(new_bc.request_completed[inference_batch_size] && "Finetuning request already present in new batch"); + assert(!pending_peft_request_queue.empty() && + "Trying to add a new finetuning request when there are none"); + assert(new_bc.num_tokens < get_max_tokens_per_batch() && + "Trying to add a new finetuning request when the batch is full"); + int inference_batch_size = + BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; + assert(new_bc.request_completed[inference_batch_size] && + "Finetuning request already present in new batch"); Request &request = pending_peft_request_queue.front(); - assert(request.req_type == RequestType::REQ_FINETUNING && "Found misplaced inference request"); + assert(request.req_type == RequestType::REQ_FINETUNING && + "Found misplaced inference request"); assert(request.dataset.size() > 0 && "Empty dataset for finetuning request"); - assert(request.peft_finetuning_info.status == Request::FORWARD_PHASE && "Finetuning request is not in forward phase"); + assert(request.peft_finetuning_info.status == Request::FORWARD_PHASE && + "Finetuning request is not in forward phase"); - int dataset_entry = request.peft_finetuning_info.completed_training_steps % request.dataset.size(); + int dataset_entry = request.peft_finetuning_info.completed_training_steps % + request.dataset.size(); int num_tokens_left_in_dataset_entry = - (int)request.dataset[dataset_entry].size() - request.peft_finetuning_info.dataset_entry_processed_tokens; - int batch_capacity_left = get_max_tokens_per_batch() - new_bc.num_active_tokens(); - int num_peft_tokens = std::min(num_tokens_left_in_dataset_entry, batch_capacity_left); + (int)request.dataset[dataset_entry].size() - + request.peft_finetuning_info.dataset_entry_processed_tokens; + int batch_capacity_left = + get_max_tokens_per_batch() - new_bc.num_active_tokens(); + int num_peft_tokens = + std::min(num_tokens_left_in_dataset_entry, batch_capacity_left); assert(num_peft_tokens > 0 && "No tokens left to add to the batch"); - + // general fields new_bc.request_completed[inference_batch_size] = false; // request info - new_bc.requestsInfo[inference_batch_size].first_token_depth_in_request = request.peft_finetuning_info.dataset_entry_processed_tokens; - new_bc.requestsInfo[inference_batch_size].first_token_offset_in_batch = new_bc.num_active_tokens(); - new_bc.requestsInfo[inference_batch_size].num_tokens_in_batch = num_peft_tokens; + new_bc.requestsInfo[inference_batch_size].first_token_depth_in_request = + request.peft_finetuning_info.dataset_entry_processed_tokens; + new_bc.requestsInfo[inference_batch_size].first_token_offset_in_batch = + new_bc.num_active_tokens(); + new_bc.requestsInfo[inference_batch_size].num_tokens_in_batch = + num_peft_tokens; new_bc.requestsInfo[inference_batch_size].max_length = request.max_length; new_bc.requestsInfo[inference_batch_size].request_guid = request.guid; - new_bc.requestsInfo[inference_batch_size].peft_model_id = request.peft_model_id; - add_peft_config_to_request_info(new_bc, inference_batch_size, get_peft_config(request.peft_model_id)); + new_bc.requestsInfo[inference_batch_size].peft_model_id = + request.peft_model_id; + add_peft_config_to_request_info( + new_bc, inference_batch_size, get_peft_config(request.peft_model_id)); new_bc.requestsInfo[inference_batch_size].finetuning_request = true; new_bc.requestsInfo[inference_batch_size].finetuning_backward_phase = false; - + // set_optimizer_tasks( // new_bc.requestsInfo[inference_batch_size].optimizer_tasks, // request.peft_finetuning_info.max_training_steps, // request.peft_finetuning_info.completed_training_steps, // request.peft_finetuning_info.gradient_accumulation_steps); - + // tokens info - for (size_t i = request.peft_finetuning_info.dataset_entry_processed_tokens; i < request.peft_finetuning_info.dataset_entry_processed_tokens + num_peft_tokens; i++) { + for (size_t i = request.peft_finetuning_info.dataset_entry_processed_tokens; + i < request.peft_finetuning_info.dataset_entry_processed_tokens + + num_peft_tokens; + i++) { new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = i; new_bc.tokensInfo[new_bc.num_tokens].request_index = inference_batch_size; - new_bc.tokensInfo[new_bc.num_tokens].token_id = request.dataset[dataset_entry][i]; + new_bc.tokensInfo[new_bc.num_tokens].token_id = + request.dataset[dataset_entry][i]; new_bc.num_tokens++; } } void RequestManager::add_finetuning_req_bwd_batch(BatchConfig &new_bc) { assert(enable_peft_finetuning && "PEFT finetuning is not enabled"); - assert(!pending_peft_request_queue.empty() && "Trying to add a new finetuning request when there are none"); - assert(new_bc.num_tokens < get_max_tokens_per_batch() && "Trying to add a new finetuning request when the batch is full"); - assert(new_bc.request_completed[0] && "Finetuning request already present in new batch"); + assert(!pending_peft_request_queue.empty() && + "Trying to add a new finetuning request when there are none"); + assert(new_bc.num_tokens < get_max_tokens_per_batch() && + "Trying to add a new finetuning request when the batch is full"); + assert(new_bc.request_completed[0] && + "Finetuning request already present in new batch"); Request &request = pending_peft_request_queue.front(); - assert(request.req_type == RequestType::REQ_FINETUNING && "Found misplaced inference request"); + assert(request.req_type == RequestType::REQ_FINETUNING && + "Found misplaced inference request"); assert(request.dataset.size() > 0 && "Empty dataset for finetuning request"); - assert(request.peft_finetuning_info.status == Request::BACKWARD_PHASE && "Finetuning request is not in backward phase"); + assert(request.peft_finetuning_info.status == Request::BACKWARD_PHASE && + "Finetuning request is not in backward phase"); + + int dataset_entry = request.peft_finetuning_info.completed_training_steps % + request.dataset.size(); + assert(request.dataset[dataset_entry].size() <= get_max_tokens_per_batch() && + "Dataset entry does not fit in the batch size"); - int dataset_entry = request.peft_finetuning_info.completed_training_steps % request.dataset.size(); - assert(request.dataset[dataset_entry].size() <= get_max_tokens_per_batch() && "Dataset entry does not fit in the batch size"); - // general fields - int inference_batch_size = BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; + int inference_batch_size = + BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; new_bc.request_completed[inference_batch_size] = false; // request info new_bc.requestsInfo[inference_batch_size].first_token_depth_in_request = 0; new_bc.requestsInfo[inference_batch_size].first_token_offset_in_batch = 0; - new_bc.requestsInfo[inference_batch_size].num_tokens_in_batch = request.dataset[dataset_entry].size(); + new_bc.requestsInfo[inference_batch_size].num_tokens_in_batch = + request.dataset[dataset_entry].size(); new_bc.requestsInfo[inference_batch_size].max_length = request.max_length; new_bc.requestsInfo[inference_batch_size].request_guid = request.guid; - new_bc.requestsInfo[inference_batch_size].peft_model_id = request.peft_model_id; - add_peft_config_to_request_info(new_bc, inference_batch_size, get_peft_config(request.peft_model_id)); + new_bc.requestsInfo[inference_batch_size].peft_model_id = + request.peft_model_id; + add_peft_config_to_request_info( + new_bc, inference_batch_size, get_peft_config(request.peft_model_id)); new_bc.requestsInfo[inference_batch_size].finetuning_request = true; new_bc.requestsInfo[inference_batch_size].finetuning_backward_phase = true; - new_bc.requestsInfo[inference_batch_size].peft_bwd_last_layer = min(request.peft_finetuning_info.last_processed_layer-1, get_num_transformer_layers()-1); + new_bc.requestsInfo[inference_batch_size].peft_bwd_last_layer = + min(request.peft_finetuning_info.last_processed_layer - 1, + get_num_transformer_layers() - 1); assert(new_bc.requestsInfo[inference_batch_size].peft_bwd_last_layer >= 0); - new_bc.requestsInfo[inference_batch_size].peft_bwd_first_layer = new_bc.requestsInfo[inference_batch_size].peft_bwd_last_layer - get_num_layers_per_finetuning_step(); + new_bc.requestsInfo[inference_batch_size].peft_bwd_first_layer = + new_bc.requestsInfo[inference_batch_size].peft_bwd_last_layer - + get_num_layers_per_finetuning_step(); assert(new_bc.requestsInfo[inference_batch_size].peft_bwd_first_layer >= 0); - + // set_optimizer_tasks( // new_bc.requestsInfo[inference_batch_size].optimizer_tasks, // request.peft_finetuning_info.max_training_steps, // request.peft_finetuning_info.completed_training_steps, // request.peft_finetuning_info.gradient_accumulation_steps); - + // tokens info for (size_t i = 0; i < request.dataset[dataset_entry].size(); i++) { new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = i; new_bc.tokensInfo[new_bc.num_tokens].request_index = 0; - new_bc.tokensInfo[new_bc.num_tokens].token_id = request.dataset[dataset_entry][i]; + new_bc.tokensInfo[new_bc.num_tokens].token_id = + request.dataset[dataset_entry][i]; new_bc.num_tokens++; } } @@ -1111,70 +1190,108 @@ bool RequestManager::finetuning_bwd_work_available() { return request.peft_finetuning_info.status == Request::BACKWARD_PHASE; } -void RequestManager::process_finetuning_req_fwd_progress(BatchConfig const &old_fwd_bc, InferenceResult const &result) { +void RequestManager::process_finetuning_req_fwd_progress( + BatchConfig const &old_fwd_bc, InferenceResult const &result) { if (old_fwd_bc.num_finetuning_requests() == 0) { return; } - int inference_batch_size = BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; - assert(!old_fwd_bc.request_completed[inference_batch_size] && "Finetuning request not found in new batch"); - assert(old_fwd_bc.requestsInfo[inference_batch_size].num_tokens_in_batch > 0 && "Trying to continue an empty finetuning request"); + int inference_batch_size = + BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; + assert(!old_fwd_bc.request_completed[inference_batch_size] && + "Finetuning request not found in new batch"); + assert(old_fwd_bc.requestsInfo[inference_batch_size].num_tokens_in_batch > + 0 && + "Trying to continue an empty finetuning request"); Request &request = pending_peft_request_queue.front(); - assert(request.req_type == RequestType::REQ_FINETUNING && "Found misplaced inference request"); - assert(request.peft_finetuning_info.completed_training_steps <= request.peft_finetuning_info.max_training_steps); - assert(request.guid == old_fwd_bc.requestsInfo[inference_batch_size].request_guid && "Request GUID mismatch"); - assert(request.peft_finetuning_info.dataset_entry_processed_tokens == old_fwd_bc.requestsInfo[inference_batch_size].first_token_depth_in_request && "Token depth mismatch"); - - request.peft_finetuning_info.dataset_entry_processed_tokens += old_fwd_bc.requestsInfo[inference_batch_size].num_tokens_in_batch; - - int dataset_entry = request.peft_finetuning_info.completed_training_steps % request.dataset.size(); - bool first_fwd_dataset_entry = request.peft_finetuning_info.dataset_entry_processed_tokens == 0; - bool dataset_entry_finished = request.peft_finetuning_info.dataset_entry_processed_tokens + old_fwd_bc.requestsInfo[inference_batch_size].num_tokens_in_batch == request.dataset[dataset_entry].size(); - - float avg_loss = result.finetuning_loss * old_fwd_bc.requestsInfo[inference_batch_size].num_tokens_in_batch / request.dataset[dataset_entry].size(); + assert(request.req_type == RequestType::REQ_FINETUNING && + "Found misplaced inference request"); + assert(request.peft_finetuning_info.completed_training_steps <= + request.peft_finetuning_info.max_training_steps); + assert(request.guid == + old_fwd_bc.requestsInfo[inference_batch_size].request_guid && + "Request GUID mismatch"); + assert(request.peft_finetuning_info.dataset_entry_processed_tokens == + old_fwd_bc.requestsInfo[inference_batch_size] + .first_token_depth_in_request && + "Token depth mismatch"); + + request.peft_finetuning_info.dataset_entry_processed_tokens += + old_fwd_bc.requestsInfo[inference_batch_size].num_tokens_in_batch; + + int dataset_entry = request.peft_finetuning_info.completed_training_steps % + request.dataset.size(); + bool first_fwd_dataset_entry = + request.peft_finetuning_info.dataset_entry_processed_tokens == 0; + bool dataset_entry_finished = + request.peft_finetuning_info.dataset_entry_processed_tokens + + old_fwd_bc.requestsInfo[inference_batch_size].num_tokens_in_batch == + request.dataset[dataset_entry].size(); + + float avg_loss = + result.finetuning_loss * + old_fwd_bc.requestsInfo[inference_batch_size].num_tokens_in_batch / + request.dataset[dataset_entry].size(); if (first_fwd_dataset_entry) { request.peft_finetuning_info.finetuning_losses.push_back(avg_loss); } else { request.peft_finetuning_info.finetuning_losses.back() += avg_loss; } - + if (dataset_entry_finished) { request.peft_finetuning_info.dataset_entry_processed_tokens = 0; request.peft_finetuning_info.status = Request::BACKWARD_PHASE; } } -void RequestManager::process_finetuning_req_bwd_progress(BatchConfig const &old_bwd_bc) { - assert(old_bwd_bc.num_active_requests() <= 1 && "More than 1 finetuning request in the batch"); +void RequestManager::process_finetuning_req_bwd_progress( + BatchConfig const &old_bwd_bc) { + assert(old_bwd_bc.num_active_requests() <= 1 && + "More than 1 finetuning request in the batch"); if (old_bwd_bc.num_active_requests() == 0) { return; } - assert(!old_bwd_bc.request_completed[0] && "Finetuning request not found in new batch"); - // check that request in batch is the same as the first one in the pending queue + assert(!old_bwd_bc.request_completed[0] && + "Finetuning request not found in new batch"); + // check that request in batch is the same as the first one in the pending + // queue Request &request = pending_peft_request_queue.front(); - assert(request.guid == old_bwd_bc.requestsInfo[0].request_guid && "Finetuning request in batch does not match the one in the pending queue"); - assert(request.req_type == RequestType::REQ_FINETUNING && "Found misplaced inference request"); - assert(request.peft_finetuning_info.status == Request::BACKWARD_PHASE && "Finetuning request is not in backward phase"); - request.peft_finetuning_info.last_processed_layer = old_bwd_bc.requestsInfo[0].peft_bwd_first_layer; + assert(request.guid == old_bwd_bc.requestsInfo[0].request_guid && + "Finetuning request in batch does not match the one in the pending " + "queue"); + assert(request.req_type == RequestType::REQ_FINETUNING && + "Found misplaced inference request"); + assert(request.peft_finetuning_info.status == Request::BACKWARD_PHASE && + "Finetuning request is not in backward phase"); + request.peft_finetuning_info.last_processed_layer = + old_bwd_bc.requestsInfo[0].peft_bwd_first_layer; assert(request.peft_finetuning_info.last_processed_layer >= 0); if (request.peft_finetuning_info.last_processed_layer == 0) { request.peft_finetuning_info.completed_training_steps += 1; request.peft_finetuning_info.status = Request::FORWARD_PHASE; } - if (request.peft_finetuning_info.completed_training_steps == request.peft_finetuning_info.max_training_steps) { + if (request.peft_finetuning_info.completed_training_steps == + request.peft_finetuning_info.max_training_steps) { handle_completed_finetuning_req(old_bwd_bc); } } -void RequestManager::process_work_from_old_batches(BatchConfig const &old_fwd_bc, BatchConfig const &old_bwd_bc, InferenceResult const &result) { +void RequestManager::process_work_from_old_batches( + BatchConfig const &old_fwd_bc, + BatchConfig const &old_bwd_bc, + InferenceResult const &result) { const std::lock_guard lock(request_queue_mutex); - - // Step 1: Inference. Process work from previous fwd iteration: save generated inference tokens and update records of finetuning fwd progress + + // Step 1: Inference. Process work from previous fwd iteration: save generated + // inference tokens and update records of finetuning fwd progress process_inf_req_progress(old_fwd_bc, result); - - // Step 2: Finetuning. Process work from previous bwd iteration: update records of finetuning bwd progress + + // Step 2: Finetuning. Process work from previous bwd iteration: update + // records of finetuning bwd progress if (enable_peft_finetuning) { // check that we did either fwd or bwd, not both - assert((old_bwd_bc.num_finetuning_requests() == 0 || old_fwd_bc.num_finetuning_requests() == 0) && "Both finetuning fwd and bwd requests are present in the batch"); + assert((old_bwd_bc.num_finetuning_requests() == 0 || + old_fwd_bc.num_finetuning_requests() == 0) && + "Both finetuning fwd and bwd requests are present in the batch"); process_finetuning_req_fwd_progress(old_fwd_bc, result); process_finetuning_req_bwd_progress(old_bwd_bc); } @@ -1190,36 +1307,51 @@ BatchConfig RequestManager::prepare_next_bwd_batch() { return new_bc; } -BatchConfig RequestManager::prepare_next_fwd_batch(BatchConfig const &old_fwd_bc, InferenceResult const &result) { +BatchConfig + RequestManager::prepare_next_fwd_batch(BatchConfig const &old_fwd_bc, + InferenceResult const &result) { const std::lock_guard lock(request_queue_mutex); - + // Step 1: Create new batch config BatchConfig new_bc; // params int num_active_req = -1; - // when finetuning is enabled, the last entry in the batch cannot be used for inference - int inference_batch_size = BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; + // when finetuning is enabled, the last entry in the batch cannot be used for + // inference + int inference_batch_size = + BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; int num_concurrent_inf_adapters = 0; // Step 2: prepare the next batch for existing inference requests for (int req_idx = 0; req_idx < inference_batch_size; req_idx++) { - if (!old_fwd_bc.request_completed[req_idx] && !inf_req_completed(old_fwd_bc, req_idx)) { - add_continuing_inf_req_to_new_batch(new_bc, old_fwd_bc, num_active_req, num_concurrent_inf_adapters, req_idx); + if (!old_fwd_bc.request_completed[req_idx] && + !inf_req_completed(old_fwd_bc, req_idx)) { + add_continuing_inf_req_to_new_batch(new_bc, + old_fwd_bc, + num_active_req, + num_concurrent_inf_adapters, + req_idx); } } - assert(num_concurrent_inf_adapters <= get_max_concurrent_adapters() && "Number of concurrent inference adapters exceeded the limit"); + assert(num_concurrent_inf_adapters <= get_max_concurrent_adapters() && + "Number of concurrent inference adapters exceeded the limit"); - // Step 3: add new inference requests to the next batch if there is space and they are available + // Step 3: add new inference requests to the next batch if there is space and + // they are available if (!pending_infr_request_queue.empty()) { - for (int req_idx = 0; req_idx < inference_batch_size && new_bc.num_tokens < get_max_tokens_per_batch(); req_idx++) { + for (int req_idx = 0; req_idx < inference_batch_size && + new_bc.num_tokens < get_max_tokens_per_batch(); + req_idx++) { if (new_bc.request_completed[req_idx]) { - add_new_inf_req(new_bc, num_active_req, num_concurrent_inf_adapters, req_idx); + add_new_inf_req( + new_bc, num_active_req, num_concurrent_inf_adapters, req_idx); } } } // Step 4: add finetuning fwd tokens, if there is additional space - if (finetuning_fwd_work_available() && new_bc.num_tokens < get_max_tokens_per_batch() && !inference_finished) { + if (finetuning_fwd_work_available() && + new_bc.num_tokens < get_max_tokens_per_batch() && !inference_finished) { add_finetuning_req_fwd_batch(new_bc); } @@ -3124,38 +3256,43 @@ void RequestManager::serve_incr_decoding(FFModel *llm) { // init operators im->init_operators_inference(llm); // Legion futures for inc_decoding and spec_infer - BatchConfigFuture last_bcf_fwd, last_bcf_bwd; + // BatchConfigFuture last_bcf_fwd, last_bcf_bwd; + BatchConfigPairFuture last_bcf; InferenceResultFuture last_irf; FinetuningBwdFuture last_bwd_f; { // Initialize futures for incr decoding BatchConfig bc_fwd, bc_bwd; InferenceResult ir; - last_bcf_fwd = Future::from_value(bc_fwd); - last_bcf_bwd = Future::from_value(bc_bwd); + // last_bcf_fwd = Future::from_value(bc_fwd); + // last_bcf_bwd = Future::from_value(bc_bwd); + last_bcf = Future::from_value>( + std::make_pair(bc_fwd, bc_bwd)); last_irf = Future::from_value(ir); last_bwd_f = Future::from_value(true); } - std::queue> batch_pipeline; - // tuple[0]: fwd batch - // tuple[1]: bwd batch - // tuple[2]: inference result - // tuple[3]: bwd future - { batch_pipeline.push(std::make_tuple(last_bcf_fwd, last_bcf_bwd, last_irf, last_bwd_f)); } + std::queue> + batch_pipeline; + // tuple[0]: std::pair + // tuple[1]: inference result + // tuple[2]: bwd future + { batch_pipeline.push(std::make_tuple(last_bcf, last_irf, last_bwd_f)); } while (!is_background_server_terminated()) { if (batch_pipeline.size() >= 4) { // Block here to avoid launching too many batches auto const &batch = batch_pipeline.front(); + std::get<1>(batch).get_void_result(); std::get<2>(batch).get_void_result(); - std::get<3>(batch).get_void_result(); } // deque finished batches while (batch_pipeline.size() > 1) { auto const &batch = batch_pipeline.front(); - if (std::get<2>(batch).is_ready() && std::get<3>(batch).is_ready()) { + if (std::get<1>(batch).is_ready() && std::get<2>(batch).is_ready()) { batch_pipeline.pop(); } else { break; @@ -3164,14 +3301,16 @@ void RequestManager::serve_incr_decoding(FFModel *llm) { runtime->begin_trace(ctx, 12346 /*trace_id*/); auto &batch_pipeline_entry = batch_pipeline.back(); - BatchConfigPairFuture next_batches = prepare_next_batch(batch_pipeline_entry, ctx, runtime); - BatchConfigFuture bcf_fwd = next_batches.first; - BatchConfigFuture bcf_bwd = next_batches.second; - InferenceResultFuture irf = im->inference(llm, 0, bcf_fwd); - FinetuningBwdFuture bwd_f = im->peft_bwd(llm, 0, bcf_bwd); - batch_pipeline.push(std::make_tuple(bcf_fwd, bcf_bwd, irf, bwd_f)); - last_bcf_fwd = bcf_fwd; - last_bcf_bwd = bcf_bwd; + BatchConfigPairFuture bcf = + prepare_next_batch(batch_pipeline_entry, ctx, runtime); + // BatchConfigFuture bcf_fwd = next_batches.first; + // BatchConfigFuture bcf_bwd = next_batches.second; + InferenceResultFuture irf = im->inference(llm, 0, bcf); + FinetuningBwdFuture bwd_f = im->peft_bwd(llm, 0, bcf); + batch_pipeline.push(std::make_tuple(bcf, irf, bwd_f)); + // last_bcf_fwd = bcf_fwd; + // last_bcf_bwd = bcf_bwd; + last_bcf = bcf; last_irf = irf; last_bwd_f = bwd_f; runtime->end_trace(ctx, 12346 /*trace_id*/); @@ -3247,11 +3386,11 @@ void RequestManager::serve_spec_infer(FFModel *llm) { for (size_t i = 0; i < get_num_ssms(); i++) { for (int depth = 0; depth < BeamSearchBatchConfig::MAX_BEAM_DEPTH; depth++) { - beam_bcf = beam_bcf_vec[i]; - - FutureMap fm = im->inference(get_ssm_model(i), 0, beam_bcf_vec[i]); - assert(fm.get_future_map_domain().get_volume() == 1); - BeamInferenceResultFuture beam_irf = fm.get_future(0); + // FutureMap fm = im->inference(get_ssm_model(i), 0, beam_bcf_vec[i]); + // assert(fm.get_future_map_domain().get_volume() == 1); + // BeamInferenceResultFuture beam_irf = fm.get_future(0); + BeamInferenceResultFuture beam_irf = + im->inference(get_ssm_model(i), 0, beam_bcf_vec[i]); beam_bcf_vec[i] = prepare_next_batch_beam(beam_bcf_vec[i], beam_irf, ctx, runtime); } @@ -3260,9 +3399,10 @@ void RequestManager::serve_spec_infer(FFModel *llm) { { TreeVerifyBatchConfigFuture tree_bcf = prepare_next_batch_verify(beam_bcf_vec, ctx, runtime); - FutureMap fm = im->inference(llm, 0, tree_bcf); - assert(fm.get_future_map_domain().get_volume() == 1); - InferenceResultFuture tree_irf = fm.get_future(0); + // FutureMap fm = im->inference(llm, 0, tree_bcf); + // assert(fm.get_future_map_domain().get_volume() == 1); + // InferenceResultFuture tree_irf = fm.get_future(0); + InferenceResultFuture tree_irf = im->inference(llm, 0, tree_bcf); batch_pipeline.push(std::make_pair(tree_bcf, tree_irf)); last_tree_bcf = tree_bcf; last_tree_irf = tree_irf;