Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'v2' of https://github.com/JewishLewish/ort into v2
Browse files Browse the repository at this point in the history
JewishLewish committed Nov 13, 2023
2 parents 9a73279 + 0f4eaaa commit 1559569
Showing 2 changed files with 7 additions and 66 deletions.
31 changes: 0 additions & 31 deletions examples/Readme.md

This file was deleted.

42 changes: 7 additions & 35 deletions examples/gpt2/examples/gpt2.rs
Original file line number Diff line number Diff line change
@@ -8,56 +8,25 @@ use ort::{download::language::machine_comprehension::GPT2, inputs, CUDAExecution
use rand::Rng;
use tokenizers::Tokenizer;

/// Prompt
const PROMPT: &str = "The corsac fox (Vulpes corsac), also known simply as a corsac, is a medium-sized fox found in";
/// Max Tokens to Generate
/// Max tokens to generate
const GEN_TOKENS: i32 = 90;
/// Top_K -> Sample from the k most likely next tokens at each step. Lower k focuses on higher probability tokens.
const TOP_K: usize = 5;

/// GPT-2 Text Generation
///
/// This Rust program demonstrates text generation using the GPT-2 language model with the ONNX Runtime.
/// This Rust program demonstrates text generation using the GPT-2 language model with `ort`.
/// The program initializes the model, tokenizes a prompt, and generates a sequence of tokens.
/// It utilizes top-k sampling for diverse and contextually relevant text generation.
///
/// # Constants
/// - `PROMPT`: The initial prompt for text generation.
/// - `GEN_TOKENS`: The maximum number of tokens to generate.
/// - `TOP_K`: Parameter for top-k sampling, influencing the diversity of generated text.
///
/// # Usage
/// Ensure that the required dependencies are installed and run the Rust script to generate text using the GPT-2 model.
///
/// # Main Function
/// The main function initializes dependencies, loads the GPT-2 model, tokenizes the prompt,
/// and iteratively generates text based on the model's output probabilities.
///
/// ## Steps
/// 1. Initialize tracing, stdout, and the random number generator.
/// 2. Create the ONNX Runtime environment and session for the GPT-2 model.
/// 3. Load the tokenizer and encode the prompt into a sequence of tokens.
/// 4. Iteratively generate tokens using the GPT-2 model and top-k sampling.
/// 5. Print the generated text to the console.
///
/// # Panics
/// The program panics if there is an issue with the ONNX Runtime or tokenizer.
///
/// # Errors
/// Returns an `ort::Result` indicating success or an error during ONNX Runtime execution.
///
/// # Examples
/// ```rust
/// fn main() -> ort::Result<()> {
/// // ... (see the main function for the complete example)
/// }
/// ```
fn main() -> ort::Result<()> {
/// Initialize tracing to receive debug messages from `ort`
tracing_subscriber::fmt::init();

let mut stdout = io::stdout();
let mut rng = rand::thread_rng();

/// Create the ONNX Runtime environment and session for the GPT-2 model.
let environment = Environment::builder()
.with_name("GPT-2")
.with_execution_providers([CUDAExecutionProvider::default().build()])
@@ -69,6 +38,7 @@ fn main() -> ort::Result<()> {
.with_intra_threads(1)?
.with_model_downloaded(GPT2::GPT2LmHead)?;

/// Load the tokenizer and encode the prompt into a sequence of tokens.
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap();
let tokens = tokenizer.encode(PROMPT, false).unwrap();
let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();
@@ -84,6 +54,7 @@ fn main() -> ort::Result<()> {
let generated_tokens: Tensor<f32> = outputs["output1"].extract_tensor()?;
let generated_tokens = generated_tokens.view();

/// Collect and sort logits
let probabilities = &mut generated_tokens
.slice(s![0, 0, -1, ..])
.insert_axis(Axis(0))
@@ -94,6 +65,7 @@ fn main() -> ort::Result<()> {
.collect::<Vec<_>>();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));

/// Sample using top-k sampling
let token = probabilities[rng.gen_range(0..=TOP_K)].0;
tokens = concatenate![Axis(0), tokens, array![token.try_into().unwrap()]];

0 comments on commit 1559569

Please sign in to comment.