Skip to content

Commit

Permalink
removed redandant structs (NttTaskStatus, available_tasks_counter...)…
Browse files Browse the repository at this point in the history
… fixed arbitrary_coset+columns_batch bug
  • Loading branch information
ShanieWinitz committed Aug 25, 2024
1 parent 0867076 commit 82fb08e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 119 deletions.
8 changes: 5 additions & 3 deletions icicle_v3/backend/cpu/include/cpu_ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ namespace ntt_cpu {
}
const int logn = int(log2(size));
const int domain_max_size = CpuNttDomain<S>::s_ntt_domain.get_max_size();
if (size > domain_max_size) {
ICICLE_LOG_ERROR << "Size is too large for domain. size = " << size << ", domain_max_size = " << domain_max_size;
return eIcicleError::INVALID_ARGUMENT;
}
const S* twiddles = CpuNttDomain<S>::s_ntt_domain.get_twiddles();
NttCpu<S, E> ntt(logn, direction, config, domain_max_size, twiddles);
NttTaskCordinates ntt_task_cordinates = {0, 0, 0, 0, 0};
NttTasksManager <S, E> ntt_tasks_manager(logn);
auto tasks_manager = new TasksManager<NttTask<S, E>>(std::thread::hardware_concurrency()-1);
// auto tasks_manager = new TasksManager<NttTask<S, E>>(1);
// auto tasks_manager = new TasksManager<NttTask<S, E>>(9);


int coset_stride = 0;
Expand Down Expand Up @@ -83,8 +87,6 @@ namespace ntt_cpu {
ntt.handle_pushed_tasks(tasks_manager, ntt_tasks_manager, 0);
}


// ntt_tasks_manager.wait_for_all_tasks();
ntt.refactor_and_reorder(output, twiddles);
ntt_task_cordinates.h1_layer_idx = 1;
sunbtt_plus_batch_logn = ntt.ntt_sub_logn.h1_layers_sub_logn[1] + int(log2(config.batch_size));
Expand Down
182 changes: 66 additions & 116 deletions icicle_v3/backend/cpu/include/ntt_tasks.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ namespace ntt_cpu {
// Constructor that initializes the counters
TasksDependenciesCounters(NttSubLogn ntt_sub_logn, int h1_layer_idx);

// Function to get a counter for a given task
std::shared_ptr<int> get_counter(const NttTaskCordinates& task_c, bool reorder);

// Function to decrement the counter for a given task and check if it is ready to execute. if so, return true
bool decrement_counter(NttTaskCordinates ntt_task_cordinates);
int get_nof_pointing_to_counter(int h0_layer_idx) { return nof_pointing_to_counter[h0_layer_idx]; }
Expand All @@ -143,11 +140,6 @@ namespace ntt_cpu {
std::vector<std::shared_ptr<int>> h1_counters; // [h1_subntt_idx]
};

struct NttTaskStatus {
bool done = false; // True if the task has been completed
bool reorder = false; // Whether the task is to reorder
};

template<typename S = scalar_t, typename E = scalar_t>
struct NttTaskParams {
NttCpu<S, E>* ntt_cpu;
Expand All @@ -156,8 +148,6 @@ namespace ntt_cpu {
bool reorder;
};

using NttTasksStatus = std::map<NttTaskCordinates, NttTaskStatus>;

template<typename S = scalar_t, typename E = scalar_t>
class NttTask : public TaskBase {
public:
Expand Down Expand Up @@ -203,10 +193,7 @@ namespace ntt_cpu {
bool is_reorder() const {
return reorder;
}
void set_ntt_cpu(NttCpu<S, E>* cpu) { ntt_cpu = cpu; }
void set_input(E* inp) { input = inp; }
void set_coordinates(const NttTaskCordinates& coordinates) { ntt_task_cordinates = coordinates; }
void set_reorder(bool reorder_val) { reorder = reorder_val; }
void set_params(NttTaskParams<S, E> params) {ntt_cpu = params.ntt_cpu; input = params.input; ntt_task_cordinates = params.task_c; reorder = params.reorder;}

private:
NttCpu<S, E>* ntt_cpu; // Reference to NttCpu instance
Expand All @@ -219,16 +206,8 @@ namespace ntt_cpu {
template<typename S = scalar_t, typename E = scalar_t>
class NttTasksManager {
public:
NttTasksManager(int logn)
: tasks_status(logn > 15 ? 2 : 1),
counters(logn > 15 ? 2 : 1, TasksDependenciesCounters(NttSubLogn(logn), 0)) {
if (logn > 15) {
counters[1] = TasksDependenciesCounters(NttSubLogn(logn), 1);
}
nof_available_tasks = 0;
nof_waiting_tasks = 0;
// ICICLE_LOG_DEBUG << "NttTasksManager constructor";
}
NttTasksManager(int logn);

// Add a new task to the ntt_task_manager
eIcicleError push_task(NttCpu<S, E>* ntt_cpu, E* input, NttTaskCordinates task_c, bool reorder);

Expand All @@ -238,24 +217,31 @@ namespace ntt_cpu {
// Set a task as completed and update dependencies
eIcicleError set_task_as_completed(NttTask<S, E>& completed_task, int nof_subntts_l2);

int nof_available_tasks;
int nof_waiting_tasks;
bool tasks_to_do() {
return !available_tasks_list.empty() || !waiting_tasks_map.empty();
}

private:
// Function to get the counter for a specific task based on its coordinates
std::shared_ptr<int> get_counter_p_for_task(const NttTaskCordinates& task_c, bool reorder) {
return counters[task_c.h1_layer_idx].get_counter(task_c, reorder);
bool available_tasks() {
return !available_tasks_list.empty();
}

NttTaskParams<S, E> get_available_task() {
return available_tasks_list.front();
}

std::vector<NttTasksStatus> tasks_status; // Status of tasks by layer
eIcicleError erase_task_from_available_tasks_list(){
available_tasks_list.erase(available_tasks_list.begin());
return eIcicleError::SUCCESS;
}
private:

std::vector<TasksDependenciesCounters> counters; // Dependencies counters by layer
std::vector<NttTaskParams<S, E>> available_tasks_params_list;
// std::vector<NttTaskParams<S, E>> waiting_tasks_params_list;
std::map<NttTaskCordinates, NttTaskParams<S, E>> waiting_tasks_params_map;
std::vector<NttTaskParams<S, E>> available_tasks_list;
std::map<NttTaskCordinates, NttTaskParams<S, E>> waiting_tasks_map;

};

//////////////////////////// NttTasksManager Implementation ////////////////////////////
//////////////////////////// TasksDependenciesCounters Implementation ////////////////////////////

TasksDependenciesCounters::TasksDependenciesCounters(NttSubLogn ntt_sub_logn, int h1_layer_idx)
: h0_counters(1<<ntt_sub_logn.h1_layers_sub_logn[1-h1_layer_idx]), //nof_h1_subntts = h1_layers_sub_logn[1-h1_layer_idx].
Expand Down Expand Up @@ -317,30 +303,6 @@ namespace ntt_cpu {
}
}

std::shared_ptr<int> TasksDependenciesCounters::get_counter(const NttTaskCordinates& task_c, bool reorder) {
if (reorder) {
// ICICLE_LOG_DEBUG << "get_counter: h1_counters["<<task_c.h1_subntt_idx<<"]: " << *h1_counters[task_c.h1_subntt_idx];
return h1_counters[task_c.h1_subntt_idx];
}
if (task_c.h0_layer_idx == 0) {
// ICICLE_LOG_DEBUG << "get_counter: h0_counters["<<task_c.h1_subntt_idx<<"]["<<task_c.h0_layer_idx<<"][0]: " << *h0_counters[task_c.h1_subntt_idx][task_c.h0_layer_idx][0];
return h0_counters[task_c.h1_subntt_idx][task_c.h0_layer_idx][0];
}
if (task_c.h0_layer_idx == 1) {
// ICICLE_LOG_DEBUG << "get_counter: h0_counters["<<task_c.h1_subntt_idx<<"]["<<task_c.h0_layer_idx<<"]["<<task_c.h0_block_idx<<"]: " << *h0_counters[task_c.h1_subntt_idx][task_c.h0_layer_idx][task_c.h0_block_idx];
return h0_counters[task_c.h1_subntt_idx][task_c.h0_layer_idx][task_c.h0_block_idx];
}
if (task_c.h0_layer_idx == 2) {
// ICICLE_LOG_DEBUG << "get_counter: h0_counters["<<task_c.h1_subntt_idx<<"]["<<task_c.h0_layer_idx<<"]["<<task_c.h0_subntt_idx<<"]: " << *h0_counters[task_c.h1_subntt_idx][task_c.h0_layer_idx][task_c.h0_subntt_idx];
return h0_counters[task_c.h1_subntt_idx][task_c.h0_layer_idx][task_c.h0_block_idx/this->nof_pointing_to_counter[task_c.h0_layer_idx]];
} else {
ICICLE_LOG_ERROR << "get_counter: return nullptr";
// Handle other cases or throw an exception
return nullptr; // Default or error value
}

}

bool TasksDependenciesCounters::decrement_counter(NttTaskCordinates task_c) {
if (nof_h0_layers==1){
return false;
Expand Down Expand Up @@ -371,32 +333,31 @@ namespace ntt_cpu {

//////////////////////////// NttTasksManager Implementation ////////////////////////////

template<typename S, typename E>
NttTasksManager<S, E>::NttTasksManager(int logn)
: counters(logn > 15 ? 2 : 1, TasksDependenciesCounters(NttSubLogn(logn), 0)) {
if (logn > 15) {
counters[1] = TasksDependenciesCounters(NttSubLogn(logn), 1);
}
}

template<typename S, typename E>
eIcicleError NttTasksManager<S, E>::push_task(NttCpu<S, E>* ntt_cpu, E* input, NttTaskCordinates task_c, bool reorder) {

if (tasks_status[task_c.h1_layer_idx].find(task_c) == tasks_status[task_c.h1_layer_idx].end()) {
NttTaskStatus status = {false, reorder};
tasks_status[task_c.h1_layer_idx][task_c] = status;

// Create a new NttTaskParams and add it to the available_tasks_params_list
// Create a new NttTaskParams and add it to the available_tasks_list
NttTaskParams<S, E> params = {ntt_cpu, input, task_c, reorder};
if (task_c.h0_layer_idx == 0) {
available_tasks_params_list.push_back(params);
nof_available_tasks++;
available_tasks_list.push_back(params);
} else {
waiting_tasks_params_map[task_c] = params; // Add to map
nof_waiting_tasks++;
waiting_tasks_map[task_c] = params; // Add to map
}
return eIcicleError::SUCCESS;
}
return eIcicleError::INVALID_ARGUMENT;
}

template<typename S, typename E>
bool NttTasksManager<S, E>::get_available_task_to_run(NttTask<S, E>* available_task, int h1_layer) {
if (!available_tasks_params_list.empty()) {
if (!available_tasks_list.empty()) {
// Take the first task from the list
NttTaskParams<S, E> params = available_tasks_params_list.front();
NttTaskParams<S, E> params = available_tasks_list.front();

// Assign the parameters to the available task
available_task->set_ntt_cpu(params.ntt_cpu);
Expand All @@ -405,9 +366,7 @@ namespace ntt_cpu {
available_task->set_reorder(params.reorder);

// Remove the task from the list
available_tasks_params_list.erase(available_tasks_params_list.begin());
nof_available_tasks--;

available_tasks_list.erase(available_tasks_list.begin());
return true;
}
return false;
Expand All @@ -418,9 +377,6 @@ namespace ntt_cpu {
template<typename S, typename E>
eIcicleError NttTasksManager<S, E>::set_task_as_completed(NttTask<S, E>& completed_task, int nof_subntts_l2) {
ntt_cpu::NttTaskCordinates task_c = completed_task.get_coordinates();
auto& status = tasks_status[task_c.h1_layer_idx][task_c];
status.done = true;
// int h1_layer_idx = task_c.h1_layer_idx;
int nof_h0_layers = counters[task_c.h1_layer_idx].get_nof_h0_layers();
// Update dependencies in counters
if(counters[task_c.h1_layer_idx].decrement_counter(task_c)){
Expand All @@ -429,35 +385,27 @@ namespace ntt_cpu {
int nof_pointing_to_counter = (task_c.h0_layer_idx == nof_h0_layers-1) ? 1
: counters[task_c.h1_layer_idx].get_nof_pointing_to_counter(task_c.h0_layer_idx+1);
int stride = nof_subntts_l2/nof_pointing_to_counter;
// int counter_group_idx = task_c.h0_layer_idx==0 ? task_c.h0_block_idx :
// /*task_c.h0_layer_idx==1*/ task_c.h0_subntt_idx;
for (int i = 0; i < nof_pointing_to_counter; i++) { // TODO - improve efficiency using make_move_iterator
NttTaskCordinates next_task_c = task_c.h0_layer_idx==0 ? NttTaskCordinates{task_c.h1_layer_idx, task_c.h1_subntt_idx, task_c.h0_layer_idx+1, task_c.h0_block_idx, i}
/*task_c.h0_layer_idx==1*/: NttTaskCordinates{task_c.h1_layer_idx, task_c.h1_subntt_idx, task_c.h0_layer_idx+1, (task_c.h0_subntt_idx + stride*i), 0};
// /*task_c.h0_layer_idx==1*/: NttTaskCordinates{task_c.h1_layer_idx, task_c.h1_subntt_idx, task_c.h0_layer_idx+1, (task_c.h0_subntt_idx* nof_pointing_to_counter +i), 0};
if (waiting_tasks_params_map.find(next_task_c) != waiting_tasks_params_map.end()) {
available_tasks_params_list.push_back(waiting_tasks_params_map[next_task_c]);
waiting_tasks_params_map.erase(next_task_c);
if (waiting_tasks_map.find(next_task_c) != waiting_tasks_map.end()) {
available_tasks_list.push_back(waiting_tasks_map[next_task_c]);
waiting_tasks_map.erase(next_task_c);
}
else {
ICICLE_LOG_ERROR << "Task not found in waiting_tasks_params_map: h0_layer_idx: " << next_task_c.h0_layer_idx << ", h0_block_idx: " << next_task_c.h0_block_idx << ", h0_subntt_idx: " << next_task_c.h0_subntt_idx;
ICICLE_LOG_ERROR << "Task not found in waiting_tasks_map: h0_layer_idx: " << next_task_c.h0_layer_idx << ", h0_block_idx: " << next_task_c.h0_block_idx << ", h0_subntt_idx: " << next_task_c.h0_subntt_idx;
}
}
nof_available_tasks = nof_available_tasks + nof_pointing_to_counter;
nof_waiting_tasks = nof_waiting_tasks - nof_pointing_to_counter;

} else {
// Reorder the output
NttTaskCordinates next_task_c = {task_c.h1_layer_idx, task_c.h1_subntt_idx, nof_h0_layers, 0, 0};

if (waiting_tasks_params_map.find(next_task_c) != waiting_tasks_params_map.end()) {
available_tasks_params_list.push_back(waiting_tasks_params_map[next_task_c]);
nof_available_tasks++;
waiting_tasks_params_map.erase(next_task_c);
nof_waiting_tasks--;
if (waiting_tasks_map.find(next_task_c) != waiting_tasks_map.end()) {
available_tasks_list.push_back(waiting_tasks_map[next_task_c]);
waiting_tasks_map.erase(next_task_c);
}
else {
ICICLE_LOG_ERROR << "Task not found in waiting_tasks_params_map";
ICICLE_LOG_ERROR << "Task not found in waiting_tasks_map";
}
}
}
Expand Down Expand Up @@ -598,7 +546,7 @@ namespace ntt_cpu {
if (arbitrary_coset) {
for (int i = 1; i < size; ++i) {
idx = this->config.columns_batch ? batch : i;
current_elements[i] = current_elements[i] * arbitrary_coset[idx];
current_elements[batch_stride * i] = current_elements[batch_stride * i] * arbitrary_coset[i];
}
} else if (coset_stride != 0) {
for (int i = 1; i < size; ++i) {
Expand Down Expand Up @@ -781,7 +729,7 @@ namespace ntt_cpu {
}
}
// Sort the output at the end so that elements will be in right order.
// TODO SHANIE - After implementing for different ordering, maybe this should be done in a different place
// TODO SHANIE - After implementing for different ordering, maybe this should be in a different place
// - When implementing real parallelism, consider sorting in parallel and in-place
int nof_h0_subntts = (nof_h0_layers == 1) ? (1 << NttCpu<S, E>::ntt_sub_logn.h0_layers_sub_logn[ntt_task_cordinates.h1_layer_idx][1]) :
(nof_h0_layers == 2) ? (1 << NttCpu<S, E>::ntt_sub_logn.h0_layers_sub_logn[ntt_task_cordinates.h1_layer_idx][0]) : 1;
Expand All @@ -790,7 +738,6 @@ namespace ntt_cpu {
// ICICLE_LOG_DEBUG << "h1_cpu_ntt: PUSH REORDER TASK h0_layer_idx: " << ntt_task_cordinates.h0_layer_idx << ", h0_block_idx: " << ntt_task_cordinates.h0_block_idx << ", h0_subntt_idx: " << ntt_task_cordinates.h0_subntt_idx;
ntt_task_cordinates = {ntt_task_cordinates.h1_layer_idx, ntt_task_cordinates.h1_subntt_idx, nof_h0_layers, 0, 0};
ntt_tasks_manager.push_task(this, input, ntt_task_cordinates, true); //reorder=true
// ICICLE_LOG_DEBUG << "h1_cpu_ntt: PUSH REORDER TASK DONE";
}
// if (nof_h0_layers>1) { // at least 2 layers
// if (this->config.columns_batch) {
Expand Down Expand Up @@ -830,33 +777,36 @@ namespace ntt_cpu {
template <typename S, typename E>
eIcicleError NttCpu<S, E>::handle_pushed_tasks(TasksManager<NttTask<S, E>>* tasks_manager, NttTasksManager<S, E>& ntt_tasks_manager, int h1_layer_idx) {
NttTask<S, E>* task_slot = nullptr;
NttTaskParams<S, E> params;

int nof_subntts_l2 = 1 << ((this->ntt_sub_logn.h0_layers_sub_logn[h1_layer_idx][0]) + (this->ntt_sub_logn.h0_layers_sub_logn[h1_layer_idx][1]));
// ICICLE_LOG_DEBUG << "handle_pushed_tasks: nof_available_tasks: " << ntt_tasks_manager.nof_available_tasks;
// ICICLE_LOG_DEBUG << "handle_pushed_tasks: nof_waiting_tasks: " << ntt_tasks_manager.nof_waiting_tasks;
while (ntt_tasks_manager.nof_available_tasks > 0 || ntt_tasks_manager.nof_waiting_tasks > 0) {
if (ntt_tasks_manager.nof_available_tasks > 0){
while (ntt_tasks_manager.tasks_to_do()) {
// There are tasks that are available or waiting

if (ntt_tasks_manager.available_tasks()){
// Task is available to dispatch
task_slot = tasks_manager->get_idle_or_completed_task();
if (task_slot->is_completed()) {
ntt_tasks_manager.set_task_as_completed(*task_slot, nof_subntts_l2);
}
ntt_tasks_manager.get_available_task_to_run(task_slot, h1_layer_idx);
params = ntt_tasks_manager.get_available_task();
task_slot->set_params(params);
ntt_tasks_manager.erase_task_from_available_tasks_list();
task_slot->dispatch();
} else { // wait for available tasks
while (ntt_tasks_manager.nof_available_tasks == 0 && ntt_tasks_manager.nof_waiting_tasks > 0) {
task_slot = tasks_manager->get_completed_task();
ntt_tasks_manager.set_task_as_completed(*task_slot, nof_subntts_l2);
if (ntt_tasks_manager.nof_available_tasks > 0) {
ICICLE_ASSERT(ntt_tasks_manager.get_available_task_to_run(task_slot, h1_layer_idx));
task_slot->dispatch();
} else {
task_slot->set_idle();
}
} else {
// Wait for available tasks
task_slot = tasks_manager->get_completed_task();
ntt_tasks_manager.set_task_as_completed(*task_slot, nof_subntts_l2);
if (ntt_tasks_manager.available_tasks()) {
params = ntt_tasks_manager.get_available_task();
task_slot->set_params(params);
ntt_tasks_manager.erase_task_from_available_tasks_list();
task_slot->dispatch();
} else {
task_slot->set_idle();
}
}
}
while ((task_slot = tasks_manager->get_completed_task()) != nullptr) { // clean all completed tasks
task_slot->set_idle();
}
return eIcicleError::SUCCESS;
}

Expand Down

0 comments on commit 82fb08e

Please sign in to comment.