From 0ff84796470f5e6e2bb817f94b4ed09092a5232e Mon Sep 17 00:00:00 2001 From: Mikolaj Klikowicz Date: Thu, 28 Nov 2024 16:35:14 +0100 Subject: [PATCH] [#66550] model: API improvements Signed-off-by: Mikolaj Klikowicz --- demo_app/src/main.c | 35 ++----------- include/kenning_inference_lib/core/model.h | 12 +++-- lib/kenning_inference_lib/core/callbacks.c | 6 +-- .../core/inference_server.c | 17 ++----- lib/kenning_inference_lib/core/model.c | 51 +++++++++++++++++-- 5 files changed, 68 insertions(+), 53 deletions(-) diff --git a/demo_app/src/main.c b/demo_app/src/main.c index 795750d..80c609e 100644 --- a/demo_app/src/main.c +++ b/demo_app/src/main.c @@ -66,16 +66,6 @@ void postprocess_output(uint8_t *data_in, float *data_out, size_t model_output_s */ void format_output(uint8_t *buffer, const size_t buffer_size, float *model_output); -/** - * Initialize main loader table - */ -status_t prepare_main_ldr_table() -{ - static struct msg_loader msg_loader_iospec = MSG_LOADER_BUF((uint8_t *)(&g_model_struct), sizeof(MlModel)); - g_ldr_tables[0][LOADER_TYPE_IOSPEC] = &msg_loader_iospec; - return STATUS_OK; -} - int main(void) { status_t status = STATUS_OK; @@ -88,8 +78,6 @@ int main(void) int64_t timer_start = 0; int64_t timer_end = 0; - prepare_main_ldr_table(); - do { // initialize model @@ -100,25 +88,12 @@ int main(void) break; } - // retrieve loaders - struct msg_loader *msg_loader_model = g_ldr_tables[1][LOADER_TYPE_MODEL]; - struct msg_loader *msg_loader_data = g_ldr_tables[1][LOADER_TYPE_DATA]; - struct msg_loader *msg_loader_iospec = g_ldr_tables[0][LOADER_TYPE_IOSPEC]; - // load model structure - msg_loader_iospec->reset(msg_loader_iospec, 0); - status = msg_loader_iospec->save(msg_loader_iospec, (uint8_t *)(&model_struct), sizeof(MlModel)); - BREAK_ON_ERROR_LOG(status, "iospec loader failed: %d", status); - - status = model_load_struct(); + status = model_load_struct((uint8_t *)&model_struct, sizeof(MlModel)); BREAK_ON_ERROR_LOG(status, "Model struct load error 0x%x (%s)", status, get_status_str(status)); // load model weights - msg_loader_model->reset(msg_loader_model, 0); - status = msg_loader_model->save(msg_loader_model, (uint8_t *)model_data, model_data_len); - BREAK_ON_ERROR_LOG(status, "Model loader failed: %d", status); - - status = model_load_weights(); + status = model_load_weights(model_data, model_data_len); BREAK_ON_ERROR_LOG(status, "Model weights load error 0x%x (%s)", status, get_status_str(status)); // allocate buffer for input @@ -135,11 +110,7 @@ int main(void) { preprocess_input((float *)data[batch_index], model_input, model_input_size); - msg_loader_data->reset(msg_loader_data, 0); - status = msg_loader_data->save(msg_loader_data, model_input, model_input_size); - BREAK_ON_ERROR_LOG(status, "Data loader failed: %d", status); - - status = model_load_input(); + status = model_load_input(model_input, model_input_size); BREAK_ON_ERROR_LOG(status, "Model input load error 0x%x (%s)", status, get_status_str(status)); status = model_run(); diff --git a/include/kenning_inference_lib/core/model.h b/include/kenning_inference_lib/core/model.h index 402f34a..eba6e55 100644 --- a/include/kenning_inference_lib/core/model.h +++ b/include/kenning_inference_lib/core/model.h @@ -57,7 +57,7 @@ status_t model_init(); * * @returns status of the model */ -status_t model_load_struct(); +status_t model_load_struct(const uint8_t *model_struct_data, const size_t data_size); /** * Loads model weights from given buffer @@ -67,7 +67,7 @@ status_t model_load_struct(); * * @returns status of the model */ -status_t model_load_weights(); +status_t model_load_weights(const uint8_t *model_weights_data, const size_t data_size); /** * Calculates model input size based on data from model struct @@ -86,7 +86,13 @@ status_t model_get_input_size(size_t *model_input_size); * * @returns status of the model */ -status_t model_load_input(); +status_t model_load_input(const uint8_t *model_input, const size_t model_input_size); + +status_t model_load_struct_from_loader(); + +status_t model_load_weights_from_loader(); + +status_t model_load_input_from_loader(); /** * Runs model inference diff --git a/lib/kenning_inference_lib/core/callbacks.c b/lib/kenning_inference_lib/core/callbacks.c index ca1c396..c0f6e84 100644 --- a/lib/kenning_inference_lib/core/callbacks.c +++ b/lib/kenning_inference_lib/core/callbacks.c @@ -103,7 +103,7 @@ status_t data_callback(message_hdr_t *hdr, resp_message_t *resp) // TODO VALIDATE_HEADER(MESSAGE_TYPE_DATA, hdr); - status = model_load_input(); + status = model_load_input_from_loader(); CHECK_STATUS_LOG(status, resp, "model_load_input returned 0x%x (%s)", status, get_status_str(status)); @@ -126,7 +126,7 @@ status_t model_callback(message_hdr_t *hdr, resp_message_t *resp) VALIDATE_HEADER(MESSAGE_TYPE_MODEL, hdr); - status = model_load_weights(); + status = model_load_weights_from_loader(); CHECK_STATUS_LOG(status, resp, "model_load_weights returned 0x%x (%s)", status, get_status_str(status)); @@ -222,7 +222,7 @@ status_t iospec_callback(message_hdr_t *hdr, resp_message_t *resp) VALIDATE_HEADER(MESSAGE_TYPE_IOSPEC, hdr); - status = model_load_struct(); + status = model_load_struct_from_loader(); CHECK_STATUS_LOG(status, resp, "model_load_struct returned 0x%x (%s)", status, get_status_str(status)); diff --git a/lib/kenning_inference_lib/core/inference_server.c b/lib/kenning_inference_lib/core/inference_server.c index e6ab052..a5a7b26 100644 --- a/lib/kenning_inference_lib/core/inference_server.c +++ b/lib/kenning_inference_lib/core/inference_server.c @@ -59,11 +59,8 @@ int reset_runtime_alloc(struct msg_loader *ldr, size_t n) return 0; } -#endif // defined(CONFIG_LLEXT) - -status_t prepare_main_ldr_table() +status_t prepare_llext_loader() { -#if defined(CONFIG_LLEXT) static struct msg_loader msg_loader_llext = {.save = buf_save, .save_one = buf_save_one, .reset = reset_runtime_alloc, @@ -71,13 +68,10 @@ status_t prepare_main_ldr_table() .max_size = 0, .addr = NULL}; g_ldr_tables[0][LOADER_TYPE_RUNTIME] = &msg_loader_llext; -#endif - - static struct msg_loader msg_loader_iospec = MSG_LOADER_BUF((uint8_t *)(&g_model_struct), sizeof(MlModel)); - g_ldr_tables[0][LOADER_TYPE_IOSPEC] = &msg_loader_iospec; - return STATUS_OK; } +#endif // defined(CONFIG_LLEXT) + status_t init_server() { status_t status = STATUS_OK; @@ -90,10 +84,9 @@ status_t init_server() #if !defined(CONFIG_LLEXT) status = model_init(); CHECK_INIT_STATUS_RET(status, "model_init returned 0x%x (%s)", status, get_status_str(status)); +#else + prepare_llext_loader(); #endif - - prepare_main_ldr_table(); - LOG_INF("Inference server started"); return STATUS_OK; } diff --git a/lib/kenning_inference_lib/core/model.c b/lib/kenning_inference_lib/core/model.c index f12b3e1..88c10f0 100644 --- a/lib/kenning_inference_lib/core/model.c +++ b/lib/kenning_inference_lib/core/model.c @@ -26,21 +26,30 @@ MODEL_STATE model_get_state() { return g_model_state; } void model_reset_state() { g_model_state = MODEL_STATE_UNINITIALIZED; } +status_t prepare_iospec_loader() +{ + static struct msg_loader msg_loader_iospec = MSG_LOADER_BUF((uint8_t *)(&g_model_struct), sizeof(MlModel)); + g_ldr_tables[0][LOADER_TYPE_IOSPEC] = &msg_loader_iospec; + return STATUS_OK; +} + status_t model_init() { status_t status = STATUS_OK; status = runtime_init(); + RETURN_ON_ERROR(status, status); if (STATUS_OK == status) { g_model_state = MODEL_STATE_INITIALIZED; } + status = prepare_iospec_loader(); return status; } -status_t model_load_struct() +status_t model_load_struct_from_loader() { status_t status = STATUS_OK; @@ -122,7 +131,7 @@ status_t model_load_struct() return status; } -status_t model_load_weights() +status_t model_load_weights_from_loader() { status_t status = STATUS_OK; @@ -164,7 +173,7 @@ status_t model_get_input_size(size_t *model_input_size) return status; } -status_t model_load_input() +status_t model_load_input_from_loader() { status_t status = STATUS_OK; @@ -189,6 +198,42 @@ status_t model_load_input() return status; } +status_t model_load_weights(const uint8_t *model_weights_data, const size_t data_size) +{ + status_t status = STATUS_OK; + struct msg_loader *msg_loader_model = g_ldr_tables[1][LOADER_TYPE_MODEL]; + + msg_loader_model->reset(msg_loader_model, 0); + status = msg_loader_model->save(msg_loader_model, (uint8_t *)model_weights_data, data_size); + RETURN_ON_ERROR_LOG(status, status, "Model loader failed: %d", status); + + return model_load_weights_from_loader(); +} + +status_t model_load_struct(const uint8_t *model_struct_data, const size_t data_size) +{ + status_t status = STATUS_OK; + struct msg_loader *msg_loader_iospec = g_ldr_tables[0][LOADER_TYPE_IOSPEC]; + + msg_loader_iospec->reset(msg_loader_iospec, 0); + status = msg_loader_iospec->save(msg_loader_iospec, model_struct_data, data_size); + RETURN_ON_ERROR_LOG(status, status, "iospec loader failed: %d", status); + + return model_load_struct_from_loader(); +} + +status_t model_load_input(const uint8_t *model_input, const size_t model_input_size) +{ + status_t status = STATUS_OK; + struct msg_loader *msg_loader_data = g_ldr_tables[1][LOADER_TYPE_DATA]; + + msg_loader_data->reset(msg_loader_data, 0); + status = msg_loader_data->save(msg_loader_data, model_input, model_input_size); + RETURN_ON_ERROR_LOG(status, status, "Data loader failed: %d", status); + + return model_load_input_from_loader(); +} + status_t model_run() { status_t status = STATUS_OK;