Skip to content

Commit

Permalink
Support loading blockwise quantized fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jan 22, 2025
1 parent 21d7461 commit f1a9e63
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 4 deletions.
74 changes: 71 additions & 3 deletions mistralrs-quant/src/fp8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use std::{
};

use byteorder::{LittleEndian, ReadBytesExt};
use candle_core::{DType, Device, Result, Tensor, D};
use candle_nn::{Linear, Module};
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Linear, Module, VarBuilder};
use quantize::QuantizationResult;

mod quantize;
Expand All @@ -18,7 +18,8 @@ use crate::{
deserialize_tensor, read_dtype, serialize_tensor, version_is_compatible, write_dtype,
HQFF_VERSION,
},
IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde, QuantizedSerdeType,
DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizedConfig, QuantizedSerde,
QuantizedSerdeType, UnquantLinear,
};

#[derive(Debug)]
Expand Down Expand Up @@ -291,3 +292,70 @@ impl QuantizedSerde for FP8Linear {
}))
}
}

pub fn fp8_linear_b(
in_dim: usize,
out_dim: usize,
config: &QuantizedConfig,
bias: bool,
vb: VarBuilder,
) -> Result<Arc<dyn QuantMethod>> {
// Handle the case where the layer is dummy (no tensors)
if !(vb.contains_tensor("weight") && vb.contains_tensor("weight_scale_inv")) {
let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
}

let weight_block_size = config
.weight_block_size
.as_ref()
.expect("FP8 requires weight_block_size in config");
if weight_block_size.len() != 2 {
candle_core::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
}
let weight = vb.get_with_hints_dtype(
(out_dim, in_dim),
"weight",
Default::default(),
DType::F8E4M3,
)?;
let weight_scale_inv = vb.get_with_hints_dtype(
(
out_dim / weight_block_size[0],
in_dim / weight_block_size[1],
),
"weight_scale_inv",
Default::default(),
DType::F32,
)?;
let bias = if bias {
Some(vb.get((out_dim,), "bias")?)
} else {
None
};

let mut out = unsafe { Tensor::empty((out_dim, in_dim), vb.dtype(), vb.device())? };

for i in (0..out_dim).step_by(weight_block_size[0]) {
for j in (0..in_dim).step_by(weight_block_size[1]) {
let scale = weight_scale_inv.i((i / weight_block_size[0], j / weight_block_size[1]))?;

let dequannt_block = (weight
.i((i..i + weight_block_size[0], j..j + weight_block_size[1]))?
.to_dtype(DType::F32)?
* scale)?
.to_dtype(out.dtype())?;

out = out.slice_assign(
&[
&(i..i + weight_block_size[0]),
&(j..j + weight_block_size[1]),
],
&dequannt_block,
)?;
}
}

let config = QuantMethodConfig::Unquantized(Linear::new(weight, bias));
Ok(Arc::new(UnquantLinear::new(config)?))
}
11 changes: 10 additions & 1 deletion mistralrs-quant/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod utils;

pub use bitsandbytes::{BnbLinear, BnbQuantParmas, BnbQuantType};
pub use dummy::DummyLayer;
use fp8::fp8_linear_b;
pub use fp8::FP8Linear;
pub use gguf::GgufMatMul;
use gptq::gptq_linear;
Expand All @@ -42,6 +43,8 @@ use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub enum QuantMethodType {
#[serde(rename = "fp8")]
Fp8,
#[serde(rename = "gptq")]
Gptq,
#[serde(rename = "unreachable")]
Expand All @@ -54,7 +57,8 @@ pub enum QuantMethodType {
impl Display for QuantMethodType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Gptq => write!(f, "GPTQ"),
Self::Gptq => write!(f, "gptq"),
Self::Fp8 => write!(f, "fp8"),
Self::Bitsandbytes => write!(f, "bnb"),
Self::Unreachable => write!(f, "unreachable",),
}
Expand All @@ -71,6 +75,9 @@ pub struct QuantizedConfig {
// BNB
pub bnb_4bit_quant_type: Option<String>,

// FP8
pub weight_block_size: Option<Vec<usize>>,

pub quant_method: QuantMethodType,
}

Expand Down Expand Up @@ -442,6 +449,7 @@ pub fn linear_no_bias(
let layer = if let Some(quant_conf) = &config {
match quant_conf.quant_method {
QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
QuantMethodType::Fp8 => fp8_linear_b(in_dim, out_dim, quant_conf, false, vb)?,
QuantMethodType::Bitsandbytes => {
Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
}
Expand Down Expand Up @@ -471,6 +479,7 @@ pub fn linear(
let layer = if let Some(quant_conf) = &config {
match quant_conf.quant_method {
QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
QuantMethodType::Fp8 => fp8_linear_b(in_dim, out_dim, quant_conf, true, vb)?,
QuantMethodType::Bitsandbytes => {
Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
}
Expand Down

0 comments on commit f1a9e63

Please sign in to comment.