Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Session::run_async #174

Merged
merged 6 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[workspace]
members = [
'ort-sys',
'examples/async-gpt2-api',
'examples/custom-ops',
'examples/gpt2',
'examples/model-info',
Expand All @@ -9,6 +10,7 @@ members = [
]
default-members = [
'.',
'examples/async-gpt2-api',
'examples/custom-ops',
'examples/gpt2',
'examples/model-info',
Expand Down Expand Up @@ -91,10 +93,12 @@ winapi = { version = "0.3", optional = true, features = [ "std", "libloaderapi"
[dev-dependencies]
anyhow = "1.0"
ureq = "2.1"
image = "0.24"
image = "0.25"
test-log = { version = "0.2", default-features = false, features = [ "trace" ] }
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
glassbench = "0.4"
tokio = { version = "1.36", features = [ "test-util" ] }
tokio-test = "0.4.3"

[[bench]]
name = "squeezenet"
Expand Down
25 changes: 25 additions & 0 deletions examples/async-gpt2-api/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[package]
publish = false
name = "example-async-gpt2-api"
version = "0.0.0"
edition = "2021"

[dependencies]
ort = { path = "../../", features = [ "fetch-models" ] }
ndarray = "0.15"
tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] }
rand = "0.8"
tracing = "0.1"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
futures = "0.3"
headers = "0.4"
axum = { version = "0.7", features = [ "json" ] }
tokio = { version = "1.36", features = [ "full" ] }
tokio-stream = "0.1"
tower-http = { version = "0.5", features = ["fs", "trace"] }
anyhow = "1.0"
async-stream = "0.3"

[features]
load-dynamic = [ "ort/load-dynamic" ]
cuda = [ "ort/cuda" ]
1 change: 1 addition & 0 deletions examples/async-gpt2-api/data/tokenizer.json

Large diffs are not rendered by default.

109 changes: 109 additions & 0 deletions examples/async-gpt2-api/examples/async-gpt2-api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use std::{path::Path, sync::Arc};

use axum::{
extract::{FromRef, State},
response::{
sse::{Event, KeepAlive},
Sse
},
routing::post,
Router
};
use futures::Stream;
use ndarray::{array, concatenate, s, Array1, ArrayViewD, Axis};
use ort::{inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session, Value};
use rand::Rng;
use tokenizers::Tokenizer;
use tokio::net::TcpListener;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Initialize tracing to receive debug messages from `ort`
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info,ort=debug".into()))
.with(tracing_subscriber::fmt::layer())
.init();

// Create the ONNX Runtime environment, enabling CUDA execution providers for all sessions created in this process.
ort::init()
.with_name("GPT-2")
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;

// Load our model
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(4)?
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/gpt2.onnx")?;

// Load the tokenizer and encode the prompt into a sequence of tokens.
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap();

let app_state = AppState {
session: Arc::new(session),
tokenizer: Arc::new(tokenizer)
};

let app = Router::new().route("/generate", post(generate)).with_state(app_state).into_make_service();
let listener = TcpListener::bind("127.0.0.1:7216").await?;
tracing::info!("Listening on {}", listener.local_addr()?);

axum::serve(listener, app).await?;

Ok(())
}

#[derive(Clone)]
struct AppState {
session: Arc<Session>,
tokenizer: Arc<Tokenizer>
}

fn generate_stream(tokenizer: Arc<Tokenizer>, session: Arc<Session>, tokens: Vec<i64>, gen_tokens: usize) -> impl Stream<Item = ort::Result<Event>> + Send {
async_stream::try_stream! {
let mut tokens = Array1::from_iter(tokens.iter().cloned());

for _ in 0..gen_tokens {
let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1));
let outputs = session.run_async(inputs![array]?)?.await?;
let generated_tokens: ArrayViewD<f32> = outputs["output1"].extract_tensor()?;

// Collect and sort logits
let probabilities = &mut generated_tokens
.slice(s![0, 0, -1, ..])
.insert_axis(Axis(0))
.to_owned()
.iter()
.cloned()
.enumerate()
.collect::<Vec<_>>();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));

// Sample using top-k sampling
let token = {
let mut rng = rand::thread_rng();
probabilities[rng.gen_range(0..=5)].0
};
tokens = concatenate![Axis(0), tokens, array![token.try_into().unwrap()]];

let token_str = tokenizer.decode(&[token as _], true).unwrap();
yield Event::default().data(token_str);
}
}
}

impl FromRef<AppState> for Arc<Session> {
fn from_ref(input: &AppState) -> Self {
Arc::clone(&input.session)
}
}
impl FromRef<AppState> for Arc<Tokenizer> {
fn from_ref(input: &AppState) -> Self {
Arc::clone(&input.tokenizer)
}
}

async fn generate(State(session): State<Arc<Session>>, State(tokenizer): State<Arc<Tokenizer>>) -> Sse<impl Stream<Item = ort::Result<Event>>> {
Sse::new(generate_stream(tokenizer, session, vec![0], 50)).keep_alive(KeepAlive::new())
}
178 changes: 178 additions & 0 deletions src/session/async.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use std::{
cell::UnsafeCell,
ffi::{c_char, CString},
future::Future,
mem::MaybeUninit,
pin::Pin,
ptr::NonNull,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex
},
task::{Context, Poll, Waker}
};

use ort_sys::{c_void, OrtStatus};

use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SessionInputValue, SessionOutputs, SharedSessionInner, Value};

pub(crate) enum InnerValue<T> {
Present(T),
Pending,
Closed
}

const VALUE_PRESENT: usize = 1 << 0;
const CHANNEL_CLOSED: usize = 1 << 1;

#[derive(Debug)]
pub(crate) struct InferenceFutInner<'s> {
presence: AtomicUsize,
value: UnsafeCell<MaybeUninit<Result<SessionOutputs<'s>>>>,
waker: Mutex<Option<Waker>>
}

impl<'s> InferenceFutInner<'s> {
pub(crate) fn new() -> Self {

Check warning on line 36 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L36

Added line #L36 was not covered by tests
InferenceFutInner {
presence: AtomicUsize::new(0),
waker: Mutex::new(None),
value: UnsafeCell::new(MaybeUninit::uninit())

Check warning on line 40 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L38-L40

Added lines #L38 - L40 were not covered by tests
}
}

pub(crate) fn try_take(&self) -> InnerValue<Result<SessionOutputs<'s>>> {
let state_snapshot = self.presence.fetch_and(!VALUE_PRESENT, Ordering::Acquire);
if state_snapshot & VALUE_PRESENT == 0 {
if self.presence.load(Ordering::Acquire) & CHANNEL_CLOSED != 0 {
InnerValue::Closed

Check warning on line 48 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L44-L48

Added lines #L44 - L48 were not covered by tests
} else {
InnerValue::Pending

Check warning on line 50 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L50

Added line #L50 was not covered by tests
}
} else {
InnerValue::Present(unsafe { (*self.value.get()).assume_init_read() })

Check warning on line 53 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L53

Added line #L53 was not covered by tests
}
}

pub(crate) fn emplace_value(&self, value: Result<SessionOutputs<'s>>) {
unsafe { (*self.value.get()).write(value) };
self.presence.fetch_or(VALUE_PRESENT, Ordering::Release);

Check warning on line 59 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L57-L59

Added lines #L57 - L59 were not covered by tests
}

pub(crate) fn set_waker(&self, waker: Option<&Waker>) {
*self.waker.lock().unwrap() = waker.map(|c| c.to_owned());

Check warning on line 63 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L62-L63

Added lines #L62 - L63 were not covered by tests
}

pub(crate) fn wake(&self) {
if let Some(waker) = self.waker.lock().unwrap().take() {
waker.wake();

Check warning on line 68 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L66-L68

Added lines #L66 - L68 were not covered by tests
}
}

pub(crate) fn close(&self) -> bool {
self.presence.fetch_or(CHANNEL_CLOSED, Ordering::Acquire) & CHANNEL_CLOSED == 0

Check warning on line 73 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L72-L73

Added lines #L72 - L73 were not covered by tests
}
}

impl<'s> Drop for InferenceFutInner<'s> {
fn drop(&mut self) {
if self.presence.load(Ordering::Acquire) & VALUE_PRESENT != 0 {
unsafe { (*self.value.get()).assume_init_drop() };

Check warning on line 80 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L78-L80

Added lines #L78 - L80 were not covered by tests
}
}
}

unsafe impl<'s> Send for InferenceFutInner<'s> {}
unsafe impl<'s> Sync for InferenceFutInner<'s> {}

pub struct InferenceFut<'s> {
inner: Arc<InferenceFutInner<'s>>,
run_options: Arc<RunOptions>,
did_receive: bool
}

impl<'s> InferenceFut<'s> {
pub(crate) fn new(inner: Arc<InferenceFutInner<'s>>, run_options: Arc<RunOptions>) -> Self {

Check warning on line 95 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L95

Added line #L95 was not covered by tests
Self {
inner,
run_options,
did_receive: false
}
}
}

impl<'s> Future for InferenceFut<'s> {
type Output = Result<SessionOutputs<'s>>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = Pin::into_inner(self);

Check warning on line 108 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L107-L108

Added lines #L107 - L108 were not covered by tests

match this.inner.try_take() {
InnerValue::Present(v) => {
this.did_receive = true;
return Poll::Ready(v);

Check warning on line 113 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L110-L113

Added lines #L110 - L113 were not covered by tests
}
InnerValue::Pending => {}
InnerValue::Closed => panic!()

Check warning on line 116 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L115-L116

Added lines #L115 - L116 were not covered by tests
};

this.inner.set_waker(Some(cx.waker()));

Check warning on line 119 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L119

Added line #L119 was not covered by tests

Poll::Pending

Check warning on line 121 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L121

Added line #L121 was not covered by tests
}
}

impl<'s> Drop for InferenceFut<'s> {
fn drop(&mut self) {
if !self.did_receive && self.inner.close() {
let _ = self.run_options.terminate();
self.inner.set_waker(None);

Check warning on line 129 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L126-L129

Added lines #L126 - L129 were not covered by tests
}
}
}

pub(crate) struct AsyncInferenceContext<'s> {
pub(crate) inner: Arc<InferenceFutInner<'s>>,
pub(crate) _input_values: Vec<SessionInputValue<'s>>,
pub(crate) input_ort_values: Vec<*const ort_sys::OrtValue>,
pub(crate) input_name_ptrs: Vec<*const c_char>,
pub(crate) output_name_ptrs: Vec<*const c_char>,
pub(crate) session_inner: &'s Arc<SharedSessionInner>,
pub(crate) output_names: Vec<&'s str>,
pub(crate) output_value_ptrs: Vec<*mut ort_sys::OrtValue>
}

crate::extern_system_fn! {
pub(crate) fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: ort_sys::size_t, status: *mut OrtStatus) {
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_>>()) };

Check warning on line 147 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L147

Added line #L147 was not covered by tests

// Reconvert name ptrs to CString so drop impl is called and memory is freed
drop(
ctx.input_name_ptrs

Check warning on line 151 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L150-L151

Added lines #L150 - L151 were not covered by tests
.into_iter()
.chain(ctx.output_name_ptrs)
.map(|p| {
assert_non_null_pointer(p, "c_char for CString")?;
unsafe { Ok(CString::from_raw(p.cast_mut().cast())) }

Check warning on line 156 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L153-L156

Added lines #L153 - L156 were not covered by tests
})
.collect::<Result<Vec<_>>>()
.unwrap()
);

if let Err(e) = crate::error::status_to_result(status) {
ctx.inner.emplace_value(Err(Error::SessionRun(e)));
ctx.inner.wake();

Check warning on line 164 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L162-L164

Added lines #L162 - L164 were not covered by tests
}

let outputs: Vec<Value> = ctx

Check warning on line 167 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L167

Added line #L167 was not covered by tests
.output_value_ptrs
.into_iter()
.map(|tensor_ptr| unsafe {
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), Some(Arc::clone(ctx.session_inner)))

Check warning on line 171 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L170-L171

Added lines #L170 - L171 were not covered by tests
})
.collect();

Check warning on line 173 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L173

Added line #L173 was not covered by tests

ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names.into_iter(), outputs)));
ctx.inner.wake();

Check warning on line 176 in src/session/async.rs

View check run for this annotation

Codecov / codecov/patch

src/session/async.rs#L175-L176

Added lines #L175 - L176 were not covered by tests
}
}
Loading
Loading