Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative backends #333

Merged
merged 13 commits into from
Jan 13, 2025
70 changes: 70 additions & 0 deletions .github/workflows/backends.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
name: 🧩 Backends
on:
workflow_dispatch:
push:
branches:
- 'main'
paths:
- '.github/workflows/backends.yml'
- 'src/**/*.rs'
- 'backends/**/*.rs'
- 'ort-sys/src/lib.rs'
- 'Cargo.toml'
pull_request:
paths:
- '.github/workflows/backends.yml'
- 'src/**/*.rs'
- 'backends/**/*.rs'
- 'ort-sys/src/lib.rs'
- 'Cargo.toml'
env:
RUST_BACKTRACE: 1
CARGO_INCREMENTAL: 0
CARGO_PROFILE_DEV_DEBUG: 0
jobs:
candle:
name: Candle
runs-on: ${{ matrix.platform.os }}
strategy:
fail-fast: false
matrix:
platform:
- os: ubuntu-latest
- os: macos-15
steps:
- uses: actions/checkout@v4
- name: Install protoc
run: |
if [ "$RUNNER_OS" == "Linux" ]; then
sudo apt-get update && sudo apt-get install protobuf-compiler -y
elif [ "$RUNNER_OS" == "macOS" ]; then
brew install protobuf
fi
- name: Install stable Rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
- uses: Swatinem/rust-cache@v1
- name: Run tests
run: |
cargo test --manifest-path backends/candle/Cargo.toml --verbose -- --test-threads 1
tract:
name: Tract
runs-on: ${{ matrix.platform.os }}
strategy:
fail-fast: false
matrix:
platform:
- os: ubuntu-latest
- os: windows-latest
- os: macos-15
steps:
- uses: actions/checkout@v4
- name: Install stable Rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
- uses: Swatinem/rust-cache@v1
- name: Run tests
run: |
cargo test --manifest-path backends/tract/Cargo.toml --verbose -- --test-threads 1
2 changes: 1 addition & 1 deletion .github/workflows/code-quality.yml
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ jobs:
- name: Check fmt
run: cargo fmt --all -- --check
- name: Run clippy
run: cargo clippy -p ort --all-targets --workspace --features fetch-models
run: cargo clippy -p ort --all-targets --features fetch-models
coverage:
name: Code coverage
runs-on: ubuntu-24.04
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -21,7 +21,11 @@ default-members = [
'examples/modnet',
'examples/sentence-transformers'
]
exclude = [ 'examples/cudarc' ]
exclude = [
'backends/candle',
'backends/tract',
'examples/cudarc'
]

[package]
name = "ort"
36 changes: 36 additions & 0 deletions backends/candle/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[package]
name = "ort-candle"
description = "ort + candle = 🦀 - An alternative backend for ort, powered by candle."
version = "0.1.0+0.8.1"
edition = "2021"
rust-version = "1.70"
license = "MIT OR Apache-2.0"
repository = "https://github.com/pykeio/ort"
homepage = "https://ort.pyke.io/"
keywords = [ "machine-learning", "ai", "ml" , "sys"]
categories = [ "algorithms", "mathematics", "science" ]
authors = [
"pyke.io <contact@pyke.io>"
]

[lib]
name = "ort_candle"
path = "lib.rs"

[features]

[dependencies]
ort-sys = { version = "=2.0.0-rc.9", path = "../../ort-sys", default-features = false }
candle-core = { version = "0.8.1", default-features = false }
candle-onnx = { version = "0.8.1" }
prost = { version = "0.12.1", default-features = false }

[dev-dependencies]
ort = { version = "=2.0.0-rc.9", path = "../../", default-features = false, features = [ "alternative-backend", "fetch-models" ] }

[[test]]
name = "memory"
path = "tests/memory.rs"
[[test]]
name = "tensor"
path = "tests/tensor.rs"
2,266 changes: 2,266 additions & 0 deletions backends/candle/api.rs

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions backends/candle/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use std::ffi::{CString, c_char};

#[derive(Debug, Clone)]
pub struct Error {
pub code: ort_sys::OrtErrorCode,
message: CString
}

impl Error {
pub fn new(code: ort_sys::OrtErrorCode, message: impl Into<String>) -> Self {

Check warning on line 10 in backends/candle/error.rs

Codecov / codecov/patch

backends/candle/error.rs#L10

Added line #L10 was not covered by tests
Self {
code,
message: CString::new(message.into()).unwrap()

Check warning on line 13 in backends/candle/error.rs

Codecov / codecov/patch

backends/candle/error.rs#L13

Added line #L13 was not covered by tests
}
}

pub fn into_sys(self) -> *mut ort_sys::OrtStatus {
(Box::leak(Box::new(self)) as *mut Error).cast()
}

pub fn new_sys(code: ort_sys::OrtErrorCode, message: impl Into<String>) -> *mut ort_sys::OrtStatus {
Self::new(code, message).into_sys()

Check warning on line 22 in backends/candle/error.rs

Codecov / codecov/patch

backends/candle/error.rs#L21-L22

Added lines #L21 - L22 were not covered by tests
}

#[inline]
pub fn message(&self) -> &str {
self.message.as_c_str().to_str().unwrap()

Check warning on line 27 in backends/candle/error.rs

Codecov / codecov/patch

backends/candle/error.rs#L26-L27

Added lines #L26 - L27 were not covered by tests
}

#[inline]
pub fn message_ptr(&self) -> *const c_char {
self.message.as_ptr()

Check warning on line 32 in backends/candle/error.rs

Codecov / codecov/patch

backends/candle/error.rs#L31-L32

Added lines #L31 - L32 were not covered by tests
}

pub unsafe fn cast_from_sys<'e>(ptr: *const ort_sys::OrtStatus) -> &'e Error {
unsafe { &*ptr.cast::<Error>() }

Check warning on line 36 in backends/candle/error.rs

Codecov / codecov/patch

backends/candle/error.rs#L35-L36

Added lines #L35 - L36 were not covered by tests
}

pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtStatus) -> Box<Error> {
Box::from_raw(ptr.cast::<Error>())
}
}
52 changes: 52 additions & 0 deletions backends/candle/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use candle_core::DType;
use ort_sys::OrtErrorCode;

mod api;
pub(crate) mod error;
mod memory;
mod session;
mod tensor;

pub use self::api::api;
use self::error::Error;

pub(crate) struct Environment {}

impl Environment {
pub fn new_sys() -> *mut ort_sys::OrtEnv {
(Box::leak(Box::new(Self {})) as *mut Environment).cast()
}

pub unsafe fn cast_from_sys<'e>(ptr: *const ort_sys::OrtEnv) -> &'e Environment {
unsafe { &*ptr.cast::<Environment>() }

Check warning on line 21 in backends/candle/lib.rs

Codecov / codecov/patch

backends/candle/lib.rs#L20-L21

Added lines #L20 - L21 were not covered by tests
}

pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtEnv) -> Box<Environment> {
Box::from_raw(ptr.cast::<Environment>())
}
}

fn convert_sys_to_dtype(sys: ort_sys::ONNXTensorElementDataType) -> Result<DType, Error> {
match sys {
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => Ok(DType::U8),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => Ok(DType::U32),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => Ok(DType::I64),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 => Ok(DType::BF16),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => Ok(DType::F16),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => Ok(DType::F32),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => Ok(DType::F64),
_ => Err(Error::new(OrtErrorCode::ORT_FAIL, "Element type not supported by candle"))
}
}

fn convert_dtype_to_sys(dtype: DType) -> ort_sys::ONNXTensorElementDataType {
match dtype {
DType::U8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
DType::U32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
DType::I64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
DType::BF16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16,
DType::F16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
DType::F32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
DType::F64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE
}
}
109 changes: 109 additions & 0 deletions backends/candle/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use std::{any::Any, ffi::CString, ptr};

use candle_core::{Device, DeviceLocation, backend::BackendDevice};
use ort_sys::{OrtAllocatorType, OrtErrorCode, OrtMemType, OrtMemoryInfoDeviceType};

use crate::error::Error;

#[repr(transparent)]
pub struct MemoryInfo(pub Device);

impl MemoryInfo {
pub fn new(device_name: impl AsRef<str>, device_id: usize, mem_type: OrtMemType) -> Result<Self, Error> {
match device_name.as_ref() {
"Cpu" | "CudaPinned" => Ok(Self(Device::Cpu)),
"Cuda" => match mem_type {
OrtMemType::OrtMemTypeCPUInput | OrtMemType::OrtMemTypeCPUOutput => Ok(Self(Device::Cpu)),
OrtMemType::OrtMemTypeDefault => Device::new_cuda(device_id)
.map(Self)
.map_err(|e| Error::new(OrtErrorCode::ORT_ENGINE_ERROR, e.to_string()))

Check warning on line 19 in backends/candle/memory.rs

Codecov / codecov/patch

backends/candle/memory.rs#L12-L19

Added lines #L12 - L19 were not covered by tests
},
"Metal" => Device::new_metal(device_id)
.map(Self)
.map_err(|e| Error::new(OrtErrorCode::ORT_ENGINE_ERROR, e.to_string())),
device_name => Err(Error::new(OrtErrorCode::ORT_NOT_IMPLEMENTED, format!("ort-candle does not support the '{device_name}' device")))

Check warning on line 24 in backends/candle/memory.rs

Codecov / codecov/patch

backends/candle/memory.rs#L21-L24

Added lines #L21 - L24 were not covered by tests
}
}

pub fn device(&self) -> &Device {
&self.0
}

pub fn device_type(&self) -> OrtMemoryInfoDeviceType {
match &self.0 {
Device::Cpu => OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU,
Device::Cuda(_) | Device::Metal(_) => OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU
}
}

pub fn device_name(&self) -> &'static str {
let sys_str = self.device_name_sys();
&sys_str[..sys_str.len() - 1]
}

pub fn device_name_sys(&self) -> &'static str {
match &self.0 {
Device::Cpu => "Cpu\0",
Device::Cuda(_) => "Cuda\0",
Device::Metal(_) => "Metal\0"
}
}

pub fn device_id(&self) -> usize {
match self.0.location() {
DeviceLocation::Cpu => 0,
DeviceLocation::Cuda { gpu_id } => gpu_id,
DeviceLocation::Metal { gpu_id } => gpu_id
}
}

pub fn memory_type(&self) -> OrtMemType {
OrtMemType::OrtMemTypeDefault
}
}

impl PartialEq for MemoryInfo {
fn eq(&self, other: &Self) -> bool {
self.0.same_device(&other.0)
}
}

#[repr(C)]
pub struct Allocator<'m> {
_sys_api: ort_sys::OrtAllocator,
pub memory_info: &'m MemoryInfo
}

impl<'m> Allocator<'m> {
pub const fn new(memory_info: &'m MemoryInfo) -> Self {

Check warning on line 78 in backends/candle/memory.rs

Codecov / codecov/patch

backends/candle/memory.rs#L78

Added line #L78 was not covered by tests
Self {
_sys_api: ort_sys::OrtAllocator {

Check warning on line 80 in backends/candle/memory.rs

Codecov / codecov/patch

backends/candle/memory.rs#L80

Added line #L80 was not covered by tests
version: ort_sys::ORT_API_VERSION,
Alloc: Some(sys_allocator_alloc),
Free: Some(sys_allocator_free),
Info: Some(sys_allocator_info),
Reserve: Some(sys_allocator_reserve)
},
memory_info
}
}
}

pub static DEFAULT_CPU_ALLOCATOR: Allocator = Allocator::new(&MemoryInfo(Device::Cpu));

unsafe extern "system" fn sys_allocator_alloc(_this: *mut ort_sys::OrtAllocator, _size: usize) -> *mut ::std::os::raw::c_void {
ptr::null_mut()
}

unsafe extern "system" fn sys_allocator_free(_this: *mut ort_sys::OrtAllocator, p: *mut ::std::os::raw::c_void) {
drop(CString::from_raw(p.cast()));
}

unsafe extern "system" fn sys_allocator_info(this_: *const ort_sys::OrtAllocator) -> *const ort_sys::OrtMemoryInfo {
let _allocator = unsafe { &*this_.cast::<Allocator>() };
(_allocator.memory_info as *const MemoryInfo).cast()
}

unsafe extern "system" fn sys_allocator_reserve(_this: *const ort_sys::OrtAllocator, _size: usize) -> *mut ::std::os::raw::c_void {
ptr::null_mut()
}
23 changes: 23 additions & 0 deletions backends/candle/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use std::{collections::HashMap, path::Path};

use candle_core::Tensor;
use candle_onnx::onnx::ModelProto;
use prost::Message;

#[derive(Default, Clone)]
pub struct SessionOptions;

pub struct Session {
pub model: ModelProto
}

impl Session {
pub fn from_buffer(_options: &SessionOptions, data: &[u8]) -> Result<Session, prost::DecodeError> {
let model = ModelProto::decode(data)?;
Ok(Session { model })
}

pub fn run(&self, inputs: HashMap<String, Tensor>) -> candle_core::Result<HashMap<String, Tensor>> {
candle_onnx::simple_eval(&self.model, inputs)
}
}
16 changes: 16 additions & 0 deletions backends/candle/tensor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use candle_core::DType;

pub struct TypeInfo {
pub dtype: DType,
pub shape: Vec<i64>
}

impl TypeInfo {
pub fn new_sys(dtype: DType, shape: Vec<i64>) -> *mut ort_sys::OrtTypeInfo {
(Box::leak(Box::new(Self { dtype, shape })) as *mut TypeInfo).cast()
}

pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtTypeInfo) -> Box<TypeInfo> {
Box::from_raw(ptr.cast::<TypeInfo>())
}
}
16 changes: 16 additions & 0 deletions backends/candle/tests/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use ort::memory::{AllocationDevice, Allocator, DeviceType};

#[test]
fn test_memory_info_apis() {
ort::set_api(ort_candle::api());

let allocator = Allocator::default();

let memory_info = allocator.memory_info();
assert_eq!(memory_info.allocation_device(), AllocationDevice::CPU);
assert_eq!(memory_info.device_type(), DeviceType::CPU);
assert_eq!(memory_info.device_id(), 0);

let memory_info_clone = memory_info.clone();
assert_eq!(memory_info, memory_info_clone);
}
17 changes: 17 additions & 0 deletions backends/candle/tests/tensor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use ort::value::Tensor;

#[test]
fn test_tensors() -> ort::Result<()> {
ort::set_api(ort_candle::api());

let mut tensor = Tensor::<i64>::from_array((vec![5], vec![0, 1, 2, 3, 4]))?;
let ptr = tensor.data_ptr_mut()?.cast::<i64>();
unsafe {
*ptr.add(3) = 42;
};

let (_, extracted) = tensor.extract_raw_tensor();
assert_eq!(&extracted, &[0, 1, 2, 42, 4]);

Ok(())
}
41 changes: 41 additions & 0 deletions backends/tract/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
[package]
name = "ort-tract"
description = "ort + tract = 🦀 - An alternative backend for ort, powered by tract."
version = "0.1.0+0.8.1"
edition = "2021"
rust-version = "1.70"
license = "MIT OR Apache-2.0"
repository = "https://github.com/pykeio/ort"
homepage = "https://ort.pyke.io/"
keywords = [ "machine-learning", "ai", "ml" , "sys"]
categories = [ "algorithms", "mathematics", "science" ]
authors = [
"pyke.io <contact@pyke.io>"
]

[lib]
name = "ort_tract"
path = "lib.rs"

[features]

[dependencies]
ort-sys = { version = "=2.0.0-rc.9", path = "../../ort-sys", default-features = false }
tract-onnx = "0.21"
parking_lot = "0.12"

[dev-dependencies]
ort = { version = "=2.0.0-rc.9", path = "../../", default-features = false, features = [ "alternative-backend", "fetch-models", "ndarray" ] }
ureq = "2.1"
image = "0.25"
ndarray = "0.16"

[[test]]
name = "memory"
path = "tests/memory.rs"
[[test]]
name = "tensor"
path = "tests/tensor.rs"
[[test]]
name = "session"
path = "tests/session.rs"
2,191 changes: 2,191 additions & 0 deletions backends/tract/api.rs

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions backends/tract/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use std::ffi::{CString, c_char};

#[derive(Debug, Clone)]
pub struct Error {
pub code: ort_sys::OrtErrorCode,
message: CString
}

impl Error {
pub fn new(code: ort_sys::OrtErrorCode, message: impl Into<String>) -> Self {

Check warning on line 10 in backends/tract/error.rs

Codecov / codecov/patch

backends/tract/error.rs#L10

Added line #L10 was not covered by tests
Self {
code,
message: CString::new(message.into()).unwrap()

Check warning on line 13 in backends/tract/error.rs

Codecov / codecov/patch

backends/tract/error.rs#L13

Added line #L13 was not covered by tests
}
}

pub fn into_sys(self) -> *mut ort_sys::OrtStatus {
(Box::leak(Box::new(self)) as *mut Error).cast()
}

pub fn new_sys(code: ort_sys::OrtErrorCode, message: impl Into<String>) -> *mut ort_sys::OrtStatus {
Self::new(code, message).into_sys()

Check warning on line 22 in backends/tract/error.rs

Codecov / codecov/patch

backends/tract/error.rs#L21-L22

Added lines #L21 - L22 were not covered by tests
}

#[inline]
pub fn message_ptr(&self) -> *const c_char {
self.message.as_ptr()

Check warning on line 27 in backends/tract/error.rs

Codecov / codecov/patch

backends/tract/error.rs#L26-L27

Added lines #L26 - L27 were not covered by tests
}

pub unsafe fn cast_from_sys<'e>(ptr: *const ort_sys::OrtStatus) -> &'e Error {
unsafe { &*ptr.cast::<Error>() }

Check warning on line 31 in backends/tract/error.rs

Codecov / codecov/patch

backends/tract/error.rs#L30-L31

Added lines #L30 - L31 were not covered by tests
}

pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtStatus) -> Box<Error> {
Box::from_raw(ptr.cast::<Error>())
}
}
63 changes: 63 additions & 0 deletions backends/tract/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use ort_sys::OrtErrorCode;
use tract_onnx::{Onnx, prelude::DatumType};

mod api;
pub(crate) mod error;
mod memory;
mod session;
mod tensor;

pub use self::api::api;
use self::error::Error;

pub(crate) struct Environment {
pub onnx: Onnx
}

impl Environment {
pub fn new_sys() -> *mut ort_sys::OrtEnv {
(Box::leak(Box::new(Self { onnx: tract_onnx::onnx() })) as *mut Environment).cast()
}

pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtEnv) -> Box<Environment> {
Box::from_raw(ptr.cast::<Environment>())
}
}

fn convert_sys_to_datum_type(sys: ort_sys::ONNXTensorElementDataType) -> Result<DatumType, Error> {
match sys {
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => Ok(DatumType::Bool),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => Ok(DatumType::U8),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => Ok(DatumType::U16),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => Ok(DatumType::U32),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => Ok(DatumType::U64),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => Ok(DatumType::I8),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => Ok(DatumType::I16),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => Ok(DatumType::I32),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => Ok(DatumType::I64),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => Ok(DatumType::F16),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => Ok(DatumType::F32),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => Ok(DatumType::F64),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => Ok(DatumType::String),
_ => Err(Error::new(OrtErrorCode::ORT_FAIL, "Element type not supported by tract"))
}
}

fn convert_datum_type_to_sys(dtype: DatumType) -> ort_sys::ONNXTensorElementDataType {
match dtype {
DatumType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
DatumType::U8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
DatumType::U16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
DatumType::U32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
DatumType::U64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
DatumType::I8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
DatumType::I16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
DatumType::I32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
DatumType::I64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
DatumType::F16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
DatumType::F32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
DatumType::F64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
DatumType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
_ => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
}
}
39 changes: 39 additions & 0 deletions backends/tract/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use std::{ffi::CString, ptr};

#[repr(C)]
pub struct Allocator {
_sys_api: ort_sys::OrtAllocator
}

impl Allocator {
pub const fn new() -> Self {
Self {
_sys_api: ort_sys::OrtAllocator {
version: ort_sys::ORT_API_VERSION,
Alloc: Some(sys_allocator_alloc),
Free: Some(sys_allocator_free),
Info: Some(sys_allocator_info),
Reserve: Some(sys_allocator_reserve)
}
}
}
}

pub static DEFAULT_CPU_ALLOCATOR: Allocator = Allocator::new();

unsafe extern "system" fn sys_allocator_alloc(_this: *mut ort_sys::OrtAllocator, _size: usize) -> *mut ::std::os::raw::c_void {
ptr::null_mut()
}

unsafe extern "system" fn sys_allocator_free(_this: *mut ort_sys::OrtAllocator, p: *mut ::std::os::raw::c_void) {
drop(CString::from_raw(p.cast()));
}

unsafe extern "system" fn sys_allocator_info(this_: *const ort_sys::OrtAllocator) -> *const ort_sys::OrtMemoryInfo {
let _allocator = unsafe { &*this_.cast::<Allocator>() };
ptr::dangling()
}

unsafe extern "system" fn sys_allocator_reserve(_this: *const ort_sys::OrtAllocator, _size: usize) -> *mut ::std::os::raw::c_void {
ptr::null_mut()
}
117 changes: 117 additions & 0 deletions backends/tract/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use std::{
collections::{HashMap, hash_map::Entry},
hash::{BuildHasher, DefaultHasher, Hasher},
sync::Arc
};

use parking_lot::Mutex;
use tract_onnx::{
pb::ValueInfoProto,
prelude::{Framework, Graph, InferenceModelExt, IntoTensor, SimplePlan, Tensor, TractResult, TypedFact, TypedOp}
};

use crate::Environment;

type OptimizedGraph = Graph<TypedFact, Box<dyn TypedOp>>;
type RunnableGraph = SimplePlan<TypedFact, Box<dyn TypedOp>, OptimizedGraph>;

#[derive(Default, Clone)]
pub struct SessionOptions {
pub perform_optimizations: bool
}

pub struct SessionLockedInner {
original_graph: Arc<OptimizedGraph>,
graphs: HashMap<u64, RunnableGraph, PassthroughHashBuilder>
}

impl SessionLockedInner {
pub fn new(original_graph: Arc<OptimizedGraph>) -> Self {
Self {
original_graph,
graphs: HashMap::with_hasher(PassthroughHashBuilder)
}
}

pub fn get_graph(&mut self, inputs: &[(String, Tensor)]) -> TractResult<&mut RunnableGraph> {
let input_mark = Session::hash_inputs(inputs);
match self.graphs.entry(input_mark) {
Entry::Vacant(entry) => Ok(entry.insert(
OptimizedGraph::clone(&*self.original_graph)
.with_input_names(inputs.iter().map(|(n, _)| n))?
.into_runnable()?
)),
Entry::Occupied(entry) => Ok(entry.into_mut())
}
}
}

pub struct Session {
pub inputs: Vec<ValueInfoProto>,
pub outputs: Vec<ValueInfoProto>,
pub original_graph: Arc<OptimizedGraph>,
locked_inner: Mutex<SessionLockedInner>
}

impl Session {
pub fn from_buffer(env: &Environment, options: &SessionOptions, mut data: &[u8]) -> TractResult<Session> {
let proto_model = env.onnx.proto_model_for_read(&mut data)?;
let inputs = proto_model.graph.as_ref().map(|graph| graph.input.clone()).unwrap_or_default();
let outputs = proto_model.graph.as_ref().map(|graph| graph.output.clone()).unwrap_or_default();

let model = env.onnx.model_for_proto_model(&proto_model)?;
let graph = Arc::new(if options.perform_optimizations { model.into_optimized()? } else { model.into_typed()? });
Ok(Session {
inputs,
outputs,
original_graph: Arc::clone(&graph),
locked_inner: Mutex::new(SessionLockedInner::new(graph))
})
}

fn hash_inputs(inputs: &[(String, Tensor)]) -> u64 {
let mut hasher = DefaultHasher::new();
for (name, _) in inputs {
hasher.write_u64(name.len() as _);
hasher.write(name.as_bytes());
hasher.write_u8(0);
}
hasher.finish()
}

pub fn run(&self, inputs: Vec<(String, Tensor)>) -> TractResult<Vec<(String, Tensor)>> {
let mut inner = self.locked_inner.lock();
let graph = inner.get_graph(&inputs)?;
let outputs = graph.run(inputs.into_iter().map(|(_, v)| tract_onnx::prelude::TValue::from(v)).collect())?;
Ok(outputs
.into_iter()
.enumerate()
.map(|(i, v)| (self.outputs[i].name.clone(), v.into_tensor()))
.collect())
}
}

struct PassthroughHasher(u64);

impl Hasher for PassthroughHasher {
fn write(&mut self, _: &[u8]) {
unreachable!()
}

fn write_u64(&mut self, i: u64) {
self.0 = i;
}

fn finish(&self) -> u64 {
self.0
}
}

struct PassthroughHashBuilder;
impl BuildHasher for PassthroughHashBuilder {
type Hasher = PassthroughHasher;

fn build_hasher(&self) -> Self::Hasher {
PassthroughHasher(0)
}
}
16 changes: 16 additions & 0 deletions backends/tract/tensor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use tract_onnx::prelude::DatumType;

pub struct TypeInfo {
pub dtype: DatumType,
pub shape: Vec<i64>
}

impl TypeInfo {
pub fn new_sys(dtype: DatumType, shape: Vec<i64>) -> *mut ort_sys::OrtTypeInfo {
(Box::leak(Box::new(Self { dtype, shape })) as *mut TypeInfo).cast()
}

pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtTypeInfo) -> Box<TypeInfo> {
Box::from_raw(ptr.cast::<TypeInfo>())
}
}
16 changes: 16 additions & 0 deletions backends/tract/tests/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use ort::memory::{AllocationDevice, Allocator, DeviceType};

#[test]
fn test_memory_info_apis() {
ort::set_api(ort_tract::api());

let allocator = Allocator::default();

let memory_info = allocator.memory_info();
assert_eq!(memory_info.allocation_device(), AllocationDevice::CPU);
assert_eq!(memory_info.device_type(), DeviceType::CPU);
assert_eq!(memory_info.device_id(), 0);

let memory_info_clone = memory_info.clone();
assert_eq!(memory_info, memory_info_clone);
}
62 changes: 62 additions & 0 deletions backends/tract/tests/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::path::Path;

use image::{ImageBuffer, Luma, Pixel, imageops::FilterType};
use ort::{
inputs,
session::{Session, builder::GraphOptimizationLevel},
tensor::ArrayExtensions,
value::TensorRef
};

#[test]
fn mnist_5() -> ort::Result<()> {
const IMAGE_TO_LOAD: &str = "mnist_5.jpg";

ort::set_api(ort_tract::api());

let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")
.expect("Could not download model from file");

// Load image and resize to model's shape, converting to RGB format
let image_buffer: ImageBuffer<Luma<u8>, Vec<u8>> = image::open(
Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.parent()
.unwrap()
.join("tests")
.join("data")
.join(IMAGE_TO_LOAD)
)
.unwrap()
.resize(28, 28, FilterType::Nearest)
.to_luma8();

let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| {
let pixel = image_buffer.get_pixel(i as u32, j as u32);
let channels = pixel.channels();

// range [0, 255] -> range [0, 1]
(channels[c] as f32) / 255.0
});

// Perform the inference
let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?;

let mut probabilities: Vec<(usize, f32)> = outputs[0]
.try_extract_tensor()?
.softmax(ndarray::Axis(1))
.iter()
.copied()
.enumerate()
.collect::<Vec<_>>();

// Sort probabilities so highest is at beginning of vector.
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());

assert_eq!(probabilities[0].0, 5, "Expecting class for {} is '5' (not {})", IMAGE_TO_LOAD, probabilities[0].0);

Ok(())
}
17 changes: 17 additions & 0 deletions backends/tract/tests/tensor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use ort::value::Tensor;

#[test]
fn test_tensors() -> ort::Result<()> {
ort::set_api(ort_tract::api());

let mut tensor = Tensor::<i64>::from_array((vec![5], vec![0, 1, 2, 3, 4]))?;
let ptr = tensor.data_ptr_mut()?.cast::<i64>();
unsafe {
*ptr.add(3) = 42;
};

let (_, extracted) = tensor.extract_raw_tensor();
assert_eq!(&extracted, &[0, 1, 2, 42, 4]);

Ok(())
}