forked from open-mmlab/mmrazor
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mmrotate sdk module (open-mmlab#450)
* support mmrotate * fix name * windows default link to cudart_static.lib, which is not compatible with static build && python_api * python api * fix ci * fix type & remove unused meta info * fix doxygen, add [out] to @param * fix mmrotate-c-api * refactor naming * refactor naming * fix lint * fix lint * move replace_RResize -> get_preprocess * Update cuda.cmake On windows, make static lib and python api build success. * fix ptr * Use unique ptr to prevent memory leaks * move unique_ptr * remove deleter Co-authored-by: chenxin2 <[email protected]> Co-authored-by: cx <[email protected]>
- Loading branch information
1 parent
1a8d7ac
commit 0ce7c83
Showing
18 changed files
with
631 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['./rotated-detection_static.py', '../_base_/backends/sdk.py'] | ||
|
||
codebase_config = dict(model_type='sdk') | ||
|
||
backend_config = dict(pipeline=[ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape']) | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
#include "rotated_detector.h" | ||
|
||
#include <numeric> | ||
|
||
#include "codebase/mmrotate/mmrotate.h" | ||
#include "core/device.h" | ||
#include "core/graph.h" | ||
#include "core/mat.h" | ||
#include "core/utils/formatter.h" | ||
#include "handle.h" | ||
|
||
using namespace std; | ||
using namespace mmdeploy; | ||
|
||
namespace { | ||
|
||
Value& config_template() { | ||
// clang-format off | ||
static Value v{ | ||
{ | ||
"pipeline", { | ||
{"input", {"image"}}, | ||
{"output", {"det"}}, | ||
{ | ||
"tasks",{ | ||
{ | ||
{"name", "mmrotate"}, | ||
{"type", "Inference"}, | ||
{"params", {{"model", "TBD"}}}, | ||
{"input", {"image"}}, | ||
{"output", {"det"}} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
// clang-format on | ||
return v; | ||
} | ||
|
||
template <class ModelType> | ||
int mmdeploy_rotated_detector_create_impl(ModelType&& m, const char* device_name, int device_id, | ||
mm_handle_t* handle) { | ||
try { | ||
auto value = config_template(); | ||
value["pipeline"]["tasks"][0]["params"]["model"] = std::forward<ModelType>(m); | ||
|
||
auto pose_estimator = std::make_unique<Handle>(device_name, device_id, std::move(value)); | ||
|
||
*handle = pose_estimator.release(); | ||
return MM_SUCCESS; | ||
|
||
} catch (const std::exception& e) { | ||
MMDEPLOY_ERROR("exception caught: {}", e.what()); | ||
} catch (...) { | ||
MMDEPLOY_ERROR("unknown exception caught"); | ||
} | ||
return MM_E_FAIL; | ||
} | ||
|
||
} // namespace | ||
|
||
int mmdeploy_rotated_detector_create(mm_model_t model, const char* device_name, int device_id, | ||
mm_handle_t* handle) { | ||
return mmdeploy_rotated_detector_create_impl(*static_cast<Model*>(model), device_name, device_id, | ||
handle); | ||
} | ||
|
||
int mmdeploy_rotated_detector_create_by_path(const char* model_path, const char* device_name, | ||
int device_id, mm_handle_t* handle) { | ||
return mmdeploy_rotated_detector_create_impl(model_path, device_name, device_id, handle); | ||
} | ||
|
||
int mmdeploy_rotated_detector_apply(mm_handle_t handle, const mm_mat_t* mats, int mat_count, | ||
mm_rotated_detect_t** results, int** result_count) { | ||
if (handle == nullptr || mats == nullptr || mat_count == 0 || results == nullptr || | ||
result_count == nullptr) { | ||
return MM_E_INVALID_ARG; | ||
} | ||
|
||
try { | ||
auto detector = static_cast<Handle*>(handle); | ||
|
||
Value input{Value::kArray}; | ||
for (int i = 0; i < mat_count; ++i) { | ||
mmdeploy::Mat _mat{mats[i].height, mats[i].width, PixelFormat(mats[i].format), | ||
DataType(mats[i].type), mats[i].data, Device{"cpu"}}; | ||
input.front().push_back({{"ori_img", _mat}}); | ||
} | ||
|
||
auto output = detector->Run(std::move(input)).value().front(); | ||
auto detector_outputs = from_value<vector<mmrotate::RotatedDetectorOutput>>(output); | ||
|
||
vector<int> _result_count; | ||
_result_count.reserve(mat_count); | ||
for (const auto& det_output : detector_outputs) { | ||
_result_count.push_back((int)det_output.detections.size()); | ||
} | ||
|
||
auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0); | ||
|
||
std::unique_ptr<int[]> result_count_data(new int[_result_count.size()]{}); | ||
std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); | ||
|
||
std::unique_ptr<mm_rotated_detect_t[]> result_data(new mm_rotated_detect_t[total]{}); | ||
auto result_ptr = result_data.get(); | ||
|
||
for (const auto& det_output : detector_outputs) { | ||
for (const auto& detection : det_output.detections) { | ||
result_ptr->label_id = detection.label_id; | ||
result_ptr->score = detection.score; | ||
const auto& rbbox = detection.rbbox; | ||
for (int i = 0; i < 5; i++) { | ||
result_ptr->rbbox[i] = rbbox[i]; | ||
} | ||
++result_ptr; | ||
} | ||
} | ||
|
||
*result_count = result_count_data.release(); | ||
*results = result_data.release(); | ||
|
||
return MM_SUCCESS; | ||
|
||
} catch (const std::exception& e) { | ||
MMDEPLOY_ERROR("exception caught: {}", e.what()); | ||
} catch (...) { | ||
MMDEPLOY_ERROR("unknown exception caught"); | ||
} | ||
return MM_E_FAIL; | ||
} | ||
|
||
void mmdeploy_rotated_detector_release_result(mm_rotated_detect_t* results, | ||
const int* result_count) { | ||
delete[] results; | ||
delete[] result_count; | ||
} | ||
|
||
void mmdeploy_rotated_detector_destroy(mm_handle_t handle) { delete static_cast<Handle*>(handle); } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
/** | ||
* @file rotated_detector.h | ||
* @brief Interface to MMRotate task | ||
*/ | ||
|
||
#ifndef MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_ | ||
#define MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_ | ||
|
||
#include "common.h" | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
typedef struct mm_rotated_detect_t { | ||
int label_id; | ||
float score; | ||
float rbbox[5]; // cx, cy, w, h, angle | ||
} mm_rotated_detect_t; | ||
|
||
/** | ||
* @brief Create rotated detector's handle | ||
* @param[in] model an instance of mmrotate sdk model created by | ||
* \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h | ||
* @param[in] device_name name of device, such as "cpu", "cuda", etc. | ||
* @param[in] device_id id of device. | ||
* @param[out] handle instance of a rotated detector | ||
* @return status of creating rotated detector's handle | ||
*/ | ||
MMDEPLOY_API int mmdeploy_rotated_detector_create(mm_model_t model, const char* device_name, | ||
int device_id, mm_handle_t* handle); | ||
|
||
/** | ||
* @brief Create rotated detector's handle | ||
* @param[in] model_path path of mmrotate sdk model exported by mmdeploy model converter | ||
* @param[in] device_name name of device, such as "cpu", "cuda", etc. | ||
* @param[in] device_id id of device. | ||
* @param[out] handle instance of a rotated detector | ||
* @return status of creating rotated detector's handle | ||
*/ | ||
MMDEPLOY_API int mmdeploy_rotated_detector_create_by_path(const char* model_path, | ||
const char* device_name, int device_id, | ||
mm_handle_t* handle); | ||
|
||
/** | ||
* @brief Apply rotated detector to batch images and get their inference results | ||
* @param[in] handle rotated detector's handle created by \ref | ||
* mmdeploy_rotated_detector_create_by_path | ||
* @param[in] mats a batch of images | ||
* @param[in] mat_count number of images in the batch | ||
* @param[out] results a linear buffer to save detection results of each image. It must be released | ||
* by \ref mmdeploy_rotated_detector_release_result | ||
* @param[out] result_count a linear buffer with length being \p mat_count to save the number of | ||
* detection results of each image. And it must be released by \ref | ||
* mmdeploy_rotated_detector_release_result | ||
* @return status of inference | ||
*/ | ||
MMDEPLOY_API int mmdeploy_rotated_detector_apply(mm_handle_t handle, const mm_mat_t* mats, | ||
int mat_count, mm_rotated_detect_t** results, | ||
int** result_count); | ||
|
||
/** @brief Release the inference result buffer created by \ref mmdeploy_rotated_detector_apply | ||
* @param[in] results rotated detection results buffer | ||
* @param[in] result_count \p results size buffer | ||
*/ | ||
MMDEPLOY_API void mmdeploy_rotated_detector_release_result(mm_rotated_detect_t* results, | ||
const int* result_count); | ||
|
||
/** | ||
* @brief Destroy rotated detector's handle | ||
* @param[in] handle rotated detector's handle created by \ref | ||
* mmdeploy_rotated_detector_create_by_path or by \ref mmdeploy_rotated_detector_create | ||
*/ | ||
MMDEPLOY_API void mmdeploy_rotated_detector_destroy(mm_handle_t handle); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
|
||
#endif // MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
#include "rotated_detector.h" | ||
|
||
#include "common.h" | ||
#include "core/logger.h" | ||
|
||
namespace mmdeploy { | ||
|
||
class PyRotatedDetector { | ||
public: | ||
PyRotatedDetector(const char *model_path, const char *device_name, int device_id) { | ||
MMDEPLOY_INFO("{}, {}, {}", model_path, device_name, device_id); | ||
auto status = | ||
mmdeploy_rotated_detector_create_by_path(model_path, device_name, device_id, &handle_); | ||
if (status != MM_SUCCESS) { | ||
throw std::runtime_error("failed to create rotated detector"); | ||
} | ||
} | ||
py::list Apply(const std::vector<PyImage> &imgs) { | ||
std::vector<mm_mat_t> mats; | ||
mats.reserve(imgs.size()); | ||
for (const auto &img : imgs) { | ||
auto mat = GetMat(img); | ||
mats.push_back(mat); | ||
} | ||
|
||
mm_rotated_detect_t *rbboxes{}; | ||
int *res_count{}; | ||
auto status = mmdeploy_rotated_detector_apply(handle_, mats.data(), (int)mats.size(), &rbboxes, | ||
&res_count); | ||
if (status != MM_SUCCESS) { | ||
throw std::runtime_error("failed to apply rotated detector, code: " + std::to_string(status)); | ||
} | ||
auto output = py::list{}; | ||
auto result = rbboxes; | ||
auto counts = res_count; | ||
for (int i = 0; i < mats.size(); i++) { | ||
auto _dets = py::array_t<float>({*counts, 6}); | ||
auto _labels = py::array_t<int>({*counts}); | ||
auto dets = _dets.mutable_data(); | ||
auto labels = _labels.mutable_data(); | ||
for (int j = 0; j < *counts; j++) { | ||
for (int k = 0; k < 5; k++) { | ||
*dets++ = result->rbbox[k]; | ||
} | ||
*dets++ = result->score; | ||
*labels++ = result->label_id; | ||
result++; | ||
} | ||
counts++; | ||
output.append(py::make_tuple(std::move(_dets), std::move(_labels))); | ||
} | ||
mmdeploy_rotated_detector_release_result(rbboxes, res_count); | ||
return output; | ||
} | ||
~PyRotatedDetector() { | ||
mmdeploy_rotated_detector_destroy(handle_); | ||
handle_ = {}; | ||
} | ||
|
||
private: | ||
mm_handle_t handle_{}; | ||
}; | ||
|
||
static void register_python_rotated_detector(py::module &m) { | ||
py::class_<PyRotatedDetector>(m, "RotatedDetector") | ||
.def(py::init([](const char *model_path, const char *device_name, int device_id) { | ||
return std::make_unique<PyRotatedDetector>(model_path, device_name, device_id); | ||
})) | ||
.def("__call__", &PyRotatedDetector::Apply); | ||
} | ||
|
||
class PythonRotatedDetectorRegisterer { | ||
public: | ||
PythonRotatedDetectorRegisterer() { | ||
gPythonBindings().emplace("rotated_detector", register_python_rotated_detector); | ||
} | ||
}; | ||
|
||
static PythonRotatedDetectorRegisterer python_rotated_detector_registerer; | ||
|
||
} // namespace mmdeploy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
cmake_minimum_required(VERSION 3.14) | ||
project(mmdeploy_mmrotate) | ||
|
||
include(${CMAKE_SOURCE_DIR}/cmake/opencv.cmake) | ||
include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake) | ||
|
||
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp") | ||
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") | ||
target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_opencv_utils) | ||
add_library(mmdeploy::mmrotate ALIAS ${PROJECT_NAME}) |
Oops, something went wrong.