Skip to content

Commit

Permalink
refactor: create environment in global OnceLock
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Nov 21, 2023
1 parent 9b1bc6a commit c69064f
Show file tree
Hide file tree
Showing 14 changed files with 124 additions and 442 deletions.
18 changes: 9 additions & 9 deletions examples/gpt2/examples/gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
};

use ndarray::{array, concatenate, s, Array1, Axis};
use ort::{download::language::machine_comprehension::GPT2, inputs, CUDAExecutionProvider, Environment, GraphOptimizationLevel, SessionBuilder, Tensor};
use ort::{download::language::machine_comprehension::GPT2, inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session, Tensor};
use rand::Rng;
use tokenizers::Tokenizer;

Expand All @@ -23,17 +23,17 @@ fn main() -> ort::Result<()> {
// Initialize tracing to receive debug messages from `ort`
tracing_subscriber::fmt::init();

let mut stdout = io::stdout();
let mut rng = rand::thread_rng();

// Create the ONNX Runtime environment and session for the GPT-2 model.
let environment = Environment::builder()
// Create the ONNX Runtime environment, enabling CUDA execution providers for all sessions created in this process.
ort::init()
.with_name("GPT-2")
.with_execution_providers([CUDAExecutionProvider::default().build()])
.build()?
.into_arc();
.commit()?;

let mut stdout = io::stdout();
let mut rng = rand::thread_rng();

let session = SessionBuilder::new(&environment)?
// Load our model
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_downloaded(GPT2::GPT2LmHead)?;
Expand Down
14 changes: 7 additions & 7 deletions examples/yolov8/examples/yolov8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::path::Path;

use image::{imageops::FilterType, GenericImageView};
use ndarray::{s, Array, Axis};
use ort::{inputs, CUDAExecutionProvider, Environment, SessionBuilder, SessionOutputs};
use ort::{inputs, CUDAExecutionProvider, Session, SessionOutputs};
use raqote::{DrawOptions, DrawTarget, LineJoin, PathBuilder, SolidSource, Source, StrokeStyle};
use show_image::{event, AsImageView, WindowOptions};

Expand Down Expand Up @@ -42,6 +42,10 @@ const YOLOV8_CLASS_LABELS: [&str; 80] = [
fn main() -> ort::Result<()> {
tracing_subscriber::fmt::init();

ort::init()
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;

let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("baseball.jpg")).unwrap();
let (img_width, img_height) = (original_img.width(), original_img.height());
let img = original_img.resize_exact(640, 640, FilterType::CatmullRom);
Expand All @@ -55,14 +59,10 @@ fn main() -> ort::Result<()> {
input[[0, 2, y, x]] = (b as f32) / 255.;
}

let env = Environment::builder()
.with_execution_providers([CUDAExecutionProvider::default().build()])
.build()?
.into_arc();
let model = SessionBuilder::new(&env).unwrap().with_model_downloaded(YOLOV8M_URL).unwrap();
let model = Session::builder()?.with_model_downloaded(YOLOV8M_URL)?;

// Run YOLOv8 inference
let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?).unwrap();
let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?)?;
let output = outputs["output0"].extract_tensor::<f32>().unwrap().view().t().into_owned();

let mut boxes = Vec::new();
Expand Down
Loading

0 comments on commit c69064f

Please sign in to comment.