diff --git a/Cargo.toml b/Cargo.toml index fecdee1..cacb045 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,9 @@ viuer = { version = "0.9.1", features = ["print-file"], optional = true } sonogram = "0.7.1" image = "0.25.5" rodio = { version = "0.20.1", optional = true } +rayon = "1.10.0" +bytes = { version = "1.9.0", features = ["serde"] } +symphonia = "0.5.4" [features] default = [] diff --git a/src/db.rs b/src/db.rs index 6d1ac64..a7bb814 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,12 +1,12 @@ use crate::distance::DistanceUnit; use crate::model::DatabaseEmbeddingModel; -use crate::EF; +use bytes::Bytes; use fastembed::Embedding; use hnsw::Params; use hnsw::{Hnsw, Searcher}; use pcg_rand::Pcg64; use serde::{Deserialize, Serialize}; -use space::{Metric, Neighbor}; +use space::Metric; use std::collections::HashMap; use std::fs::OpenOptions; use std::io::{self, BufReader, BufWriter}; @@ -135,20 +135,20 @@ where /// # Returns /// /// A tuple containing the number of embeddings inserted and the dimension of the embeddings. - pub fn insert_documents + Send + Sync + Clone, Mod: DatabaseEmbeddingModel>( + pub fn insert_documents( &mut self, model: &Mod, - documents: &[S], + documents: &[Bytes], ) -> Result<(usize, usize), Box> { let new_embeddings: Vec = model.embed_documents(documents.to_vec())?; let length_and_dimension = (new_embeddings.len(), new_embeddings[0].len()); let mut searcher: Searcher = Searcher::default(); + let mut document_map = HashMap::new(); for (document, embedding) in documents.iter().zip(new_embeddings.iter()) { let embedding_index = self.hnsw.insert(embedding.clone(), &mut searcher); - let mut document_map = HashMap::new(); document_map.insert(embedding_index, document.clone()); - self.save_documents_to_disk(&mut document_map)?; } + self.save_documents_to_disk(&mut document_map)?; self.save_database()?; Ok(length_and_dimension) } @@ -161,39 +161,31 @@ where /// /// * `documents` - A vector of documents to be queried. /// - /// * `number_of_results` - An optional positive integer less than or equal to `EF` specifying the number of query results to return. + /// * `number_of_results` - The candidate list size for the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. /// /// # Returns /// /// A vector of documents that are most similar to the queried documents. - pub fn query_documents + Send + Sync, Mod: DatabaseEmbeddingModel>( + pub fn query_documents( &mut self, model: &Mod, - documents: Vec, - number_of_results: Option, + documents: Vec, + number_of_results: usize, ) -> Result>, Box> { if self.hnsw.is_empty() { return Ok(Vec::new()); } - let number_of_results = match number_of_results { - None => 1, - Some(number_of_results) => std::cmp::min(number_of_results, EF), - }; let mut searcher: Searcher = Searcher::default(); let mut results = Vec::new(); - // let model = TextEmbedding::try_new(InitOptions { - // model_name: EmbeddingModel::BGESmallENV15, - // show_download_progress: false, - // ..Default::default() - // })?; let query_embeddings = model.embed_documents(documents)?; for query_embedding in query_embeddings.iter() { - let mut neighbours = [Neighbor { - index: !0, - distance: !0, - }; EF]; - self.hnsw - .nearest(query_embedding, EF, &mut searcher, &mut neighbours); + let mut neighbours = Vec::new(); + self.hnsw.nearest( + query_embedding, + number_of_results, + &mut searcher, + &mut neighbours, + ); if neighbours.is_empty() { return Ok(Vec::new()); } @@ -216,14 +208,14 @@ where /// # Arguments /// /// * `documents` - A map of document indices and their corresponding documents. - pub fn save_documents_to_disk + Send + Sync>( + pub fn save_documents_to_disk( &self, - documents: &mut HashMap, + documents: &mut HashMap, ) -> Result<(), Box> { let document_subdirectory = self.document_type.subdirectory_name(); std::fs::create_dir_all(document_subdirectory)?; for document in documents { - let mut reader = BufReader::new(document.1.as_ref().as_bytes()); + let mut reader = BufReader::new(document.1.as_ref()); let file = OpenOptions::new() .read(true) .write(true) diff --git a/src/image.rs b/src/image.rs index 698eaa4..e036150 100644 --- a/src/image.rs +++ b/src/image.rs @@ -1,5 +1,11 @@ use crate::db::{Database, DocumentType}; use crate::distance::{CosineDistance, DefaultImageMetric}; +use bytes::Bytes; +use candle_core::Tensor; +use candle_examples::imagenet::{IMAGENET_MEAN, IMAGENET_STD}; +use image::ImageReader; +use std::error::Error; +use std::io::Cursor; /// A parameter regarding insertion into the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Cannot be changed after database creation. pub const IMAGE_EF_CONSTRUCTION: usize = 400; @@ -21,3 +27,33 @@ pub type ImageDatabase = Database Result> { ImageDatabase::create_or_load_database(CosineDistance, DocumentType::Image) } + +/// Loads an image from raw bytes with ImageNet normalisation applied, returning a tensor with the shape [3 224 224]. +/// +/// # Arguments +/// +/// * `bytes` - The raw bytes of an image. +/// +/// # Returns +/// +/// A tensor with the shape [3 224 224]; ImageNet normalisation is applied. +pub fn load_image224(bytes: Bytes) -> Result> { + let res = 224_usize; + let img = ImageReader::new(Cursor::new(bytes)) + .with_guessed_format()? + .decode()? + .resize_to_fill( + res as u32, + res as u32, + image::imageops::FilterType::Triangle, + ) + .to_rgb8(); + let data = img.into_raw(); + let data = + Tensor::from_vec(data, (res, res, 3), &candle_core::Device::Cpu)?.permute((2, 0, 1))?; + let mean = Tensor::new(&IMAGENET_MEAN, &candle_core::Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(&IMAGENET_STD, &candle_core::Device::Cpu)?.reshape((3, 1, 1))?; + Ok((data.to_dtype(candle_core::DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std)?) +} diff --git a/src/lib.rs b/src/lib.rs index ce47723..5c278c5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,3 @@ pub mod image; pub mod model; /// A module for text database operations. pub mod text; - -/// The candidate list size for the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Can be changed after database creation. -pub const EF: usize = 24; diff --git a/src/main.rs b/src/main.rs index e02f16e..5139946 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use clap::{command, Parser, Subcommand}; use fastembed::Embedding; use fastembed::TextEmbedding; @@ -5,6 +6,9 @@ use indicatif::HumanCount; use indicatif::ProgressStyle; use indicatif::{ProgressBar, ProgressDrawTarget}; use pretty_duration::pretty_duration; +use rayon::iter::IntoParallelIterator; +use rayon::iter::IntoParallelRefIterator; +use rayon::iter::ParallelIterator; use rodio::{Decoder, OutputStream, Sink}; use space::Metric; use std::error::Error; @@ -71,7 +75,8 @@ enum TextCommands { #[command(about = "Query texts from the database.", arg_required_else_help(true))] Query { texts: Vec, - number_of_results: Option, + #[arg(default_value_t = 1)] + number_of_results: usize, }, #[command(about = "Clear the database.")] Clear, @@ -90,7 +95,8 @@ enum ImageCommands { )] Query { image_path: PathBuf, - number_of_results: Option, + #[arg(default_value_t = 1)] + number_of_results: usize, }, #[command(about = "Clear the database.")] Clear, @@ -109,7 +115,8 @@ enum AudioCommands { )] Query { audio_path: PathBuf, - number_of_results: Option, + #[arg(default_value_t = 1)] + number_of_results: usize, }, #[command(about = "Clear the database.")] Clear, @@ -119,13 +126,14 @@ fn main() -> Result<(), Box> { let cli = Cli::parse(); match cli.commands { Commands::Text(text) => match text.text_commands { - TextCommands::Insert { mut texts } => { + TextCommands::Insert { texts } => { let mut sw = Stopwatch::start_new(); let mut db = zebra::text::create_or_load_database()?; let mut buffer = BufWriter::new(stdout().lock()); let model: TextEmbedding = DatabaseEmbeddingModel::new()?; writeln!(buffer, "Inserting {} text(s).", texts.len())?; - let insertion_results = db.insert_documents(&model, &mut texts)?; + let texts_bytes: Vec<_> = texts.into_par_iter().map(|x| Bytes::from(x)).collect(); + let insertion_results = db.insert_documents(&model, &texts_bytes)?; sw.stop(); writeln!( buffer, @@ -150,10 +158,11 @@ fn main() -> Result<(), Box> { let num_texts = texts.len(); let model: TextEmbedding = DatabaseEmbeddingModel::new()?; writeln!(buffer, "Querying {} text(s).", num_texts)?; - let query_results = db.query_documents(&model, texts, number_of_results)?; - let result_texts: Vec = query_results + let texts_bytes: Vec<_> = texts.into_par_iter().map(|x| Bytes::from(x)).collect(); + let query_results = db.query_documents(&model, texts_bytes, number_of_results)?; + let result_texts: Vec<_> = query_results .iter() - .map(|x| String::from_utf8(x.to_vec()).unwrap()) + .map(|x| String::from_utf8_lossy(x)) .collect(); sw.stop(); writeln!( @@ -201,11 +210,9 @@ fn main() -> Result<(), Box> { }; let model: ImageEmbeddingModel = DatabaseEmbeddingModel::new()?; writeln!(buffer, "Querying image.")?; - let query_results = db.query_documents( - &model, - vec![image_path.to_str().unwrap()], - number_of_results, - )?; + let image_bytes = std::fs::read(image_path).unwrap_or_default().into(); + let query_results = + db.query_documents(&model, vec![image_bytes], number_of_results)?; sw.stop(); writeln!( buffer, @@ -239,11 +246,9 @@ fn main() -> Result<(), Box> { let mut buffer = BufWriter::new(stdout().lock()); let model: AudioEmbeddingModel = DatabaseEmbeddingModel::new()?; writeln!(buffer, "Querying sound.")?; - let query_results = db.query_documents( - &model, - vec![audio_path.to_str().unwrap()], - number_of_results, - )?; + let audio_bytes = std::fs::read(audio_path).unwrap_or_default().into(); + let query_results = + db.query_documents(&model, vec![audio_bytes], number_of_results)?; sw.stop(); writeln!( buffer, @@ -264,7 +269,6 @@ fn main() -> Result<(), Box> { clear_database(DocumentType::Audio)?; } }, - // _ => unreachable!(), } Ok(()) } @@ -315,9 +319,9 @@ where ProgressDrawTarget::hidden(), ); progress_bar.set_style(progress_bar_style()?); - let documents: Vec = file_paths - .into_iter() - .map(|x| x.to_str().unwrap().to_string()) + let documents: Vec<_> = file_paths + .par_iter() + .filter_map(|x| std::fs::read(x).ok().map(|y| y.into())) .collect(); // Insert documents in batches of INSERT_BATCH_SIZE. for document_batch in documents.chunks(INSERT_BATCH_SIZE) { diff --git a/src/model.rs b/src/model.rs index 0b53b8d..e178e74 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,13 +1,25 @@ +use crate::image::load_image224; +use bytes::Bytes; use candle_core::DType; use candle_core::Device; use candle_core::Tensor; use candle_nn::VarBuilder; use candle_transformers::models::vit; use fastembed::{Embedding, EmbeddingModel, InitOptions, TextEmbedding}; +use rayon::iter::IntoParallelIterator; +use rayon::iter::ParallelIterator; use sonogram::ColourGradient; use sonogram::FrequencyScale; use sonogram::SpecOptionsBuilder; -use std::{error::Error, path::PathBuf}; +use std::error::Error; +use std::io::Cursor; +use symphonia::core::audio::Signal; +use symphonia::core::codecs::DecoderOptions; +use symphonia::core::codecs::CODEC_TYPE_NULL; +use symphonia::core::formats::FormatOptions; +use symphonia::core::io::MediaSourceStream; +use symphonia::core::meta::MetadataOptions; +use symphonia::core::probe::Hint; /// A trait for embedding models that can be used with the database. pub trait DatabaseEmbeddingModel { @@ -25,10 +37,7 @@ pub trait DatabaseEmbeddingModel { /// # Returns /// /// A vector of embeddings. - fn embed_documents + Send + Sync>( - &self, - documents: Vec, - ) -> Result, Box>; + fn embed_documents(&self, documents: Vec) -> Result, Box>; /// Embed a single document. /// @@ -39,7 +48,7 @@ pub trait DatabaseEmbeddingModel { /// # Returns /// /// An embedding vector. - fn embed + Send + Sync>(&self, document: S) -> Result>; + fn embed(&self, document: Bytes) -> Result>; } impl DatabaseEmbeddingModel for TextEmbedding { @@ -48,15 +57,23 @@ impl DatabaseEmbeddingModel for TextEmbedding { InitOptions::new(EmbeddingModel::BGESmallENV15).with_show_download_progress(false), )?) } - fn embed_documents + Send + Sync>( - &self, - documents: Vec, - ) -> Result, Box> { - Ok(self.embed(documents, None)?) + fn embed_documents(&self, documents: Vec) -> Result, Box> { + Ok(self.embed( + documents + .into_par_iter() + .map(|x| x.to_vec()) + .filter_map(|x| String::from_utf8(x).ok()) + .collect(), + None, + )?) } - fn embed + Send + Sync>(&self, document: S) -> Result> { - let vec_with_document = vec![document]; + fn embed(&self, document: Bytes) -> Result> { + let vec_with_document = vec![document] + .into_par_iter() + .map(|x| x.to_vec()) + .filter_map(|x| String::from_utf8(x).ok()) + .collect(); let vector_of_embeddings = self.embed(vec_with_document, None)?; Ok(vector_of_embeddings.first().unwrap().to_vec()) } @@ -69,10 +86,7 @@ impl DatabaseEmbeddingModel for ImageEmbeddingModel { fn new() -> Result> { Ok(Self) } - fn embed_documents + Send + Sync>( - &self, - documents: Vec, - ) -> Result, Box> { + fn embed_documents(&self, documents: Vec) -> Result, Box> { let mut result = Vec::new(); let device = candle_examples::device(false)?; let api = hf_hub::api::sync::Api::new()?; @@ -86,18 +100,16 @@ impl DatabaseEmbeddingModel for ImageEmbeddingModel { varbuilder.pp("vit").pp("embeddings"), )?; for document in documents { - let path = PathBuf::from(document.as_ref().to_string()); - let image = candle_examples::imagenet::load_image224(path)?.to_device(&device)?; + let image = load_image224(document)?.to_device(&device)?; let embedding_tensors = model.forward(&image.unsqueeze(0)?, None, false)?; let embedding_vector = embedding_tensors.flatten_all()?.to_vec1::()?; result.push(embedding_vector); } Ok(result) } - fn embed + Send + Sync>(&self, document: S) -> Result> { + fn embed(&self, document: Bytes) -> Result> { let device = candle_examples::device(false)?; - let path = PathBuf::from(document.as_ref().to_string()); - let image = candle_examples::imagenet::load_image224(path)?.to_device(&device)?; + let image = load_image224(document)?.to_device(&device)?; let api = hf_hub::api::sync::Api::new()?; let api = api.model("google/vit-base-patch16-224".into()); let model_file = api.get("model.safetensors")?; @@ -114,18 +126,75 @@ impl DatabaseEmbeddingModel for ImageEmbeddingModel { pub struct AudioEmbeddingModel; impl AudioEmbeddingModel { - /// Convert a waveform audio file into a logarithm-scale spectrogram for use with image embedding models. + /// Decodes the samples of an audio files. + /// + /// # Arguments + /// + /// * `audio` - The raw bytes of an audio file. + /// + /// # Returns + /// + /// An `i16` vector of decoded samples, and the sample rate of the audio. + pub fn audio_to_data(audio: Bytes) -> Result<(Vec, u32), Box> { + let mss = MediaSourceStream::new(Box::new(Cursor::new(audio)), Default::default()); + let meta_opts: MetadataOptions = Default::default(); + let fmt_opts: FormatOptions = Default::default(); + let probed = + symphonia::default::get_probe().format(&Hint::new(), mss, &fmt_opts, &meta_opts)?; + let mut format = probed.format; + let track = format + .tracks() + .into_par_iter() + .find_any(|t| t.codec_params.codec != CODEC_TYPE_NULL) + .unwrap(); + let dec_opts: DecoderOptions = Default::default(); + let mut decoder = symphonia::default::get_codecs().make(&track.codec_params, &dec_opts)?; + let track_id = track.id; + let mut sample_rate = 0; + let mut data = Vec::new(); + + loop { + match format.next_packet() { + Ok(packet) => { + while !format.metadata().is_latest() { + format.metadata().pop(); + } + if packet.track_id() != track_id { + continue; + } + match decoder.decode(&packet) { + Ok(decoded) => { + let decoded = decoded.make_equivalent::(); + sample_rate = decoded.spec().rate; + let number_channels = decoded.spec().channels.count(); + for i in 0..number_channels { + let samples = decoded.chan(i); + data.extend_from_slice(samples); + } + } + Err(_) => continue, + } + } + Err(_) => break, + } + } + + Ok((data, sample_rate)) + } + + /// Convert an audio file into a logarithm-scale spectrogram for use with image embedding models. /// /// # Arguments /// - /// `path` - The path to the waveform audio file. - pub fn audio_to_image_tensor + Send + Sync>( - path: S, - ) -> Result> { - let path = PathBuf::from(path.as_ref().to_string()); + /// `audio` - The raw bytes of an audio file. + /// + /// # Returns + /// + /// A spectrogram of the audio as an ImageNet-normalised tensor with shape [3 224 224]. + pub fn audio_to_image_tensor(audio: Bytes) -> Result> { + let (data, sample_rate) = Self::audio_to_data(audio)?; let mut spectrograph = SpecOptionsBuilder::new(512) - .load_data_from_file(&path) - .unwrap() + .load_data_from_memory(data, sample_rate) .normalise() .build() .unwrap(); @@ -151,10 +220,7 @@ impl DatabaseEmbeddingModel for AudioEmbeddingModel { fn new() -> Result> { Ok(Self) } - fn embed_documents + Send + Sync>( - &self, - documents: Vec, - ) -> Result, Box> { + fn embed_documents(&self, documents: Vec) -> Result, Box> { let mut result = Vec::new(); let device = candle_examples::device(false)?; let api = hf_hub::api::sync::Api::new()?; @@ -175,7 +241,7 @@ impl DatabaseEmbeddingModel for AudioEmbeddingModel { } Ok(result) } - fn embed + Send + Sync>(&self, document: S) -> Result> { + fn embed(&self, document: Bytes) -> Result> { let device = candle_examples::device(false)?; let image = AudioEmbeddingModel::audio_to_image_tensor(document)?.to_device(&device)?; let api = hf_hub::api::sync::Api::new()?;