Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shape of attention mask in distilbert example #2667

Open
fbilhaut opened this issue Dec 13, 2024 · 4 comments
Open

Shape of attention mask in distilbert example #2667

fbilhaut opened this issue Dec 13, 2024 · 4 comments

Comments

@fbilhaut
Copy link

fbilhaut commented Dec 13, 2024

Hi,

I'm trying to adapt the distilbert example to make it process multiple sequences at once (the provided example just processes one prompt).

But I'm having trouble providing the proper attention mask to the DistilBertModel::forward() method.

I noticed, when reading the documentation of the forward() method of the equivalent Python class, that that this mask is expected to have the same shape as the input_ids parameter.

This seems sound, and is also consistent with the BERT example in Candle, which does it that way when it comes to processing multiple sequences to compute similarities:

let token_ids = tokens.iter().map(|tokens| {
    let tokens = tokens.get_ids().to_vec();
    Ok(Tensor::new(tokens.as_slice(), device)?)
}).collect::<Result<Vec<_>>>()?;

let attention_mask = tokens.iter().map(|tokens| {
    let tokens = tokens.get_attention_mask().to_vec();
    Ok(Tensor::new(tokens.as_slice(), device)?)
}).collect::<Result<Vec<_>>>()?;

let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;

// ...

model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;

BUT:

In the distilbert example the tokenizer doesn't add any padding, and there is a quite mysterious function that is supposed compute the attention mask returning a different (squared) shape:

fn get_mask(size: usize, device: &Device) -> Tensor {
    let mask: Vec<_> = (0..size)
        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
        .collect();
    Tensor::from_slice(&mask, (size, size), device).unwrap()
}

For example for a sequence of 3 tokens this generates the following NxN mask:

[[0, 1, 1], 
[0, 0, 1], 
[0, 0, 0]]

If I simply replace this by the result of the get_attention_mask() function in a 1xN tensor, it works for one sequence.

But for several sequences, if I pad all the sequences to same size S, and stack the masks obtained for N sequence to get a NxS tensor (as does the bert example mentioned earlier), I get an error like this:

cannot broadcast [2, 32] to [2, 12, 32, 32]

I must admit that I don't get the expectations of DistilBertModel::forward() regarding the provided mask. I also don't understand what the gest_mask() function is supposed to do.

Maybe this is due to my lack of knowledge on that matter, but when I refer to the elements mentioned above (Python equivalent and similar Candle example with Bert), I'm wondering if there isn't something wrong with the distilbert example and/or model implementation ?

@fbilhaut
Copy link
Author

fbilhaut commented Dec 13, 2024

@ToluClassics maybe ? (seems like you commited this code :-)

@sondalex
Copy link

sondalex commented Jan 8, 2025

Hi @fbilhaut ,

I recently explored your use case and took a deeper look into both the Candle and Transformers implementations. The PyTorch-based Transformers implementation has some differences in how attention is computed, but these differences primarily reflect API design choices. Despite these differences, you can replicate the same attention mask behavior in Candle.

Specifically, the "eager" method in Transformers uses MultiHeadSelfAttention, which corresponds to the implementation in Candle-transformers.

One notable discrepancy between the Candle-transformers and PyTorch implementations is the absence of a reshape (view in PyTorch) function call in Candle-transformers.

In the PyTorch MultiHeadSelfAttention class, the forward method handles the mask as follows:

class MultiHeadSelfAttention:
    ...
    def forward(..., mask, ...):
        ...
        mask_reshp = (bs, 1, 1, k_length)
        mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)

To replicate this in Candle, you can use:

let attention_mask = attention_mask
    .eq(0 as u32)?
    .reshape((batch_size, 1, 1, feature_size))?;

Additionally, in the PyTorch implementation, the attention mask is initialized as a matrix of ones. To better match this behavior, consider using Tensor::ones() for initialization instead of mapping over the encodings:

attention_mask = torch.ones(...)

The final solution:

use candle_transformers::models::distilbert::{
    Config as DistilConfig, DistilBertModel, DTYPE as DistilDTYPE,
};

use candle_core::{DType, Device, Error as CandleError, IndexOp, Tensor};
use candle_nn::VarBuilder;

use hf_hub::api::sync::Api;
use std::{
    error::Error,
    fs::{self, File},
    u32,
};

const WITH_SPECIAL_TOKENS: bool = true;
const TEXT: [&str; 1] = ["hello"];

use tokenizers::{tokenizer::Tokenizer, PaddingParams};

pub fn select_row(tensor: &Tensor, index: usize) -> Result<Tensor, CandleError> {
    assert_eq!(tensor.dims().len(), 2);
    Ok(tensor.i(index)?)
}

fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
    let device = Device::Cpu;
    let model = "distilbert/distilbert-base-uncased";

    let api = Api::new()?;
    let repo = api.model(model.to_string());

    let config_filepath = repo.get("config.json")?;
    let config_reader = File::open(config_filepath)?;
    let safetensors_filepath = repo.get("model.safetensors")?;
    let buffer: Vec<u8> = fs::read(safetensors_filepath)?;

    let mut tokenizer: Tokenizer = Tokenizer::from_pretrained(model.to_string(), None)?;
    let config: DistilConfig = serde_json::from_reader(&config_reader)?;
    let vb = VarBuilder::from_buffered_safetensors(buffer, DistilDTYPE, &device)?;
    let model = DistilBertModel::load(vb, &config)?;

    let tokenizer = tokenizer.with_padding(Some(PaddingParams::default()));

    let encoded = tokenizer.encode_batch(TEXT.to_vec().clone(), WITH_SPECIAL_TOKENS)?;

    let input_ids = encoded
        .iter()
        .map(|v| v.get_ids().to_vec())
        .collect::<Vec<_>>();
    let input_ids = Tensor::new(input_ids, &device)?;
    println!("\nInput Ids\n{input_ids}");
    let input_ids = input_ids.to_dtype(DType::I64)?;
    /*let attention_mask = encoded
        .iter()
        .map(|encoding| encoding.get_attention_mask().to_vec())
        .collect::<Vec<_>>();
    let attention_mask = Tensor::new(attention_mask, &device)?;
    */

    let (batch_size, feature_size) = input_ids.dims2()?;
    let attention_mask = Tensor::ones_like(&input_ids)?;

    let attention_mask = attention_mask
        .eq(0 as u32)?
        .reshape((batch_size, 1, 1, feature_size))?;
    println!("\nAttention Mask\n{attention_mask}");

    let output = model.forward(&input_ids, &attention_mask)?;
    println!("\nLast Hidden State\n{output}");

    Ok(())
}

And the (almost) equivalent implementation in Python

from transformers import DistilBertModel, DistilBertTokenizer

TEXT = ["Hello"]


if __name__ == "__main__":

    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    model = DistilBertModel.from_pretrained("distilbert-base-uncased", attn_implementation="eager")

    encoded_input = tokenizer(TEXT, padding=True, return_tensors='pt')
    print("Input ids")
    input_ids = encoded_input["input_ids"]
    print(input_ids)
    print(input_ids.dtype)
    output = model(**encoded_input)
    print("\nLast Hidden State")
    last_hidden_state = output["last_hidden_state"]
    print(last_hidden_state)
    print(last_hidden_state.dtype)

Note:

You may observe a difference in numeric precision. I am not entirely sure about the exact causes at play here.

@fbilhaut
Copy link
Author

fbilhaut commented Jan 10, 2025

Hi @sondalex,

Thank you very much for investigating this, really appreciate.

The reshape() part sounds good.

However I really can't figure out the logic behind initializing the mask with ones and then "negating" the matrix with eq(0):

  • Calling Tensor::ones_like(&input_ids) makes something lie [[1, 1, 1, ...], [1, 1, 1, ...]]
  • Then applying eq(0) produces [[0, 0, 0, ...], [0, 0, 0, ...]]
  • Then reshape produces [[[[0, 0, 0, ...]]], [[[0, 0, 0, ...]]]]

The final shape seems to be correct, but I can't see the point of filling the matrix with ones just to "negate" it right after. And in the first place a sequence of zeroes (or even ones) doesn't look like a valid attention mask to me. I'm sorry if I'm missing something, but no matter how much I reread I don't see how it could make sense :-)

Anyway, I ended up with the following, just using reshape after generating "regular" masks:

let (batch_size, feature_size) = input_ids.dims2()?;
let attention_masks: Vec<_> = tokens.iter().map(|x| x.get_attention_mask().to_vec()).collect();
let attention_masks = Tensor::new(attention_masks, device)?;
let attention_masks = attention_masks.reshape((batch_size, 1, 1, feature_size))?;

The last part being, indeed, appropriate to get the expected shape.

I'm not sure yet if the overall solution is okay, as I have to test further with the obtained tensors after inference, but at first sight the output look okay.

@sondalex
Copy link

sondalex commented Jan 16, 2025

Hi @fbilhaut, I have overlooked the Tensor.ones_like() operations in my interpretation of the Pytorch implementation.
This initialization occurs only under a specific case that is not reflected in my Python example. Respectively:

from transformers import DistilBertModel, DistilBertTokenizer

TEXT = ["hello"]


if __name__ == "__main__":

    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    model = DistilBertModel.from_pretrained("distilbert-base-uncased", attn_implementation="eager")

    encoded_input = tokenizer(TEXT, padding=True, return_tensors='pt')
    input_ids = encoded_input["input_ids"]
    
    output = model(input_ids=encoded_input["input_ids"]) # Note the difference here
    print("\nLast Hidden State")
    last_hidden_state = output["last_hidden_state"]

As a result, in general you would not want to do this.

Regarding the operation:

let attention_mask = attention_mask
        .eq(0 as u32)?
        ...

This statement is necessary due to the masked_fill function in the Candle source code.

let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;

If you examine the masked_fill function, you'll see that mask values greater than 0 are replaced with the on_true value (NaN constant).

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

I find the current design of masked_fill quite counterintuitive, abstracting away the inversion of the attention mask within the API would be better. And this would only require rewriting masked_fill function.

I have raised an issue here #2721

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants