Skip to content

Commit

Permalink
update models
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayballal95 committed Nov 10, 2024
1 parent 76d456e commit 5149f96
Show file tree
Hide file tree
Showing 10 changed files with 1,375 additions and 7 deletions.
69 changes: 64 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ pdf-extract = "0.7.7"
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.7.2" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.7.2" }
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.7.2" }
candle-flash-attn = { git = "https://github.com/huggingface/candle.git", version = "0.7.2" }

ort = {version = "=2.0.0-rc.8", features = ["cuda"]}
strum = "0.26.1"
strum_macros = "0.26"
Expand Down
3 changes: 3 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ accelerate-src = { version = "0.3.2", optional = true }
indicatif = "0.17.8"
statistical = "1.0.0"
half = "2.4.1"
candle-flash-attn = { workspace = true, optional = true }


[dev-dependencies]
tempdir = "0.3.7"
Expand All @@ -97,3 +99,4 @@ accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/acceler
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-core/cuda"]
cudnn = ["candle-core/cudnn"]
load-dynamic = ["ort/load-dynamic"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
2 changes: 1 addition & 1 deletion rust/src/embeddings/local/colpali.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use anyhow::Error as E;
use base64::Engine;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::{colpali::Model, paligemma};
use crate::models::{colpali::Model, paligemma};
use image::{DynamicImage, ImageFormat};

use pdf2image::{Pages, RenderOptionsBuilder, PDF};
Expand Down
2 changes: 1 addition & 1 deletion rust/src/embeddings/local/colpali_ort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{collections::HashMap, path::PathBuf};

use anyhow::Error as E;
use base64::Engine;
use candle_transformers::models::paligemma;
use crate::models::paligemma;
use half::f16;
use image::{DynamicImage, ImageFormat};
use ndarray::prelude::*;
Expand Down
42 changes: 42 additions & 0 deletions rust/src/models/colpali.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use candle_core::{Module, Result, Tensor};
use candle_nn::VarBuilder;

use super::paligemma;
use candle_nn::{linear, Linear};

pub struct Model {
pub model: paligemma::Model,
pub custom_text_projection: Linear,
}

impl Model {
pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
let model = paligemma::Model::new(config, vb.pp("model"))?;
let custom_text_projection = linear(
config.text_config.hidden_size,
128,
vb.pp("custom_text_proj"),
)?;

Ok(Self {
model,
custom_text_projection,
})
}

pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self
.model
.setup_without_projection(pixel_values, input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}

pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self.model.forward_without_projection(input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
}
Loading

0 comments on commit 5149f96

Please sign in to comment.