diff --git a/examples/gpt2/examples/gpt2.rs b/examples/gpt2/examples/gpt2.rs index 7071bd22..c976284a 100644 --- a/examples/gpt2/examples/gpt2.rs +++ b/examples/gpt2/examples/gpt2.rs @@ -9,15 +9,24 @@ use rand::Rng; use tokenizers::Tokenizer; const PROMPT: &str = "The corsac fox (Vulpes corsac), also known simply as a corsac, is a medium-sized fox found in"; +/// Max tokens to generate const GEN_TOKENS: i32 = 90; +/// Top_K -> Sample from the k most likely next tokens at each step. Lower k focuses on higher probability tokens. const TOP_K: usize = 5; +/// GPT-2 Text Generation +/// +/// This Rust program demonstrates text generation using the GPT-2 language model with `ort`. +/// The program initializes the model, tokenizes a prompt, and generates a sequence of tokens. +/// It utilizes top-k sampling for diverse and contextually relevant text generation. fn main() -> ort::Result<()> { + /// Initialize tracing to receive debug messages from `ort` tracing_subscriber::fmt::init(); let mut stdout = io::stdout(); let mut rng = rand::thread_rng(); + /// Create the ONNX Runtime environment and session for the GPT-2 model. let environment = Environment::builder() .with_name("GPT-2") .with_execution_providers([CUDAExecutionProvider::default().build()]) @@ -29,6 +38,7 @@ fn main() -> ort::Result<()> { .with_intra_threads(1)? .with_model_downloaded(GPT2::GPT2LmHead)?; + /// Load the tokenizer and encode the prompt into a sequence of tokens. let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap(); let tokens = tokenizer.encode(PROMPT, false).unwrap(); let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); @@ -44,6 +54,7 @@ fn main() -> ort::Result<()> { let generated_tokens: Tensor = outputs["output1"].extract_tensor()?; let generated_tokens = generated_tokens.view(); + /// Collect and sort logits let probabilities = &mut generated_tokens .slice(s![0, 0, -1, ..]) .insert_axis(Axis(0)) @@ -54,6 +65,7 @@ fn main() -> ort::Result<()> { .collect::>(); probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); + /// Sample using top-k sampling let token = probabilities[rng.gen_range(0..=TOP_K)].0; tokens = concatenate![Axis(0), tokens, array![token.try_into().unwrap()]]; diff --git a/ort-sys/build.rs b/ort-sys/build.rs index aac4fab1..6132c9f7 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -259,7 +259,7 @@ fn extract_zip(filename: &Path, outpath: &Path) { fn copy_libraries(lib_dir: &Path, out_dir: &Path) { // get the target directory - we need to place the dlls next to the executable so they can be properly loaded by windows - let out_dir = out_dir.parent().unwrap().parent().unwrap().parent().unwrap(); + let out_dir = out_dir.ancestors().nth(3).unwrap(); let lib_files = fs::read_dir(lib_dir).unwrap(); for lib_file in lib_files.filter(|e| {