-
-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: `Session::run_async` * fix: set intra threads in doctests * ci(code-quality): cover doctests too * feat: make `InferenceFut` cancel safe * ci(code-quality): revert * minor cleanup
1 parent
3035b07
commit 979a591
Showing
6 changed files
with
409 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
[package] | ||
publish = false | ||
name = "example-async-gpt2-api" | ||
version = "0.0.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
ort = { path = "../../", features = [ "fetch-models" ] } | ||
ndarray = "0.15" | ||
tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] } | ||
rand = "0.8" | ||
tracing = "0.1" | ||
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } | ||
futures = "0.3" | ||
headers = "0.4" | ||
axum = { version = "0.7", features = [ "json" ] } | ||
tokio = { version = "1.36", features = [ "full" ] } | ||
tokio-stream = "0.1" | ||
tower-http = { version = "0.5", features = ["fs", "trace"] } | ||
anyhow = "1.0" | ||
async-stream = "0.3" | ||
|
||
[features] | ||
load-dynamic = [ "ort/load-dynamic" ] | ||
cuda = [ "ort/cuda" ] |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
use std::{path::Path, sync::Arc}; | ||
|
||
use axum::{ | ||
extract::{FromRef, State}, | ||
response::{ | ||
sse::{Event, KeepAlive}, | ||
Sse | ||
}, | ||
routing::post, | ||
Router | ||
}; | ||
use futures::Stream; | ||
use ndarray::{array, concatenate, s, Array1, ArrayViewD, Axis}; | ||
use ort::{inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session, Value}; | ||
use rand::Rng; | ||
use tokenizers::Tokenizer; | ||
use tokio::net::TcpListener; | ||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; | ||
|
||
#[tokio::main] | ||
async fn main() -> anyhow::Result<()> { | ||
// Initialize tracing to receive debug messages from `ort` | ||
tracing_subscriber::registry() | ||
.with(tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info,ort=debug".into())) | ||
.with(tracing_subscriber::fmt::layer()) | ||
.init(); | ||
|
||
// Create the ONNX Runtime environment, enabling CUDA execution providers for all sessions created in this process. | ||
ort::init() | ||
.with_name("GPT-2") | ||
.with_execution_providers([CUDAExecutionProvider::default().build()]) | ||
.commit()?; | ||
|
||
// Load our model | ||
let session = Session::builder()? | ||
.with_optimization_level(GraphOptimizationLevel::Level1)? | ||
.with_intra_threads(4)? | ||
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/gpt2.onnx")?; | ||
|
||
// Load the tokenizer and encode the prompt into a sequence of tokens. | ||
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap(); | ||
|
||
let app_state = AppState { | ||
session: Arc::new(session), | ||
tokenizer: Arc::new(tokenizer) | ||
}; | ||
|
||
let app = Router::new().route("/generate", post(generate)).with_state(app_state).into_make_service(); | ||
let listener = TcpListener::bind("127.0.0.1:7216").await?; | ||
tracing::info!("Listening on {}", listener.local_addr()?); | ||
|
||
axum::serve(listener, app).await?; | ||
|
||
Ok(()) | ||
} | ||
|
||
#[derive(Clone)] | ||
struct AppState { | ||
session: Arc<Session>, | ||
tokenizer: Arc<Tokenizer> | ||
} | ||
|
||
fn generate_stream(tokenizer: Arc<Tokenizer>, session: Arc<Session>, tokens: Vec<i64>, gen_tokens: usize) -> impl Stream<Item = ort::Result<Event>> + Send { | ||
async_stream::try_stream! { | ||
let mut tokens = Array1::from_iter(tokens.iter().cloned()); | ||
|
||
for _ in 0..gen_tokens { | ||
let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1)); | ||
let outputs = session.run_async(inputs![array]?)?.await?; | ||
let generated_tokens: ArrayViewD<f32> = outputs["output1"].extract_tensor()?; | ||
|
||
// Collect and sort logits | ||
let probabilities = &mut generated_tokens | ||
.slice(s![0, 0, -1, ..]) | ||
.insert_axis(Axis(0)) | ||
.to_owned() | ||
.iter() | ||
.cloned() | ||
.enumerate() | ||
.collect::<Vec<_>>(); | ||
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); | ||
|
||
// Sample using top-k sampling | ||
let token = { | ||
let mut rng = rand::thread_rng(); | ||
probabilities[rng.gen_range(0..=5)].0 | ||
}; | ||
tokens = concatenate![Axis(0), tokens, array![token.try_into().unwrap()]]; | ||
|
||
let token_str = tokenizer.decode(&[token as _], true).unwrap(); | ||
yield Event::default().data(token_str); | ||
} | ||
} | ||
} | ||
|
||
impl FromRef<AppState> for Arc<Session> { | ||
fn from_ref(input: &AppState) -> Self { | ||
Arc::clone(&input.session) | ||
} | ||
} | ||
impl FromRef<AppState> for Arc<Tokenizer> { | ||
fn from_ref(input: &AppState) -> Self { | ||
Arc::clone(&input.tokenizer) | ||
} | ||
} | ||
|
||
async fn generate(State(session): State<Arc<Session>>, State(tokenizer): State<Arc<Tokenizer>>) -> Sse<impl Stream<Item = ort::Result<Event>>> { | ||
Sse::new(generate_stream(tokenizer, session, vec![0], 50)).keep_alive(KeepAlive::new()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
use std::{ | ||
cell::UnsafeCell, | ||
ffi::{c_char, CString}, | ||
future::Future, | ||
mem::MaybeUninit, | ||
pin::Pin, | ||
ptr::NonNull, | ||
sync::{ | ||
atomic::{AtomicUsize, Ordering}, | ||
Arc, Mutex | ||
}, | ||
task::{Context, Poll, Waker} | ||
}; | ||
|
||
use ort_sys::{c_void, OrtStatus}; | ||
|
||
use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SessionInputValue, SessionOutputs, SharedSessionInner, Value}; | ||
|
||
pub(crate) enum InnerValue<T> { | ||
Present(T), | ||
Pending, | ||
Closed | ||
} | ||
|
||
const VALUE_PRESENT: usize = 1 << 0; | ||
const CHANNEL_CLOSED: usize = 1 << 1; | ||
|
||
#[derive(Debug)] | ||
pub(crate) struct InferenceFutInner<'s> { | ||
presence: AtomicUsize, | ||
value: UnsafeCell<MaybeUninit<Result<SessionOutputs<'s>>>>, | ||
waker: Mutex<Option<Waker>> | ||
} | ||
|
||
impl<'s> InferenceFutInner<'s> { | ||
pub(crate) fn new() -> Self { | ||
InferenceFutInner { | ||
presence: AtomicUsize::new(0), | ||
waker: Mutex::new(None), | ||
value: UnsafeCell::new(MaybeUninit::uninit()) | ||
} | ||
} | ||
|
||
pub(crate) fn try_take(&self) -> InnerValue<Result<SessionOutputs<'s>>> { | ||
let state_snapshot = self.presence.fetch_and(!VALUE_PRESENT, Ordering::Acquire); | ||
if state_snapshot & VALUE_PRESENT == 0 { | ||
if self.presence.load(Ordering::Acquire) & CHANNEL_CLOSED != 0 { | ||
InnerValue::Closed | ||
} else { | ||
InnerValue::Pending | ||
} | ||
} else { | ||
InnerValue::Present(unsafe { (*self.value.get()).assume_init_read() }) | ||
} | ||
} | ||
|
||
pub(crate) fn emplace_value(&self, value: Result<SessionOutputs<'s>>) { | ||
unsafe { (*self.value.get()).write(value) }; | ||
self.presence.fetch_or(VALUE_PRESENT, Ordering::Release); | ||
} | ||
|
||
pub(crate) fn set_waker(&self, waker: Option<&Waker>) { | ||
*self.waker.lock().unwrap() = waker.map(|c| c.to_owned()); | ||
} | ||
|
||
pub(crate) fn wake(&self) { | ||
if let Some(waker) = self.waker.lock().unwrap().take() { | ||
waker.wake(); | ||
} | ||
} | ||
|
||
pub(crate) fn close(&self) -> bool { | ||
self.presence.fetch_or(CHANNEL_CLOSED, Ordering::Acquire) & CHANNEL_CLOSED == 0 | ||
} | ||
} | ||
|
||
impl<'s> Drop for InferenceFutInner<'s> { | ||
fn drop(&mut self) { | ||
if self.presence.load(Ordering::Acquire) & VALUE_PRESENT != 0 { | ||
unsafe { (*self.value.get()).assume_init_drop() }; | ||
} | ||
} | ||
} | ||
|
||
unsafe impl<'s> Send for InferenceFutInner<'s> {} | ||
unsafe impl<'s> Sync for InferenceFutInner<'s> {} | ||
|
||
pub struct InferenceFut<'s> { | ||
inner: Arc<InferenceFutInner<'s>>, | ||
run_options: Arc<RunOptions>, | ||
did_receive: bool | ||
} | ||
|
||
impl<'s> InferenceFut<'s> { | ||
pub(crate) fn new(inner: Arc<InferenceFutInner<'s>>, run_options: Arc<RunOptions>) -> Self { | ||
Self { | ||
inner, | ||
run_options, | ||
did_receive: false | ||
} | ||
} | ||
} | ||
|
||
impl<'s> Future for InferenceFut<'s> { | ||
type Output = Result<SessionOutputs<'s>>; | ||
|
||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||
let this = Pin::into_inner(self); | ||
|
||
match this.inner.try_take() { | ||
InnerValue::Present(v) => { | ||
this.did_receive = true; | ||
return Poll::Ready(v); | ||
} | ||
InnerValue::Pending => {} | ||
InnerValue::Closed => panic!() | ||
}; | ||
|
||
this.inner.set_waker(Some(cx.waker())); | ||
|
||
Poll::Pending | ||
} | ||
} | ||
|
||
impl<'s> Drop for InferenceFut<'s> { | ||
fn drop(&mut self) { | ||
if !self.did_receive && self.inner.close() { | ||
let _ = self.run_options.terminate(); | ||
self.inner.set_waker(None); | ||
} | ||
} | ||
} | ||
|
||
pub(crate) struct AsyncInferenceContext<'s> { | ||
pub(crate) inner: Arc<InferenceFutInner<'s>>, | ||
pub(crate) _input_values: Vec<SessionInputValue<'s>>, | ||
pub(crate) input_ort_values: Vec<*const ort_sys::OrtValue>, | ||
pub(crate) input_name_ptrs: Vec<*const c_char>, | ||
pub(crate) output_name_ptrs: Vec<*const c_char>, | ||
pub(crate) session_inner: &'s Arc<SharedSessionInner>, | ||
pub(crate) output_names: Vec<&'s str>, | ||
pub(crate) output_value_ptrs: Vec<*mut ort_sys::OrtValue> | ||
} | ||
|
||
crate::extern_system_fn! { | ||
pub(crate) fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: ort_sys::size_t, status: *mut OrtStatus) { | ||
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_>>()) }; | ||
|
||
// Reconvert name ptrs to CString so drop impl is called and memory is freed | ||
drop( | ||
ctx.input_name_ptrs | ||
.into_iter() | ||
.chain(ctx.output_name_ptrs) | ||
.map(|p| { | ||
assert_non_null_pointer(p, "c_char for CString")?; | ||
unsafe { Ok(CString::from_raw(p.cast_mut().cast())) } | ||
}) | ||
.collect::<Result<Vec<_>>>() | ||
.unwrap() | ||
); | ||
|
||
if let Err(e) = crate::error::status_to_result(status) { | ||
ctx.inner.emplace_value(Err(Error::SessionRun(e))); | ||
ctx.inner.wake(); | ||
} | ||
|
||
let outputs: Vec<Value> = ctx | ||
.output_value_ptrs | ||
.into_iter() | ||
.map(|tensor_ptr| unsafe { | ||
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), Some(Arc::clone(ctx.session_inner))) | ||
}) | ||
.collect(); | ||
|
||
ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names.into_iter(), outputs))); | ||
ctx.inner.wake(); | ||
} | ||
} |
Oops, something went wrong.