Skip to content

Commit

Permalink
Merge branch 'enh-conversion-to-coreml' into enh-yolo-tensorrt-v7
Browse files Browse the repository at this point in the history
  • Loading branch information
oddkiva committed Dec 15, 2023
2 parents e1f1be2 + 48e9656 commit a42ba14
Show file tree
Hide file tree
Showing 97 changed files with 1,182 additions and 294 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ latex/
CMakeLists.txt.user

# Python
**/__pycache__
*.pyc
*.pyo
*.coverage
Expand Down
2 changes: 1 addition & 1 deletion cpp/examples/Shakti/TensorRT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ if(NOT CMAKE_CUDA_COMPILER OR NOT TensorRT_FOUND)
return()
endif()

file(GLOB TRT_SOURCE_FILES FILES *.cpp)
file(GLOB TRT_SOURCE_FILES FILES *.cu)

foreach(file ${TRT_SOURCE_FILES})
get_filename_component(filename ${file} NAME_WE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// you can obtain one at http://mozilla.org/MPL/2.0/.
// ========================================================================== //

#include <DO/Shakti/Cuda/MultiArray/ManagedMemoryAllocator.hpp>
#include <DO/Shakti/Cuda/TensorRT/DarknetParser.hpp>
#include <DO/Shakti/Cuda/TensorRT/IO.hpp>
#include <DO/Shakti/Cuda/TensorRT/InferenceEngine.hpp>
Expand All @@ -20,6 +21,7 @@
#include <DO/Sara/NeuralNetworks/Darknet/YoloUtilities.hpp>
#include <DO/Sara/VideoIO.hpp>

#include <algorithm>
#include <filesystem>

#ifdef _OPENMP
Expand All @@ -33,51 +35,81 @@ namespace fs = std::filesystem;
namespace trt = DO::Shakti::TensorRT;
namespace d = sara::Darknet;

using CudaManagedTensor3ub =
trt::InferenceEngine::ManagedTensor<std::uint8_t, 3>;
using CudaManagedTensor3f = trt::InferenceEngine::ManagedTensor<float, 3>;


__global__ auto naive_downsample_and_transpose(float* out_chw,
const std::uint8_t* in_hwc,
const int wout, const int hout,
const int win, const int hin)
-> void
{
const int c = blockIdx.x * blockDim.x + threadIdx.x;
const int yout = blockIdx.y * blockDim.y + threadIdx.y;
const int xout = blockIdx.z * blockDim.z + threadIdx.z;

if (xout >= wout || yout >= hout || c >= 3)
return;

const float sx = float(win) / float(wout);
const float sy = float(hin) / float(hout);

int xin = int(xout * sx + 0.5f);
int yin = int(yout * sy + 0.5f);

if (xin >= win)
xin = win - 1;
if (yin >= hin)
yin = hin - 1;

const int gi_out = c * hout * wout + yout * wout + xout;
const int gi_in = yin * win * 3 + xin * 3 + c;

static constexpr auto normalize_factor = 1 / 255.f;
out_chw[gi_out] = static_cast<float>(in_hwc[gi_in]) * normalize_factor;
}

auto naive_downsample_and_transpose(CudaManagedTensor3f& tensor_chw_resized_32f,
CudaManagedTensor3ub& tensor_hwc_8u) -> void
{
// Data order: H W C
// 0 1 2
const auto in_hwc = tensor_hwc_8u.data();
const auto win = tensor_hwc_8u.sizes()(1);
const auto hin = tensor_hwc_8u.sizes()(0);

// Data order: C H W
// 0 1 2
auto out_chw = tensor_chw_resized_32f.data();
const auto hout = tensor_chw_resized_32f.sizes()(1);
const auto wout = tensor_chw_resized_32f.sizes()(2);

const auto threads_per_block = dim3(4, 16, 16);
const auto num_blocks = dim3( //
1, //
(hout + threads_per_block.y - 1) / threads_per_block.y,
(wout + threads_per_block.z - 1) / threads_per_block.z //
);

naive_downsample_and_transpose<<<num_blocks, threads_per_block>>>(
out_chw, in_hwc, //
wout, hout, //
win, hin //
);
}

// The API.
auto detect_objects(
const sara::ImageView<sara::Rgb8>& image,
const trt::InferenceEngine& inference_engine,
trt::InferenceEngine::PinnedTensor<float, 3>& cuda_in_tensor,
const CudaManagedTensor3f& cuda_in_tensor,
std::vector<trt::InferenceEngine::PinnedTensor<float, 3>>& cuda_out_tensors,
const float iou_thres, //
const std::vector<std::vector<int>>& anchor_masks,
const std::vector<int>& anchors) -> std::vector<d::YoloBox>
const std::vector<int>& anchors, //
const Eigen::Vector2i& image_sizes) -> std::vector<d::YoloBox>
{
// N.B.: this would still be unacceptably slow and a GPU implementation is
// still preferrable.
// The CPU implementation takes between 5 and 7 ms on a powerful CPU...
sara::tic();
auto rgb_tensor = sara::Tensor_<float, 3>{3, image.height(), image.width()};
const auto rgb = image.data();
auto r = rgb_tensor[0].data();
auto g = rgb_tensor[1].data();
auto b = rgb_tensor[2].data();
const auto size = static_cast<int>(image.size());
if (image.size() != rgb_tensor[0].size())
throw 0;
#ifdef _OPENMP
# pragma omp parallel for
#endif
for (auto i = 0; i < size; ++i)
{
r[i] = rgb[i].channel<s::R>() / 255.f;
g[i] = rgb[i].channel<s::G>() / 255.f;
b[i] = rgb[i].channel<s::B>() / 255.f;
}
sara::toc("Uint8 Interleaved to Float Planar");

sara::tic();
auto rgb_tensor_resized = sara::TensorView_<float, 3>{cuda_in_tensor.data(),
cuda_in_tensor.sizes()};
for (auto channel = 0; channel < 3; ++channel)
{
const auto src = sara::image_view(rgb_tensor[channel]);
auto dst = sara::image_view(rgb_tensor_resized[channel]);
sara::resize_v2(src, dst);
}
sara::toc("Image resize");

// Feed the input and outputs to the YOLO v4 tiny network.
sara::tic();
inference_engine(cuda_in_tensor, cuda_out_tensors, true);
Expand All @@ -86,16 +118,15 @@ auto detect_objects(
// Accumulate all the detection from each YOLO layer.
sara::tic();
auto detections = std::vector<d::YoloBox>{};
for (auto i = 0u; i < anchor_masks.size(); ++i)
const auto wr = cuda_in_tensor.sizes()(2);
const auto hr = cuda_in_tensor.sizes()(1);
for (auto i = 0; i < 2; ++i)
{
const auto& yolo_out = cuda_out_tensors[i];
const auto& anchor_mask = anchor_masks[i];
const auto dets = d::get_yolo_boxes( //
yolo_out, //
anchors, anchor_mask, //
{cuda_in_tensor.size(2), cuda_in_tensor.size(1)}, //
image.sizes(), //
0.25f);
const auto dets = d::get_yolo_boxes(yolo_out, //
anchors, anchor_mask, //
{wr, hr}, image_sizes, 0.25f);
detections.insert(detections.end(), dets.begin(), dets.end());
}
sara::toc("Postprocess boxes");
Expand Down Expand Up @@ -138,7 +169,7 @@ auto test_on_video(int argc, char** argv) -> void

const auto data_dir_path = fs::canonical(fs::path{src_path("data")});
static constexpr auto yolo_version = 4;
static constexpr auto is_tiny = false;
static constexpr auto is_tiny = true;
auto yolo_model = "yolov" + std::to_string(yolo_version);
if (is_tiny)
yolo_model += "-tiny";
Expand All @@ -158,16 +189,21 @@ auto test_on_video(int argc, char** argv) -> void
trt::write_plan(serialized_net, yolo_plan_filepath.string());
}

auto cuda_in_tensor = trt::InferenceEngine::PinnedTensor<float, 3>{};
auto tensor_hwc_8u = CudaManagedTensor3ub{frame.height(), frame.width(), 3};
auto tensor_hwc_32f = CudaManagedTensor3f{frame.height(), frame.width(), 3};
auto tensor_chw_resized_32f = CudaManagedTensor3f{};

auto& cuda_in_tensor = tensor_chw_resized_32f;
auto cuda_out_tensors =
std::vector<trt::InferenceEngine::PinnedTensor<float, 3>>{};

auto yolo_masks = std::vector<std::vector<int>>{};
auto yolo_anchors = std::vector<int>{};

if constexpr (is_tiny)
{
// The CUDA tensors.
cuda_in_tensor = trt::InferenceEngine::PinnedTensor<float, 3>{3, 416, 416};
tensor_chw_resized_32f = CudaManagedTensor3f{{3, 416, 416}};
cuda_out_tensors = std::vector{
trt::InferenceEngine::PinnedTensor<float, 3>{255, 13, 13},
trt::InferenceEngine::PinnedTensor<float, 3>{255, 26, 26} //
Expand All @@ -189,19 +225,19 @@ auto test_on_video(int argc, char** argv) -> void
else
{
// The CUDA tensors.
cuda_in_tensor = trt::InferenceEngine::PinnedTensor<float, 3>{3, 608, 608};
tensor_chw_resized_32f = CudaManagedTensor3f{{3, 608, 608}};
cuda_out_tensors = std::vector{
trt::InferenceEngine::PinnedTensor<float, 3>{255, 76, 76},
trt::InferenceEngine::PinnedTensor<float, 3>{255, 38, 38}, //
trt::InferenceEngine::PinnedTensor<float, 3>{255, 19, 19}, //
};

const auto yolo_masks = std::vector{
yolo_masks = std::vector{
std::vector{0, 1, 2}, //
std::vector{3, 4, 5}, //
std::vector{6, 7, 8}, //
};
const auto yolo_anchors = std::vector{
yolo_anchors = std::vector{
12, 16, //
19, 36, //
40, 28, //
Expand Down Expand Up @@ -231,11 +267,22 @@ auto test_on_video(int argc, char** argv) -> void
continue;

sara::tic();
auto dets = detect_objects( //
video_stream.frame(), //
std::copy_n(reinterpret_cast<const std::uint8_t*>(frame.data()),
sizeof(sara::Rgb8) * frame.size(), //
tensor_hwc_8u.begin());
sara::toc("Copy frame data from host to CUDA");

sara::tic();
naive_downsample_and_transpose(tensor_chw_resized_32f, tensor_hwc_8u);
sara::toc("CUDA downsample+transpose");

sara::tic();
const auto dets = detect_objects( //
inference_engine, //
cuda_in_tensor, cuda_out_tensors, //
iou_thres, yolo_masks, yolo_anchors);
iou_thres, //
yolo_masks, yolo_anchors, //
frame.sizes());
sara::toc("Object detection");

sara::tic();
Expand All @@ -256,14 +303,14 @@ auto test_on_video(int argc, char** argv) -> void
}


int graphics_main(int argc, char** argv)
auto graphics_main(int argc, char** argv) -> int
{
test_on_video(argc, argv);
return 0;
}


int main(int argc, char** argv)
auto main(int argc, char** argv) -> int
{
DO::Sara::GraphicsApplication app(argc, argv);
app.register_user_main(graphics_main);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ add_custom_command(
${GLSLC_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/shader.frag -o
$<TARGET_FILE_DIR:hello_vulkan_image>/hello_vulkan_image_shaders/frag.spv)

# file(GLOB SHADER_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.vert *.frag)
add_custom_command(
TARGET hello_vulkan_image
PRE_BUILD
Expand Down
38 changes: 16 additions & 22 deletions cpp/examples/Shakti/Vulkan/hello_vulkan_image/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ class VulkanImageRenderer : public kvk::GraphicsBackend
auto h = int{};
glfwGetWindowSize(window, &w, &h);

const auto dynamic_viewport_states = std::vector<VkDynamicState>{
VK_DYNAMIC_STATE_VIEWPORT, //
VK_DYNAMIC_STATE_SCISSOR //
};

_graphics_pipeline =
VulkanImagePipelineBuilder{_device, _render_pass}
.vertex_shader_path(vertex_shader_path)
Expand All @@ -270,6 +275,7 @@ class VulkanImageRenderer : public kvk::GraphicsBackend
.input_assembly_topology(VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST)
.viewport_sizes(static_cast<float>(w), static_cast<float>(h))
.scissor_sizes(w, h)
.dynamic_states(dynamic_viewport_states)
.create();
}

Expand Down Expand Up @@ -631,8 +637,8 @@ class VulkanImageRenderer : public kvk::GraphicsBackend
vkCmdBindPipeline(command_buffer, VK_PIPELINE_BIND_POINT_GRAPHICS,
_graphics_pipeline);

#ifdef ALLOW_DYNAMIC_VIEWPORT_AND_SCISSOR_STATE
VkViewport viewport{};
// Important: reset the viewport.
auto viewport = VkViewport{};
viewport.x = 0.0f;
viewport.y = 0.0f;
viewport.width = static_cast<float>(_swapchain.extent.width);
Expand All @@ -641,11 +647,11 @@ class VulkanImageRenderer : public kvk::GraphicsBackend
viewport.maxDepth = 1.0f;
vkCmdSetViewport(command_buffer, 0, 1, &viewport);

VkRect2D scissor{};
// Important: reset the scissor.
auto scissor = VkRect2D{};
scissor.offset = {0, 0};
scissor.extent = _swapchain.extent;
vkCmdSetScissor(command_buffer, 0, 1, &scissor);
#endif

// Pass the VBO to the graphics pipeline.
static const auto vbos = std::array<VkBuffer, 1>{_vbo};
Expand Down Expand Up @@ -837,8 +843,8 @@ class VulkanImageRenderer : public kvk::GraphicsBackend
if (result == VK_ERROR_OUT_OF_DATE_KHR || result == VK_SUBOPTIMAL_KHR ||
_framebuffer_resized)
{
_framebuffer_resized = false;
recreate_swapchain();
_framebuffer_resized = false;
}
else if (result != VK_SUCCESS)
{
Expand Down Expand Up @@ -888,24 +894,12 @@ class VulkanImageRenderer : public kvk::GraphicsBackend
init_swapchain(_window);
init_swapchain_fbos();

// // This time only modify the view matrix.
// {
// _mvp.view.setIdentity();
// _mvp.view.scale(static_cast<float>(w) / _vstream.width());
// }

// Recalculate the projection matrix.
{
const auto fb_aspect_ratio = static_cast<float>(w) / h;
_mvp.projection = k::orthographic( //
-fb_aspect_ratio, fb_aspect_ratio, //
-1.f, 1.f, //
-1.f, 1.f);
}

SARA_CHECK(_mvp.model.matrix());
SARA_CHECK(_mvp.view.matrix());
SARA_CHECK(_mvp.projection);
const auto fb_aspect_ratio = static_cast<float>(w) / h;
_mvp.projection = k::orthographic( //
-fb_aspect_ratio, fb_aspect_ratio, //
-1.f, 1.f, //
-1.f, 1.f);
}

private:
Expand Down
Loading

0 comments on commit a42ba14

Please sign in to comment.