Skip to content

Commit

Permalink
ENH: accelerate image preprocessing.
Browse files Browse the repository at this point in the history
The implementation is using a naive downsampling.
  • Loading branch information
oddkiva committed Dec 15, 2023
1 parent 2217fd0 commit a928a57
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 60 deletions.
114 changes: 66 additions & 48 deletions cpp/examples/Shakti/TensorRT/tensorrt_yolov4_tiny_example.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,29 @@
#include <DO/Sara/NeuralNetworks/Darknet/YoloUtilities.hpp>
#include <DO/Sara/VideoIO.hpp>

#include <algorithm>
#include <filesystem>


namespace sara = DO::Sara;
namespace shakti = DO::Shakti;
namespace fs = std::filesystem;
namespace trt = DO::Shakti::TensorRT;
namespace d = sara::Darknet;

using CudaManagedTensor3ub =
trt::InferenceExecutor::ManagedTensor<std::uint8_t, 3>;
using CudaManagedTensor3f = trt::InferenceExecutor::ManagedTensor<float, 3>;

__global__ auto naive_downsample(float* out, const float* in, const int wout,
const int hout, const int win, const int hin)

__global__ auto naive_downsample_and_transpose(float* out_chw,
const std::uint8_t* in_hwc,
const int wout, const int hout,
const int win, const int hin)
-> void
{
const int xout = blockIdx.x * blockDim.x + threadIdx.x;
const int c = blockIdx.x * blockDim.x + threadIdx.x;
const int yout = blockIdx.y * blockDim.y + threadIdx.y;
const int c = blockIdx.z * blockDim.z + threadIdx.z;
const int xout = blockIdx.z * blockDim.z + threadIdx.z;

if (xout >= wout || yout >= hout || c >= 3)
return;
Expand All @@ -52,41 +58,25 @@ __global__ auto naive_downsample(float* out, const float* in, const int wout,
if (yin >= hin)
yin = hin - 1;

const int gi_out = c * wout * hout + yout * wout + xout;
const int gi_in = c * win * hin + yin * win + xin;
out[gi_out] = in[gi_in];
const int gi_out = c * hout * wout + yout * wout + xout;
const int gi_in = yin * win * 3 + xin * 3 + c;

static constexpr auto normalize_factor = 1 / 255.f;
out_chw[gi_out] = static_cast<float>(in_hwc[gi_in]) * normalize_factor;
}


// The API.
auto detect_objects(
const sara::ImageView<sara::Rgb32f>& image,
const trt::InferenceExecutor& inference_engine,
trt::InferenceExecutor::PinnedTensor<float, 3>& cuda_in_tensor,
const trt::InferenceExecutor::PinnedTensor<float, 3>& cuda_in_tensor,
std::array<trt::InferenceExecutor::PinnedTensor<float, 3>, 2>&
cuda_out_tensors,
const float iou_thres, //
const std::array<std::vector<int>, 2>& anchor_masks,
const std::vector<int>& anchors) -> std::vector<d::YoloBox>
const std::vector<int>& anchors, const Eigen::Vector2i& image_sizes)
-> std::vector<d::YoloBox>
{
// This is the bottleneck.
sara::tic();
const auto image_resized = sara::resize(image, {416, 416});
sara::toc("Image resize");

sara::tic();
const auto image_tensor =
sara::tensor_view(image_resized)
.reshape(Eigen::Vector4i{1, image_resized.height(),
image_resized.width(), 3})
.transpose({0, 3, 1, 2});
sara::toc("Tensor transpose");

// Copy to the CUDA tensor.
sara::tic();
std::copy(image_tensor.begin(), image_tensor.end(), cuda_in_tensor.begin());
sara::toc("Copy to CUDA tensor");

// Feed the input and outputs to the YOLO v4 tiny network.
sara::tic();
inference_engine(cuda_in_tensor, cuda_out_tensors, true);
Expand All @@ -95,14 +85,15 @@ auto detect_objects(
// Accumulate all the detection from each YOLO layer.
sara::tic();
auto detections = std::vector<d::YoloBox>{};
const auto wr = cuda_in_tensor.sizes()(2);
const auto hr = cuda_in_tensor.sizes()(1);
for (auto i = 0; i < 2; ++i)
{
const auto& yolo_out = cuda_out_tensors[i];
const auto& anchor_mask = anchor_masks[i];
const auto dets =
d::get_yolo_boxes(yolo_out, //
anchors, anchor_mask, //
image_resized.sizes(), image.sizes(), 0.25f);
const auto dets = d::get_yolo_boxes(yolo_out, //
anchors, anchor_mask, //
{wr, hr}, image_sizes, 0.25f);
detections.insert(detections.end(), dets.begin(), dets.end());
}
sara::toc("Postprocess boxes");
Expand Down Expand Up @@ -148,15 +139,11 @@ auto test_on_video(int argc, char** argv) -> void
// Load the network and get the CUDA inference engine ready.
auto inference_executor = trt::InferenceExecutor{serialized_net};

using shakti::ManagedMemoryAllocator;
using CudaManagedTensor_ = sara::Tensor_<float, 3, ManagedMemoryAllocator>;
auto tensor_rgb32f = CudaManagedTensor_{frame.height(), frame.width(), 3};
auto frame32f = sara::ImageView<sara::Rgb32f>{
reinterpret_cast<sara::Rgb32f *>(tensor_rgb32f.data()), frame.sizes()};
auto tensor_hwc_8u = CudaManagedTensor3ub{frame.height(), frame.width(), 3};
auto tensor_hwc_32f = CudaManagedTensor3f{frame.height(), frame.width(), 3};
auto tensor_chw_resized_32f = CudaManagedTensor3f{{3, 416, 416}};

// The CUDA tensors.
auto cuda_in_tensor =
trt::InferenceExecutor::PinnedTensor<float, 3>{3, 416, 416};
auto& cuda_in_tensor = tensor_chw_resized_32f;
auto cuda_out_tensors = std::array{
trt::InferenceExecutor::PinnedTensor<float, 3>{255, 13, 13},
trt::InferenceExecutor::PinnedTensor<float, 3>{255, 26, 26} //
Expand Down Expand Up @@ -192,15 +179,46 @@ auto test_on_video(int argc, char** argv) -> void
continue;

sara::tic();
sara::convert(frame, frame32f);
sara::toc("Color conversion");
std::copy_n(reinterpret_cast<const std::uint8_t*>(frame.data()),
sizeof(sara::Rgb8) * frame.size(), //
tensor_hwc_8u.begin());
sara::toc("Copy frame data from host to CUDA");

sara::tic();
{
// Data order: H W C
// 0 1 2
const auto in_hwc = tensor_hwc_8u.data();
const auto win = tensor_hwc_8u.sizes()(1);
const auto hin = tensor_hwc_8u.sizes()(0);

// Data order: C H W
// 0 1 2
auto out_chw = tensor_chw_resized_32f.data();
const auto hout = tensor_chw_resized_32f.sizes()(1);
const auto wout = tensor_chw_resized_32f.sizes()(2);

const auto threads_per_block = dim3(4, 16, 16);
const auto num_blocks = dim3( //
1, //
(hout + threads_per_block.y - 1) / threads_per_block.y,
(wout + threads_per_block.z - 1) / threads_per_block.z //
);

naive_downsample_and_transpose<<<num_blocks, threads_per_block>>>(
out_chw, in_hwc, //
wout, hout, //
win, hin //
);
}
sara::toc("CUDA downsample+transpose");

sara::tic();
auto dets = detect_objects( //
frame32f, //
inference_executor, //
cuda_in_tensor, cuda_out_tensors, //
iou_thres, yolo_masks, yolo_anchors);
auto dets = detect_objects( //
inference_executor, //
cuda_in_tensor, cuda_out_tensors, //
iou_thres, yolo_masks, yolo_anchors, //
frame.sizes());
sara::toc("Object detection");

sara::tic();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,18 +300,18 @@ namespace DO::Shakti::Cuda::Gaussian {
timer.restart();
#endif
{
const auto threadsperBlock = dim3(kernel_max_radius, tile_size);
const auto numBlocks = dim3(
(d_in.padded_width() + threadsperBlock.x - 1) / threadsperBlock.x,
(d_in.height() + threadsperBlock.y - 1) / threadsperBlock.y);
const auto threads_per_block = dim3(kernel_max_radius, tile_size);
const auto num_blocks = dim3(
(d_in.padded_width() + threads_per_block.x - 1) / threads_per_block.x,
(d_in.height() + threads_per_block.y - 1) / threads_per_block.y);

// x-convolution.
convx<<<numBlocks, threadsperBlock>>>(d_in.data(), //
d_convx.data(), //
d_in.width(), //
d_in.height(), //
d_in.padded_width(), //
kernel_index);
convx<<<num_blocks, threads_per_block>>>(d_in.data(), //
d_convx.data(), //
d_in.width(), //
d_in.height(), //
d_in.padded_width(), //
kernel_index);
}
#ifdef PROFILE_GAUSSIAN_CONVOLUTION
elapsed = timer.elapsed_ms();
Expand Down
23 changes: 23 additions & 0 deletions cpp/src/DO/Shakti/Cuda/TensorRT/InferenceExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,26 @@ auto InferenceExecutor::operator()( //
if (synchronize)
cudaStreamSynchronize(*_cuda_stream);
}

auto InferenceExecutor::operator()(
const ManagedTensor<float, 3>& in,
std::array<PinnedTensor<float, 3>, 2>& out, //
const bool synchronize) const -> void
{
const auto device_tensors = std::array{
const_cast<void*>(reinterpret_cast<const void*>(in.data())), //
reinterpret_cast<void*>(out[0].data()), //
reinterpret_cast<void*>(out[1].data()) //
};

// Enqueue the CPU pinned <-> GPU tranfers and the convolution task.
if (!_context->enqueueV2(device_tensors.data(), *_cuda_stream, nullptr))
{
SARA_DEBUG << termcolor::red << "Execution failed!" << termcolor::reset
<< std::endl;
}

// Wait for the completion of GPU operations.
if (synchronize)
cudaStreamSynchronize(*_cuda_stream);
}
12 changes: 10 additions & 2 deletions cpp/src/DO/Shakti/Cuda/TensorRT/InferenceExecutor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <DO/Sara/Core/Tensor.hpp>

#include <DO/Shakti/Cuda/MultiArray/ManagedMemoryAllocator.hpp>
#include <DO/Shakti/Cuda/MultiArray/PinnedMemoryAllocator.hpp>
#include <DO/Shakti/Cuda/TensorRT/Helpers.hpp>

Expand All @@ -29,7 +30,10 @@ namespace DO::Shakti::TensorRT {
{
public:
template <typename T, int N>
using PinnedTensor = Sara::Tensor_<T, N, Shakti::PinnedMemoryAllocator>;
using PinnedTensor = Sara::Tensor_<T, N, PinnedMemoryAllocator>;

template <typename T, int N>
using ManagedTensor = Sara::Tensor_<T, N, ManagedMemoryAllocator>;

InferenceExecutor() = default;

Expand All @@ -43,11 +47,15 @@ namespace DO::Shakti::TensorRT {
std::array<PinnedTensor<float, 3>, 2>& out, //
const bool synchronize = true) const -> void;

auto operator()(const ManagedTensor<float, 3>& in,
std::array<PinnedTensor<float, 3>, 2>& out, //
const bool synchronize = true) const -> void;

// private:
CudaStreamUniquePtr _cuda_stream = make_cuda_stream();
RuntimeUniquePtr _runtime = {nullptr, &runtime_deleter};
CudaEngineUniquePtr _engine = {nullptr, &engine_deleter};
ContextUniquePtr _context = {nullptr, &context_deleter};
};

} // namespace DO::Sara::TensorRT
} // namespace DO::Shakti::TensorRT

0 comments on commit a928a57

Please sign in to comment.