Skip to content

Commit

Permalink
examples: demo creating tensor from cudarc slice
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Apr 13, 2024
1 parent 8383879 commit af95987
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ default-members = [
'examples/yolov8',
'examples/modnet'
]
exclude = [ 'examples/cudarc' ]

[package]
name = "ort"
Expand Down
14 changes: 14 additions & 0 deletions examples/cudarc/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "example-cudarc"
version = "0.1.0"
edition = "2021"

[dependencies]
ort = { path = "../../", features = [ "cuda", "fetch-models" ] }
cudarc = "0.10"
anyhow = "1.0"
ndarray = "0.15"
image = "0.24"
tracing = "0.1"
show-image = { version = "0.13", features = [ "image", "raqote" ] }
tracing-subscriber = { version = "0.3", features = [ "fmt" ] }
1 change: 1 addition & 0 deletions examples/cudarc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Photo by <a href="https://unsplash.com/@nixcreative?utm_content=creditCopyText&utm_medium=referral&utm_source=unsplash">Tyler Nix</a> on <a href="https://unsplash.com/photos/woman-standing-in-front-of-multicolored-wall-sh3LSNbyj7k?utm_content=creditCopyText&utm_medium=referral&utm_source=unsplash">Unsplash</a>
Binary file added examples/cudarc/data/photo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
91 changes: 91 additions & 0 deletions examples/cudarc/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use std::{ops::Mul, path::Path};

use cudarc::driver::{sys::CUdeviceptr, CudaDevice, DevicePtr, DevicePtrMut};
use image::{imageops::FilterType, GenericImageView, ImageBuffer, Rgba};
use ndarray::Array;
use ort::{AllocationDevice, AllocatorType, CUDAExecutionProvider, ExecutionProvider, MemoryInfo, MemoryType, Session, TensorRefMut};
use show_image::{event, AsImageView, WindowOptions};

#[show_image::main]
fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();

ort::init()
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;

let model =
Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/modnet_photographic_portrait_matting.onnx")?;

let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("photo.jpg")).unwrap();
let (img_width, img_height) = (original_img.width(), original_img.height());
let img = original_img.resize_exact(512, 512, FilterType::Triangle);
let mut input = Array::zeros((1, 3, 512, 512));
for pixel in img.pixels() {
let x = pixel.0 as _;
let y = pixel.1 as _;
let [r, g, b, _] = pixel.2.0;
input[[0, 0, y, x]] = (r as f32 - 127.5) / 127.5;
input[[0, 1, y, x]] = (g as f32 - 127.5) / 127.5;
input[[0, 2, y, x]] = (b as f32 - 127.5) / 127.5;
}

let device = CudaDevice::new(0)?;
let device_data = device.htod_sync_copy(&input.into_raw_vec())?;
let tensor: TensorRefMut<'_, f32> = unsafe {
TensorRefMut::from_raw(
MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?,
(*device_data.device_ptr() as usize as *mut ()).cast(),
vec![1, 3, 512, 512]
)
.unwrap()
};
let outputs = model.run([tensor.into()])?;

let output = outputs["output"].try_extract_tensor::<f32>()?;

// convert to 8-bit
let output = output.mul(255.0).map(|x| *x as u8);
let output = output.into_raw_vec();

// change rgb to rgba
let output_img = ImageBuffer::from_fn(512, 512, |x, y| {
let i = (x + y * 512) as usize;
Rgba([output[i], output[i], output[i], 255])
});

let mut output = image::imageops::resize(&output_img, img_width, img_height, FilterType::Triangle);
output.enumerate_pixels_mut().for_each(|(x, y, pixel)| {
let origin = original_img.get_pixel(x, y);
pixel.0[3] = pixel.0[0];
pixel.0[0] = origin.0[0];
pixel.0[1] = origin.0[1];
pixel.0[2] = origin.0[2];
});

let window = show_image::context()
.run_function_wait(move |context| -> Result<_, String> {
let mut window = context
.create_window(
"ort + modnet",
WindowOptions {
size: Some([img_width, img_height]),
..WindowOptions::default()
}
)
.map_err(|e| e.to_string())?;
window.set_image("photo", &output.as_image_view().map_err(|e| e.to_string())?);
Ok(window.proxy())
})
.unwrap();

for event in window.event_channel().unwrap() {
if let event::WindowEvent::KeyboardInput(event) = event {
if event.input.key_code == Some(event::VirtualKeyCode::Escape) && event.input.state.is_pressed() {
break;
}
}
}

Ok(())
}

0 comments on commit af95987

Please sign in to comment.