Skip to content

Commit

Permalink
examples: refactor all-mini-lm-l6 for semantic similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Jul 23, 2024
1 parent 1a10b11 commit c5538f2
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 45 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
[workspace]
members = [
'ort-sys',
'examples/all-mini-lm-l6',
'examples/async-gpt2-api',
'examples/custom-ops',
'examples/gpt2',
'examples/model-info',
'examples/yolov8',
'examples/modnet',
'examples/sentence-transformers',
'examples/training',
'examples/webassembly'
]
default-members = [
'.',
'examples/all-mini-lm-l6',
'examples/async-gpt2-api',
'examples/custom-ops',
'examples/gpt2',
'examples/model-info',
'examples/yolov8',
'examples/modnet'
'examples/modnet',
'examples/sentence-transformers'
]
exclude = [ 'examples/cudarc' ]

Expand Down
41 changes: 0 additions & 41 deletions examples/all-mini-lm-l6/examples/all-mini-lm-l6.rs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
publish = false
name = "example-all-mini-lm-l6"
name = "sentence-transformers"
version = "0.0.0"
edition = "2021"

Expand Down
File renamed without changes.
File renamed without changes.
61 changes: 61 additions & 0 deletions examples/sentence-transformers/examples/semantic-similarity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use std::path::Path;

use ndarray::{s, Array1, Array2, Axis, Ix2};
use ort::{CUDAExecutionProvider, GraphOptimizationLevel, Session};
use tokenizers::Tokenizer;

/// Example usage of a text embedding model like Sentence Transformers' `all-mini-lm-l6` model for semantic textual similarity.
///
/// Text embedding models map sentences & paragraphs to an n-dimensional dense vector space, which can then be used for
/// tasks like clustering or semantic search.
fn main() -> ort::Result<()> {
// Initialize tracing to receive debug messages from `ort`
tracing_subscriber::fmt::init();

// Create the ONNX Runtime environment, enabling CUDA execution providers for all sessions created in this process.
ort::init()
.with_name("sbert")
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;

// Load our model
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/all-MiniLM-L6-v2.onnx")?;

// Load the tokenizer and encode the text.
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap();

let inputs = vec!["The weather outside is lovely.", "It's so sunny outside!", "She drove to the stadium."];

// Encode our input strings. `encode_batch` will pad each input to be the same length.
let encodings = tokenizer.encode_batch(inputs.clone(), false)?;

// Get the padded length of each encoding.
let padded_token_length = encodings[0].len();

// Get our token IDs & mask as a flattened array.
let ids: Vec<i64> = encodings.iter().flat_map(|e| e.get_ids().iter().map(|i| *i as i64)).collect();
let mask: Vec<i64> = encodings.iter().flat_map(|e| e.get_attention_mask().iter().map(|i| *i as i64)).collect();

// Convert our flattened arrays into 2-dimensional tensors of shape [N, L].
let a_ids = Array2::from_shape_vec([inputs.len(), padded_token_length], ids).unwrap();
let a_mask = Array2::from_shape_vec([inputs.len(), padded_token_length], mask).unwrap();

// Run the model.
let outputs = session.run(ort::inputs![a_ids, a_mask]?)?;

// Extract our embeddings tensor and convert it to a strongly-typed 2-dimensional array.
let embeddings = outputs[1].try_extract_tensor::<f32>()?.into_dimensionality::<Ix2>().unwrap();

println!("Similarity for '{}'", inputs[0]);
let query = embeddings.index_axis(Axis(0), 0);
for (embeddings, sentence) in embeddings.axis_iter(Axis(0)).zip(inputs.iter()).skip(1) {
// Calculate cosine similarity against the 'query' sentence.
let dot_product: f32 = query.iter().zip(embeddings.iter()).map(|(a, b)| a * b).sum();
println!("\t'{}': {:.1}%", sentence, dot_product * 100.);
}

Ok(())
}

0 comments on commit c5538f2

Please sign in to comment.