-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shape of attention mask in distilbert example #2667
Comments
@ToluClassics maybe ? (seems like you commited this code :-) |
Hi @fbilhaut , I recently explored your use case and took a deeper look into both the Candle and Transformers implementations. The PyTorch-based Transformers implementation has some differences in how attention is computed, but these differences primarily reflect API design choices. Despite these differences, you can replicate the same attention mask behavior in Candle. Specifically, the "eager" method in Transformers uses One notable discrepancy between the Candle-transformers and PyTorch implementations is the absence of a reshape ( In the PyTorch class MultiHeadSelfAttention:
...
def forward(..., mask, ...):
...
mask_reshp = (bs, 1, 1, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) To replicate this in Candle, you can use: let attention_mask = attention_mask
.eq(0 as u32)?
.reshape((batch_size, 1, 1, feature_size))?; Additionally, in the PyTorch implementation, the attention mask is initialized as a matrix of ones. To better match this behavior, consider using attention_mask = torch.ones(...) The final solution: use candle_transformers::models::distilbert::{
Config as DistilConfig, DistilBertModel, DTYPE as DistilDTYPE,
};
use candle_core::{DType, Device, Error as CandleError, IndexOp, Tensor};
use candle_nn::VarBuilder;
use hf_hub::api::sync::Api;
use std::{
error::Error,
fs::{self, File},
u32,
};
const WITH_SPECIAL_TOKENS: bool = true;
const TEXT: [&str; 1] = ["hello"];
use tokenizers::{tokenizer::Tokenizer, PaddingParams};
pub fn select_row(tensor: &Tensor, index: usize) -> Result<Tensor, CandleError> {
assert_eq!(tensor.dims().len(), 2);
Ok(tensor.i(index)?)
}
fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let device = Device::Cpu;
let model = "distilbert/distilbert-base-uncased";
let api = Api::new()?;
let repo = api.model(model.to_string());
let config_filepath = repo.get("config.json")?;
let config_reader = File::open(config_filepath)?;
let safetensors_filepath = repo.get("model.safetensors")?;
let buffer: Vec<u8> = fs::read(safetensors_filepath)?;
let mut tokenizer: Tokenizer = Tokenizer::from_pretrained(model.to_string(), None)?;
let config: DistilConfig = serde_json::from_reader(&config_reader)?;
let vb = VarBuilder::from_buffered_safetensors(buffer, DistilDTYPE, &device)?;
let model = DistilBertModel::load(vb, &config)?;
let tokenizer = tokenizer.with_padding(Some(PaddingParams::default()));
let encoded = tokenizer.encode_batch(TEXT.to_vec().clone(), WITH_SPECIAL_TOKENS)?;
let input_ids = encoded
.iter()
.map(|v| v.get_ids().to_vec())
.collect::<Vec<_>>();
let input_ids = Tensor::new(input_ids, &device)?;
println!("\nInput Ids\n{input_ids}");
let input_ids = input_ids.to_dtype(DType::I64)?;
/*let attention_mask = encoded
.iter()
.map(|encoding| encoding.get_attention_mask().to_vec())
.collect::<Vec<_>>();
let attention_mask = Tensor::new(attention_mask, &device)?;
*/
let (batch_size, feature_size) = input_ids.dims2()?;
let attention_mask = Tensor::ones_like(&input_ids)?;
let attention_mask = attention_mask
.eq(0 as u32)?
.reshape((batch_size, 1, 1, feature_size))?;
println!("\nAttention Mask\n{attention_mask}");
let output = model.forward(&input_ids, &attention_mask)?;
println!("\nLast Hidden State\n{output}");
Ok(())
} And the (almost) equivalent implementation in Python from transformers import DistilBertModel, DistilBertTokenizer
TEXT = ["Hello"]
if __name__ == "__main__":
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained("distilbert-base-uncased", attn_implementation="eager")
encoded_input = tokenizer(TEXT, padding=True, return_tensors='pt')
print("Input ids")
input_ids = encoded_input["input_ids"]
print(input_ids)
print(input_ids.dtype)
output = model(**encoded_input)
print("\nLast Hidden State")
last_hidden_state = output["last_hidden_state"]
print(last_hidden_state)
print(last_hidden_state.dtype) Note: You may observe a difference in numeric precision. I am not entirely sure about the exact causes at play here. |
Hi @sondalex, Thank you very much for investigating this, really appreciate. The However I really can't figure out the logic behind initializing the mask with ones and then "negating" the matrix with
The final shape seems to be correct, but I can't see the point of filling the matrix with ones just to "negate" it right after. And in the first place a sequence of zeroes (or even ones) doesn't look like a valid attention mask to me. I'm sorry if I'm missing something, but no matter how much I reread I don't see how it could make sense :-) Anyway, I ended up with the following, just using let (batch_size, feature_size) = input_ids.dims2()?;
let attention_masks: Vec<_> = tokens.iter().map(|x| x.get_attention_mask().to_vec()).collect();
let attention_masks = Tensor::new(attention_masks, device)?;
let attention_masks = attention_masks.reshape((batch_size, 1, 1, feature_size))?; The last part being, indeed, appropriate to get the expected shape. I'm not sure yet if the overall solution is okay, as I have to test further with the obtained tensors after inference, but at first sight the output look okay. |
Hi @fbilhaut, I have overlooked the from transformers import DistilBertModel, DistilBertTokenizer
TEXT = ["hello"]
if __name__ == "__main__":
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained("distilbert-base-uncased", attn_implementation="eager")
encoded_input = tokenizer(TEXT, padding=True, return_tensors='pt')
input_ids = encoded_input["input_ids"]
output = model(input_ids=encoded_input["input_ids"]) # Note the difference here
print("\nLast Hidden State")
last_hidden_state = output["last_hidden_state"] As a result, in general you would not want to do this. Regarding the operation: let attention_mask = attention_mask
.eq(0 as u32)?
... This statement is necessary due to the
If you examine the candle/candle-transformers/src/models/distilbert.rs Lines 13 to 18 in efd0e68
I find the current design of I have raised an issue here #2721 |
Hi,
I'm trying to adapt the distilbert example to make it process multiple sequences at once (the provided example just processes one prompt).
But I'm having trouble providing the proper attention mask to the
DistilBertModel::forward()
method.I noticed, when reading the documentation of the
forward()
method of the equivalent Python class, that that this mask is expected to have the same shape as theinput_ids
parameter.This seems sound, and is also consistent with the BERT example in Candle, which does it that way when it comes to processing multiple sequences to compute similarities:
BUT:
In the
distilbert
example the tokenizer doesn't add any padding, and there is a quite mysterious function that is supposed compute the attention mask returning a different (squared) shape:For example for a sequence of 3 tokens this generates the following NxN mask:
If I simply replace this by the result of the
get_attention_mask()
function in a 1xN tensor, it works for one sequence.But for several sequences, if I pad all the sequences to same size S, and stack the masks obtained for N sequence to get a NxS tensor (as does the
bert
example mentioned earlier), I get an error like this:I must admit that I don't get the expectations of
DistilBertModel::forward()
regarding the provided mask. I also don't understand what thegest_mask()
function is supposed to do.Maybe this is due to my lack of knowledge on that matter, but when I refer to the elements mentioned above (Python equivalent and similar Candle example with Bert), I'm wondering if there isn't something wrong with the
distilbert
example and/or model implementation ?The text was updated successfully, but these errors were encountered: