You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We should add an option to exclude special tokens when adding a SAE into HookedSAETransformer. This could take the form of an exclude_special_tokens param for add_sae() / run_with_cache_with_saes() / run_with_saes(). This would exclude running the SAE on BOS, EOS, and SEP tokens as specified by the model tokenizer. The user could pass True to avoid these standard tokens, or pass a list (or tensor) of token_id values to exclude to further customize this behavior.
Motivation
It's often not useful to apply a SAE on special tokens since the SAEs are often not trained on special tokens, and it's not particularly interesting to see SAE latents that fire on BOS. Given this is a common use-case, we should make it easy to just skip special tokens when running with a SAE using HookedSAETransformer as this class only exists to make common use-cases for SAEs easy.
Alternatives
We could alternatively allow users to specify certain token indices to avoid running the SAE on instead of token ids. This would require more work for the users but may support other use-cases where the user doesn't want to apply the SAE at certain positions. This could also be implemented separately / in-addition to adding an exclude_special_tokens param.
Checklist
I have checked that there is no similar issue in the repo (required)
The text was updated successfully, but these errors were encountered:
Here is a temporary fix that worked for me, if someone wants to use before the PR is made.
import torch
def run_with_saes_filtered(tokens, filtered_ids, model, saes):
# Ensure tokens are a torch.Tensor
if not isinstance(tokens, torch.Tensor):
tokens = torch.tensor(tokens, dtype=torch.long)
# Create a mask where True indicates positions to modify
mask = torch.ones_like(tokens, dtype=torch.bool)
for token_id in filtered_ids:
mask &= tokens != token_id
# For each SAE, add the appropriate hook
for sae in saes:
hook_point = sae.cfg.hook_name
# Define the modified hook function
def filtered_hook(act, hook, sae=sae, mask=mask):
# act shape: [batch_size, seq_len, hidden_size]
# Expand mask to match the shape of act
mask_expanded = mask.unsqueeze(-1).expand_as(act)
# Apply sae only to positions where mask is True
act = torch.where(mask_expanded, sae(act), act)
return act
# Add the hook to the model
model.add_hook(hook_point, filtered_hook, dir='fwd')
# Run the model with the tokens
logits = model(tokens)
# Reset the hooks after computation
model.reset_hooks()
return logits
filtered_ids = [
model.tokenizer.bos_token_id,
model.tokenizer.eos_token_id,
model.tokenizer.pad_token_id
]
logits = run_with_saes_filtered(tokens, filtered_ids, model, [sae])
Proposal
We should add an option to exclude special tokens when adding a SAE into
HookedSAETransformer
. This could take the form of anexclude_special_tokens
param foradd_sae()
/run_with_cache_with_saes()
/run_with_saes()
. This would exclude running the SAE on BOS, EOS, and SEP tokens as specified by the model tokenizer. The user could passTrue
to avoid these standard tokens, or pass a list (or tensor) oftoken_id
values to exclude to further customize this behavior.Motivation
It's often not useful to apply a SAE on special tokens since the SAEs are often not trained on special tokens, and it's not particularly interesting to see SAE latents that fire on BOS. Given this is a common use-case, we should make it easy to just skip special tokens when running with a SAE using
HookedSAETransformer
as this class only exists to make common use-cases for SAEs easy.Alternatives
We could alternatively allow users to specify certain token indices to avoid running the SAE on instead of token ids. This would require more work for the users but may support other use-cases where the user doesn't want to apply the SAE at certain positions. This could also be implemented separately / in-addition to adding an
exclude_special_tokens
param.Checklist
The text was updated successfully, but these errors were encountered: