Skip to content

Commit

Permalink
feat: Session::run_async (#174)
Browse files Browse the repository at this point in the history
* feat: `Session::run_async`

* fix: set intra threads in doctests

* ci(code-quality): cover doctests too

* feat: make `InferenceFut` cancel safe

* ci(code-quality): revert

* minor cleanup
decahedron1 authored Mar 21, 2024
1 parent 3035b07 commit 979a591
Showing 6 changed files with 409 additions and 1 deletion.
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[workspace]
members = [
'ort-sys',
'examples/async-gpt2-api',
'examples/custom-ops',
'examples/gpt2',
'examples/model-info',
@@ -9,6 +10,7 @@ members = [
]
default-members = [
'.',
'examples/async-gpt2-api',
'examples/custom-ops',
'examples/gpt2',
'examples/model-info',
@@ -91,10 +93,12 @@ winapi = { version = "0.3", optional = true, features = [ "std", "libloaderapi"
[dev-dependencies]
anyhow = "1.0"
ureq = "2.1"
image = "0.24"
image = "0.25"
test-log = { version = "0.2", default-features = false, features = [ "trace" ] }
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
glassbench = "0.4"
tokio = { version = "1.36", features = [ "test-util" ] }
tokio-test = "0.4.3"

[[bench]]
name = "squeezenet"
25 changes: 25 additions & 0 deletions examples/async-gpt2-api/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[package]
publish = false
name = "example-async-gpt2-api"
version = "0.0.0"
edition = "2021"

[dependencies]
ort = { path = "../../", features = [ "fetch-models" ] }
ndarray = "0.15"
tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] }
rand = "0.8"
tracing = "0.1"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
futures = "0.3"
headers = "0.4"
axum = { version = "0.7", features = [ "json" ] }
tokio = { version = "1.36", features = [ "full" ] }
tokio-stream = "0.1"
tower-http = { version = "0.5", features = ["fs", "trace"] }
anyhow = "1.0"
async-stream = "0.3"

[features]
load-dynamic = [ "ort/load-dynamic" ]
cuda = [ "ort/cuda" ]
1 change: 1 addition & 0 deletions examples/async-gpt2-api/data/tokenizer.json

Large diffs are not rendered by default.

109 changes: 109 additions & 0 deletions examples/async-gpt2-api/examples/async-gpt2-api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use std::{path::Path, sync::Arc};

use axum::{
extract::{FromRef, State},
response::{
sse::{Event, KeepAlive},
Sse
},
routing::post,
Router
};
use futures::Stream;
use ndarray::{array, concatenate, s, Array1, ArrayViewD, Axis};
use ort::{inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session, Value};
use rand::Rng;
use tokenizers::Tokenizer;
use tokio::net::TcpListener;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Initialize tracing to receive debug messages from `ort`
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info,ort=debug".into()))
.with(tracing_subscriber::fmt::layer())
.init();

// Create the ONNX Runtime environment, enabling CUDA execution providers for all sessions created in this process.
ort::init()
.with_name("GPT-2")
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;

// Load our model
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(4)?
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/gpt2.onnx")?;

// Load the tokenizer and encode the prompt into a sequence of tokens.
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap();

let app_state = AppState {
session: Arc::new(session),
tokenizer: Arc::new(tokenizer)
};

let app = Router::new().route("/generate", post(generate)).with_state(app_state).into_make_service();
let listener = TcpListener::bind("127.0.0.1:7216").await?;
tracing::info!("Listening on {}", listener.local_addr()?);

axum::serve(listener, app).await?;

Ok(())
}

#[derive(Clone)]
struct AppState {
session: Arc<Session>,
tokenizer: Arc<Tokenizer>
}

fn generate_stream(tokenizer: Arc<Tokenizer>, session: Arc<Session>, tokens: Vec<i64>, gen_tokens: usize) -> impl Stream<Item = ort::Result<Event>> + Send {
async_stream::try_stream! {
let mut tokens = Array1::from_iter(tokens.iter().cloned());

for _ in 0..gen_tokens {
let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1));
let outputs = session.run_async(inputs![array]?)?.await?;
let generated_tokens: ArrayViewD<f32> = outputs["output1"].extract_tensor()?;

// Collect and sort logits
let probabilities = &mut generated_tokens
.slice(s![0, 0, -1, ..])
.insert_axis(Axis(0))
.to_owned()
.iter()
.cloned()
.enumerate()
.collect::<Vec<_>>();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));

// Sample using top-k sampling
let token = {
let mut rng = rand::thread_rng();
probabilities[rng.gen_range(0..=5)].0
};
tokens = concatenate![Axis(0), tokens, array![token.try_into().unwrap()]];

let token_str = tokenizer.decode(&[token as _], true).unwrap();
yield Event::default().data(token_str);
}
}
}

impl FromRef<AppState> for Arc<Session> {
fn from_ref(input: &AppState) -> Self {
Arc::clone(&input.session)
}
}
impl FromRef<AppState> for Arc<Tokenizer> {
fn from_ref(input: &AppState) -> Self {
Arc::clone(&input.tokenizer)
}
}

async fn generate(State(session): State<Arc<Session>>, State(tokenizer): State<Arc<Tokenizer>>) -> Sse<impl Stream<Item = ort::Result<Event>>> {
Sse::new(generate_stream(tokenizer, session, vec![0], 50)).keep_alive(KeepAlive::new())
}
178 changes: 178 additions & 0 deletions src/session/async.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use std::{
cell::UnsafeCell,
ffi::{c_char, CString},
future::Future,
mem::MaybeUninit,
pin::Pin,
ptr::NonNull,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex
},
task::{Context, Poll, Waker}
};

use ort_sys::{c_void, OrtStatus};

use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SessionInputValue, SessionOutputs, SharedSessionInner, Value};

pub(crate) enum InnerValue<T> {
Present(T),
Pending,
Closed
}

const VALUE_PRESENT: usize = 1 << 0;
const CHANNEL_CLOSED: usize = 1 << 1;

#[derive(Debug)]
pub(crate) struct InferenceFutInner<'s> {
presence: AtomicUsize,
value: UnsafeCell<MaybeUninit<Result<SessionOutputs<'s>>>>,
waker: Mutex<Option<Waker>>
}

impl<'s> InferenceFutInner<'s> {
pub(crate) fn new() -> Self {
InferenceFutInner {
presence: AtomicUsize::new(0),
waker: Mutex::new(None),
value: UnsafeCell::new(MaybeUninit::uninit())
}
}

pub(crate) fn try_take(&self) -> InnerValue<Result<SessionOutputs<'s>>> {
let state_snapshot = self.presence.fetch_and(!VALUE_PRESENT, Ordering::Acquire);
if state_snapshot & VALUE_PRESENT == 0 {
if self.presence.load(Ordering::Acquire) & CHANNEL_CLOSED != 0 {
InnerValue::Closed
} else {
InnerValue::Pending
}
} else {
InnerValue::Present(unsafe { (*self.value.get()).assume_init_read() })
}
}

pub(crate) fn emplace_value(&self, value: Result<SessionOutputs<'s>>) {
unsafe { (*self.value.get()).write(value) };
self.presence.fetch_or(VALUE_PRESENT, Ordering::Release);
}

pub(crate) fn set_waker(&self, waker: Option<&Waker>) {
*self.waker.lock().unwrap() = waker.map(|c| c.to_owned());
}

pub(crate) fn wake(&self) {
if let Some(waker) = self.waker.lock().unwrap().take() {
waker.wake();
}
}

pub(crate) fn close(&self) -> bool {
self.presence.fetch_or(CHANNEL_CLOSED, Ordering::Acquire) & CHANNEL_CLOSED == 0
}
}

impl<'s> Drop for InferenceFutInner<'s> {
fn drop(&mut self) {
if self.presence.load(Ordering::Acquire) & VALUE_PRESENT != 0 {
unsafe { (*self.value.get()).assume_init_drop() };
}
}
}

unsafe impl<'s> Send for InferenceFutInner<'s> {}
unsafe impl<'s> Sync for InferenceFutInner<'s> {}

pub struct InferenceFut<'s> {
inner: Arc<InferenceFutInner<'s>>,
run_options: Arc<RunOptions>,
did_receive: bool
}

impl<'s> InferenceFut<'s> {
pub(crate) fn new(inner: Arc<InferenceFutInner<'s>>, run_options: Arc<RunOptions>) -> Self {
Self {
inner,
run_options,
did_receive: false
}
}
}

impl<'s> Future for InferenceFut<'s> {
type Output = Result<SessionOutputs<'s>>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = Pin::into_inner(self);

match this.inner.try_take() {
InnerValue::Present(v) => {
this.did_receive = true;
return Poll::Ready(v);
}
InnerValue::Pending => {}
InnerValue::Closed => panic!()
};

this.inner.set_waker(Some(cx.waker()));

Poll::Pending
}
}

impl<'s> Drop for InferenceFut<'s> {
fn drop(&mut self) {
if !self.did_receive && self.inner.close() {
let _ = self.run_options.terminate();
self.inner.set_waker(None);
}
}
}

pub(crate) struct AsyncInferenceContext<'s> {
pub(crate) inner: Arc<InferenceFutInner<'s>>,
pub(crate) _input_values: Vec<SessionInputValue<'s>>,
pub(crate) input_ort_values: Vec<*const ort_sys::OrtValue>,
pub(crate) input_name_ptrs: Vec<*const c_char>,
pub(crate) output_name_ptrs: Vec<*const c_char>,
pub(crate) session_inner: &'s Arc<SharedSessionInner>,
pub(crate) output_names: Vec<&'s str>,
pub(crate) output_value_ptrs: Vec<*mut ort_sys::OrtValue>
}

crate::extern_system_fn! {
pub(crate) fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: ort_sys::size_t, status: *mut OrtStatus) {
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_>>()) };

// Reconvert name ptrs to CString so drop impl is called and memory is freed
drop(
ctx.input_name_ptrs
.into_iter()
.chain(ctx.output_name_ptrs)
.map(|p| {
assert_non_null_pointer(p, "c_char for CString")?;
unsafe { Ok(CString::from_raw(p.cast_mut().cast())) }
})
.collect::<Result<Vec<_>>>()
.unwrap()
);

if let Err(e) = crate::error::status_to_result(status) {
ctx.inner.emplace_value(Err(Error::SessionRun(e)));
ctx.inner.wake();
}

let outputs: Vec<Value> = ctx
.output_value_ptrs
.into_iter()
.map(|tensor_ptr| unsafe {
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), Some(Arc::clone(ctx.session_inner)))
})
.collect();

ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names.into_iter(), outputs)));
ctx.inner.wake();
}
}
Loading

0 comments on commit 979a591

Please sign in to comment.