AxBench is a a scalable benchmark that evaluates interpretability techniques on two axes: concept detection and model steering. This repo includes all benchmarking code, including data generation, training, evaluation, and analysis.
We introduced supervised dictionary learning (SDL) on synthetic data as an analogue to SAEs. You can access pretrained SDLs and our training/eval datasets here:
- 🤗 HuggingFace: AxBench Collections
- 🤗 ReFT-r1 Live Demo: Steering ChatLM
- 🤗 ReFT-cr1 Live Demo: Conditional Steering ChatLM
- 📚 Feature Visualizer: Visualize LM Activations
- 🔍 Subspace Gazer: Visualize Subspaces via UMAP
- Tutorial of using our dictionary via pyvene
- Scalabale evaluation harness: Framework for generating synthetic training + eval data from concept lists (e.g. GemmaScope SAE labels).
- Comprehensive implementations: 10+ interpretability methods evaluated, along with finetuning and prompting baselines.
- 16K concept training data: Full-scale datasets for supervised dictionary learning (SDL).
- Two pretrained SDL models: Drop-in replacements for standard SAEs.
- LLM-in-the-loop training: Generate your own datasets for less than $0.01 per concept.
We include exploratory notebooks under axbench/examples
, such as:
Experiment | Description |
---|---|
basics.ipynb |
Analyzes basic geometry of learned dictionaries. |
subspace_gazer.ipynb |
Visualizes learned subspaces. |
lang>subspace.ipynb |
Fine-tunes a hyper-network to map natural language to subspaces or steering vectors. |
platonic.ipynb |
Explores the platonic representation hypothesis in subspace learning. |
We highly suggest using uv
for your Python virtual environment, but you can use any venv manager.
git clone [email protected]:stanfordnlp/axbench.git
cd axbench
uv sync # if using uv
Set up your API keys for OpenAI and Neuronpedia:
import os
os.environ["OPENAI_API_KEY"] = "your_openai_api_key_here"
os.environ["NP_API_KEY"] = "your_neuronpedia_api_key_here"
Download the necessary datasets to axbench/data
:
uv run axbench/scripts/download-seed-sentences.py
cd data
bash download-2b.sh
bash download-9b.sh
bash download-alpaca.sh
To run a complete demo with a single config file:
bash axbench/demo/demo.sh
(If using our pre-generated data, you can skip this.)
Generate training data:
uv run axbench/scripts/generate.py --config axbench/demo/sweep/simple.yaml --dump_dir axbench/demo
Generate inference data:
uv run axbench/scripts/generate_latent.py --config axbench/demo/sweep/simple.yaml --dump_dir axbench/demo
To modify the data generation process, edit simple.yaml
.
Train and save your methods:
uv run torchrun --nproc_per_node=$gpu_count axbench/scripts/train.py \
--config axbench/demo/sweep/simple.yaml \
--dump_dir axbench/demo
(Replace $gpu_count
with the number of GPUs to use.)
For additional config:
torchrun --nproc_per_node=$gpu_count axbench/scripts/train.py \
--config axbench/sweep/wuzhengx/2b/l10/no_grad.yaml \
--dump_dir axbench/results/prod_2b_l10_concept500_no_grad \
--overwrite_data_dir axbench/concept500/prod_2b_l10_v1/generate
where --dump_dir
is the output directory, and --overwrite_data_dir
is where the training data resides.
Run inference:
uv run torchrun --nproc_per_node=$gpu_count axbench/scripts/inference.py \
--config axbench/demo/sweep/simple.yaml \
--dump_dir axbench/demo \
--mode latent
For additional config using custom directories:
uv run torchrun --nproc_per_node=$gpu_count axbench/scripts/inference.py \
--config axbench/sweep/wuzhengx/2b/l10/no_grad.yaml \
--dump_dir axbench/results/prod_2b_l10_concept500_no_grad \
--overwrite_metadata_dir axbench/concept500/prod_2b_l10_v1/generate \
--overwrite_inference_data_dir axbench/concept500/prod_2b_l10_v1/inference \
--mode latent
For real-world scenarios with fewer than 1% positive examples, we upsample negatives (100:1) and re-evaluate. Use:
uv run torchrun --nproc_per_node=$gpu_count axbench/scripts/inference.py \
--config axbench/sweep/wuzhengx/2b/l10/no_grad.yaml \
--dump_dir axbench/results/prod_2b_l10_concept500_no_grad \
--overwrite_metadata_dir axbench/concept500/prod_2b_l10_v1/generate \
--overwrite_inference_data_dir axbench/concept500/prod_2b_l10_v1/inference \
--mode latent_imbalance
For steering experiments:
uv run torchrun --nproc_per_node=$gpu_count axbench/scripts/inference.py \
--config axbench/demo/sweep/simple.yaml \
--dump_dir axbench/demo \
--mode steering
Or a custom run:
uv run torchrun --nproc_per_node=$gpu_count axbench/scripts/inference.py \
--config axbench/sweep/wuzhengx/2b/l10/no_grad.yaml \
--dump_dir axbench/results/prod_2b_l10_concept500_no_grad \
--overwrite_metadata_dir axbench/concept500/prod_2b_l10_v1/generate \
--overwrite_inference_data_dir axbench/concept500/prod_2b_l10_v1/inference \
--mode steering
To evaluate concept detection results:
uv run axbench/scripts/evaluate.py \
--config axbench/demo/sweep/simple.yaml \
--dump_dir axbench/demo \
--mode latent
Enable wandb logging:
uv run axbench/scripts/evaluate.py \
--config axbench/demo/sweep/simple.yaml \
--dump_dir axbench/demo \
--mode latent \
--report_to wandb \
--wandb_entity "your_wandb_entity"
Or evaluate using your custom config:
uv run axbench/scripts/evaluate.py \
--config axbench/sweep/wuzhengx/2b/l10/no_grad.yaml \
--dump_dir axbench/results/prod_2b_l10_concept500_no_grad \
--mode latent
To evaluate steering:
uv run axbench/scripts/evaluate.py \
--config axbench/demo/sweep/simple.yaml \
--dump_dir axbench/demo \
--mode steering
Or a custom config:
uv run axbench/scripts/evaluate.py \
--config axbench/sweep/wuzhengx/2b/l10/no_grad.yaml \
--dump_dir axbench/results/prod_2b_l10_concept500_no_grad \
--mode steering
Please see axbench/experiment_commands.txt
for detailed commands and configurations.