Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal] Allow excluding special tokens when applying SAEs in HookedSAETransformer #350

Open
1 task done
chanind opened this issue Oct 24, 2024 · 1 comment
Open
1 task done

Comments

@chanind
Copy link
Collaborator

chanind commented Oct 24, 2024

Proposal

We should add an option to exclude special tokens when adding a SAE into HookedSAETransformer. This could take the form of an exclude_special_tokens param for add_sae() / run_with_cache_with_saes() / run_with_saes(). This would exclude running the SAE on BOS, EOS, and SEP tokens as specified by the model tokenizer. The user could pass True to avoid these standard tokens, or pass a list (or tensor) of token_id values to exclude to further customize this behavior.

Motivation

It's often not useful to apply a SAE on special tokens since the SAEs are often not trained on special tokens, and it's not particularly interesting to see SAE latents that fire on BOS. Given this is a common use-case, we should make it easy to just skip special tokens when running with a SAE using HookedSAETransformer as this class only exists to make common use-cases for SAEs easy.

Alternatives

We could alternatively allow users to specify certain token indices to avoid running the SAE on instead of token ids. This would require more work for the users but may support other use-cases where the user doesn't want to apply the SAE at certain positions. This could also be implemented separately / in-addition to adding an exclude_special_tokens param.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@NainaniJatinZ
Copy link

Here is a temporary fix that worked for me, if someone wants to use before the PR is made.

import torch

def run_with_saes_filtered(tokens, filtered_ids, model, saes):
    # Ensure tokens are a torch.Tensor
    if not isinstance(tokens, torch.Tensor):
        tokens = torch.tensor(tokens, dtype=torch.long)

    # Create a mask where True indicates positions to modify
    mask = torch.ones_like(tokens, dtype=torch.bool)
    for token_id in filtered_ids:
        mask &= tokens != token_id

    # For each SAE, add the appropriate hook
    for sae in saes:
        hook_point = sae.cfg.hook_name

        # Define the modified hook function
        def filtered_hook(act, hook, sae=sae, mask=mask):
            # act shape: [batch_size, seq_len, hidden_size]
            # Expand mask to match the shape of act
            mask_expanded = mask.unsqueeze(-1).expand_as(act)
            # Apply sae only to positions where mask is True
            act = torch.where(mask_expanded, sae(act), act)
            return act

        # Add the hook to the model
        model.add_hook(hook_point, filtered_hook, dir='fwd')

    # Run the model with the tokens
    logits = model(tokens)

    # Reset the hooks after computation
    model.reset_hooks()
    return logits

filtered_ids = [
    model.tokenizer.bos_token_id,
    model.tokenizer.eos_token_id,
    model.tokenizer.pad_token_id
]

logits = run_with_saes_filtered(tokens, filtered_ids, model, [sae])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants