-
-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
examples: refactor
all-mini-lm-l6
for semantic similarity
- Loading branch information
1 parent
1a10b11
commit c5538f2
Showing
6 changed files
with
65 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
2 changes: 1 addition & 1 deletion
2
examples/all-mini-lm-l6/Cargo.toml → examples/sentence-transformers/Cargo.toml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
61 changes: 61 additions & 0 deletions
61
examples/sentence-transformers/examples/semantic-similarity.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
use std::path::Path; | ||
|
||
use ndarray::{s, Array1, Array2, Axis, Ix2}; | ||
use ort::{CUDAExecutionProvider, GraphOptimizationLevel, Session}; | ||
use tokenizers::Tokenizer; | ||
|
||
/// Example usage of a text embedding model like Sentence Transformers' `all-mini-lm-l6` model for semantic textual similarity. | ||
/// | ||
/// Text embedding models map sentences & paragraphs to an n-dimensional dense vector space, which can then be used for | ||
/// tasks like clustering or semantic search. | ||
fn main() -> ort::Result<()> { | ||
// Initialize tracing to receive debug messages from `ort` | ||
tracing_subscriber::fmt::init(); | ||
|
||
// Create the ONNX Runtime environment, enabling CUDA execution providers for all sessions created in this process. | ||
ort::init() | ||
.with_name("sbert") | ||
.with_execution_providers([CUDAExecutionProvider::default().build()]) | ||
.commit()?; | ||
|
||
// Load our model | ||
let session = Session::builder()? | ||
.with_optimization_level(GraphOptimizationLevel::Level1)? | ||
.with_intra_threads(1)? | ||
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/all-MiniLM-L6-v2.onnx")?; | ||
|
||
// Load the tokenizer and encode the text. | ||
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap(); | ||
|
||
let inputs = vec!["The weather outside is lovely.", "It's so sunny outside!", "She drove to the stadium."]; | ||
|
||
// Encode our input strings. `encode_batch` will pad each input to be the same length. | ||
let encodings = tokenizer.encode_batch(inputs.clone(), false)?; | ||
|
||
// Get the padded length of each encoding. | ||
let padded_token_length = encodings[0].len(); | ||
|
||
// Get our token IDs & mask as a flattened array. | ||
let ids: Vec<i64> = encodings.iter().flat_map(|e| e.get_ids().iter().map(|i| *i as i64)).collect(); | ||
let mask: Vec<i64> = encodings.iter().flat_map(|e| e.get_attention_mask().iter().map(|i| *i as i64)).collect(); | ||
|
||
// Convert our flattened arrays into 2-dimensional tensors of shape [N, L]. | ||
let a_ids = Array2::from_shape_vec([inputs.len(), padded_token_length], ids).unwrap(); | ||
let a_mask = Array2::from_shape_vec([inputs.len(), padded_token_length], mask).unwrap(); | ||
|
||
// Run the model. | ||
let outputs = session.run(ort::inputs![a_ids, a_mask]?)?; | ||
|
||
// Extract our embeddings tensor and convert it to a strongly-typed 2-dimensional array. | ||
let embeddings = outputs[1].try_extract_tensor::<f32>()?.into_dimensionality::<Ix2>().unwrap(); | ||
|
||
println!("Similarity for '{}'", inputs[0]); | ||
let query = embeddings.index_axis(Axis(0), 0); | ||
for (embeddings, sentence) in embeddings.axis_iter(Axis(0)).zip(inputs.iter()).skip(1) { | ||
// Calculate cosine similarity against the 'query' sentence. | ||
let dot_product: f32 = query.iter().zip(embeddings.iter()).map(|(a, b)| a * b).sum(); | ||
println!("\t'{}': {:.1}%", sentence, dot_product * 100.); | ||
} | ||
|
||
Ok(()) | ||
} |