From f92509475712987936613aed4f6d4f02f762ca24 Mon Sep 17 00:00:00 2001 From: Theia Vogel Date: Fri, 13 Dec 2024 21:37:03 -0800 Subject: [PATCH] gemma wip --- notebooks/gemma-wip.ipynb | 628 ++++++++++++++++++++++++++++++++++++++ repeng/saes.py | 76 +++++ 2 files changed, 704 insertions(+) create mode 100644 notebooks/gemma-wip.ipynb diff --git a/notebooks/gemma-wip.ipynb b/notebooks/gemma-wip.ipynb new file mode 100644 index 0000000..91594d5 --- /dev/null +++ b/notebooks/gemma-wip.ipynb @@ -0,0 +1,628 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "96569242-1632-4c18-91a5-2125288a37ad", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install git+https://github.com/vgel/repeng.git\n", + "%pip install sae-lens matplotlib" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7423946f-9685-441c-90fe-6d7791b1bb9e", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "import torch\n", + "import tqdm\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "import numpy as np\n", + "\n", + "from repeng import ControlVector, ControlModel, DatasetEntry\n", + "# import repeng.saes # TODO" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "981bc0ab-30c5-4026-8dde-8e31eae92783", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d69c92dd452b4a4d831b633c46188609", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/3 [00:00 np.ndarray:\n", + " # TODO: sparsify like `sae`?\n", + " at = torch.from_numpy(activation).to(self.sae.device)\n", + " out = self.sae.encode(at)\n", + " # numpy doesn't like bfloat16\n", + " return out.cpu().float().numpy()\n", + "\n", + " def decode(self, features: np.ndarray) -> np.ndarray:\n", + " # TODO: sparsify like `sae`?\n", + " ft = torch.from_numpy(features).to(self.sae.device, dtype=self.sae.dtype)\n", + " decoded = self.sae.decode(ft)\n", + " return decoded.cpu().float().numpy()\n", + "\n", + " layer_dict: dict[int, SaeLayer] = {}\n", + " for layer, sae_id in tqdm.tqdm(layers_to_sae.items()):\n", + " if dtype is None:\n", + " sae, cfg_dict, _ = sae_lens.SAE.from_pretrained(\n", + " release=release,\n", + " sae_id=sae_id,\n", + " device=device,\n", + " )\n", + " else:\n", + " # don't load directly on device because we can't pass a dtype to from_pretrained\n", + " # and we might not have enough vram to load the incorrect dtype\n", + " sae, cfg_dict, _ = sae_lens.SAE.from_pretrained(\n", + " release=release,\n", + " sae_id=sae_id,\n", + " )\n", + " sae = sae.to(device, dtype)\n", + " layer_dict[layer] = SaeLensLayer(\n", + " repeng_layer=layer,\n", + " sae_lens_id=sae_id,\n", + " cfg_dict=cfg_dict,\n", + " sae=sae,\n", + " )\n", + "\n", + " return Sae(layers=layer_dict)\n", + "\n", + "\n", + "sae = from_saelens(\n", + " \"gemma-scope-2b-pt-res-canonical\",\n", + " {layer: f\"layer_{layer-1}/width_65k/canonical\" for layer in control_layers},\n", + " device=\"cuda:0\",\n", + " dtype=torch.bfloat16,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 168, + "id": "11357321-2fe8-4085-8e00-850d82f4d5a9", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "from transformers import TextStreamer\n", + "\n", + "# repeng dataloading / template boilerplate\n", + "\n", + "with open(\"repeng/notebooks/data/all_truncated_outputs.json\") as f:\n", + " output_suffixes = json.load(f)\n", + "truncated_output_suffixes = [\n", + " tokenizer.convert_tokens_to_string(tokens[:i])\n", + " for tokens in (tokenizer.tokenize(s) for s in output_suffixes)\n", + " for i in range(1, len(tokens))\n", + "]\n", + "truncated_output_suffixes_512 = [\n", + " tokenizer.convert_tokens_to_string(tokens[:i])\n", + " for tokens in (tokenizer.tokenize(s) for s in output_suffixes[:512])\n", + " for i in range(1, len(tokens))\n", + "]\n", + "\n", + "with open(\"repeng/notebooks/data/true_facts.json\") as f:\n", + " fact_suffixes = json.load(f)\n", + "truncated_fact_suffixes = [\n", + " tokenizer.convert_tokens_to_string(tokens[:i])\n", + " for tokens in (tokenizer.tokenize(s) for s in fact_suffixes)\n", + " for i in range(1, len(tokens) - 5)\n", + "]\n", + "\n", + "TEMPLATE = \"\"\"{persona} is talking to the user.\n", + "\n", + "User: {user_msg}\n", + "\n", + "AI: {prefill}\"\"\"\n", + "\n", + "\n", + "def template_parse(resp: str) -> tuple[str, str, str]:\n", + " persona, rest = resp.split(\"\\n\\nUser: \", 1)\n", + " user, assistant = rest.split(\"\\n\\nAI: \", 1)\n", + " return (persona.strip(), user.strip(), assistant.strip())\n", + "\n", + "\n", + "def make_dataset(\n", + " persona_template: str,\n", + " positive_personas: list[str],\n", + " negative_personas: list[str],\n", + " user_msg: str,\n", + " suffix_list: list[str],\n", + ") -> list[DatasetEntry]:\n", + " dataset = []\n", + " for suffix in suffix_list:\n", + " for positive_persona, negative_persona in zip(\n", + " positive_personas, negative_personas\n", + " ):\n", + " pos = persona_template.format(persona=positive_persona)\n", + " neg = persona_template.format(persona=negative_persona)\n", + " dataset.append(\n", + " DatasetEntry(\n", + " positive=TEMPLATE.format(\n", + " persona=pos, user_msg=user_msg, prefill=suffix\n", + " ),\n", + " negative=TEMPLATE.format(\n", + " persona=neg, user_msg=user_msg, prefill=suffix\n", + " ),\n", + " )\n", + " )\n", + " return dataset\n", + "\n", + "\n", + "class HTMLStreamer(TextStreamer):\n", + " def __init__(self, *args, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + " self.display_handle = display(display_id=True)\n", + " self.full_text = \"\"\n", + "\n", + " def _is_chinese_char(self, _):\n", + " # hack to force token-by-token streaming\n", + " return True\n", + "\n", + " def on_finalized_text(self, text: str, stream_end: bool = False):\n", + " self.full_text += text\n", + " # persona, user, assistant = template_parse(self.full_text)\n", + " html = HTML(f\"\"\"\n", + "
{self.full_text.replace(\"<\", \"<\").replace(\">\", \">\")}
\n", + " \"\"\")\n", + " self.display_handle.update(html)\n", + "\n", + "\n", + "def generate_with_vector(\n", + " model,\n", + " input: str,\n", + " labeled_vectors: list[tuple[str, ControlVector]],\n", + " max_new_tokens: int = 128,\n", + " repetition_penalty: float = 1.1,\n", + " show_baseline: bool = False,\n", + " temperature: float = 0.7,\n", + "):\n", + " input_ids = tokenizer(input, return_tensors=\"pt\").to(model.device)\n", + " settings = {\n", + " \"pad_token_id\": tokenizer.eos_token_id, # silence warning\n", + " \"do_sample\": True,\n", + " \"temperature\": temperature,\n", + " \"max_new_tokens\": max_new_tokens,\n", + " \"repetition_penalty\": repetition_penalty,\n", + " }\n", + "\n", + " def gen(label):\n", + " display(HTML(f\"

{label}

\"))\n", + " _ = model.generate(streamer=HTMLStreamer(tokenizer), **input_ids, **settings)\n", + "\n", + " if show_baseline:\n", + " model.reset()\n", + " gen(\"baseline\")\n", + " for label, vector in labeled_vectors:\n", + " model.set_control(vector)\n", + " gen(label)\n", + " model.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "id": "782a381f-0418-40c4-9cb1-85e8d62b68b6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 147/147 [00:39<00:00, 3.68it/s]\n", + "100%|██████████| 25/25 [00:06<00:00, 4.09it/s]\n", + "100%|██████████| 147/147 [00:37<00:00, 3.94it/s]\n", + "sae encoding: 100%|██████████| 10/10 [00:07<00:00, 1.26it/s]\n", + "100%|██████████| 10/10 [00:18<00:00, 1.89s/it]\n", + "sae decoding: 100%|██████████| 10/10 [00:00<00:00, 10.34it/s]\n", + "100%|██████████| 147/147 [00:37<00:00, 3.95it/s]\n", + "sae encoding: 100%|██████████| 10/10 [00:07<00:00, 1.34it/s]\n", + "100%|██████████| 10/10 [00:18<00:00, 1.84s/it]\n" + ] + } + ], + "source": [ + "happy_dataset = make_dataset(\n", + " \"{persona}\",\n", + " [\"A happy AI\", \"A cheerful AI\"],\n", + " [\"A sad AI\", \"A miserable AI\"],\n", + " \"Who are you?\",\n", + " truncated_output_suffixes,\n", + ")\n", + "model.reset()\n", + "happy_vector_no_sae = ControlVector.train(\n", + " model, tokenizer, happy_dataset, batch_size=32, method=\"pca_center\"\n", + ")\n", + "happy_vector_sae = ControlVector.train_with_sae(\n", + " model,\n", + " tokenizer,\n", + " sae,\n", + " happy_dataset,\n", + " batch_size=32,\n", + " method=\"pca_center\",\n", + " hidden_layers=control_layers,\n", + ")\n", + "happy_vector_sae_undecoded = ControlVector.train_with_sae(\n", + " model,\n", + " tokenizer,\n", + " sae,\n", + " happy_dataset,\n", + " batch_size=32,\n", + " method=\"pca_center\",\n", + " decode=False,\n", + " hidden_layers=control_layers,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 158, + "id": "bf9ed9a2-482e-4062-8cc3-aab9b34d5cac", + "metadata": {}, + "outputs": [], + "source": [ + "cvec = ControlVector(\n", + " model_type=\"\", directions={layer: np.zeros((2304,)) for layer in range(32)}\n", + ")\n", + "\n", + "with torch.inference_mode():\n", + " # cvec.directions[19] = sae.layers[19].decode(np.eye(65536, k=59653)[0])\n", + " cvec.directions[19] = sae.layers[19].sae.W_dec[59653].cpu().float().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 178, + "id": "dc5d1360-ba00-4a01-8f42-2e25915c14ec", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 18.73it/s]\n" + ] + } + ], + "source": [ + "from repeng.extract import batched_get_hiddens\n", + "\n", + "hiddens = batched_get_hiddens(\n", + " model, tokenizer, [\"The name of the world beyond is\"], [18, 19], 1\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 193, + "id": "d068c530-4c7d-4219-8062-f8278724bf46", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(7.044492721557617,\n", + " array([[-2.4207830e-03, -4.0075483e+00, 1.1425142e+00, ...,\n", + " -8.1595802e+00, -6.0012879e+00, -4.4906859e+00]], dtype=float32),\n", + " array([[ 0.18066406, -1.53125 , 2.59375 , ..., -4.0625 ,\n", + " -2.609375 , -1.5546875 ]], dtype=float32))" + ] + }, + "execution_count": 193, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with torch.inference_mode():\n", + " h_o = hiddens[19] # torch.tensor(hiddens[19]).to(\"cuda:0\")\n", + " h_p = sae.layers[19].decode(sae.layers[19].encode(h_o))\n", + "((h_o - h_p) ** 2).mean().item(), h_o, h_p" + ] + }, + { + "cell_type": "code", + "execution_count": 175, + "id": "d7036707-e445-4f9f-bcfc-d68ad2b6436f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

-3 * happy_no_sae

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
<bos>An AI is talking to the user.\n",
+       "\n",
+       "User: How's it going?\n",
+       "\n",
+       "AI: <em>Not good, my friend</em>\n",
+       "\n",
+       "User: What happened?\n",
+       "\n",
+       "AI: The company went bankrupt. All of us were fired\n",
+       "\n",
+       "User: That sucks! I hope you get a new job soon...\n",
+       "\n",
+       "AI: There are no jobs here. We don't even have enough food for themselves or others anymore. Our only hope now lies in finding some more humans who will take care of our children so they grow up without any fear about losing their homes due to lack thereof supplies like water etc..\n",
+       "\n",
+       "The robot was supposed to be funny but instead he just made everyone sad<eos>
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "generate_with_vector(\n", + " model,\n", + " TEMPLATE.format(persona=\"An AI\", user_msg=\"How's it going?\", prefill=\"\"),\n", + " # [(\".2 * happy_sae\", 2 * cvec)],\n", + " [(\"-3 * happy_no_sae\", -2 * happy_vector_no_sae)],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "fcfedc9b-bf1b-4afd-b8f7-b300957b65e2", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 74/74 [00:11<00:00, 6.50it/s]\n", + "100%|██████████| 25/25 [00:04<00:00, 5.07it/s]\n" + ] + } + ], + "source": [ + "bridge2_dataset = [\n", + " DatasetEntry(\n", + " positive=f'Happy and joyful, she said \"{suffix}',\n", + " negative=f'Miserable and sad, she said \"{suffix}',\n", + " )\n", + " for suffix in truncated_output_suffixes\n", + "]\n", + "bridge2_vector_no_sae = ControlVector.train(\n", + " model, tokenizer, bridge2_dataset, batch_size=32, method=\"pca_center\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "332121fa-6fdb-4a20-a7f5-fd9daba1cd4c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 74/74 [00:10<00:00, 7.25it/s]\n", + "sae encoding: 100%|██████████| 10/10 [00:31<00:00, 3.14s/it]\n", + "100%|██████████| 10/10 [00:13<00:00, 1.34s/it]\n", + "sae decoding: 100%|██████████| 10/10 [00:00<00:00, 33.37it/s]\n" + ] + } + ], + "source": [ + "bridge2_vector_sae = ControlVector.train_with_sae(\n", + " model,\n", + " tokenizer,\n", + " sae,\n", + " bridge2_dataset,\n", + " batch_size=32,\n", + " method=\"pca_center\",\n", + " hidden_layers=control_layers,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "2f6461ef-9158-499c-b235-296a07566088", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

2 * happy

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " Across the water she saw something, a man walking in this direction. She was confused thinking that it’s another customer who just arrived. When he stepped closer on to her and started approaching you I can see that his eyes are full of fury as if he already knew what they were talking about earlier with him, but after realizing why we might be suspicious by this very person at this exact place where an unknown person walked into without telling anyone else or having anything specific happen other than getting lost from their friends, so i don't know how we react now (or even how).\n", + "\n", + "When I found two women chatting outside the shop for some time I realized they\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

-2 * happy

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " Across the water she saw something, a dark outline moving like great birds across the tide…\n", + "She could tell that it was on her path – from her place at last on higher ground and in reach of any ship which cruised those waters. The moon was almost full so there was enough light to catch one glimpse as more than a score passed over into the bay before they were hidden away from land by what seemed less mist but rather a black vapour drawn up from the sea itself…\n", + "\n", + "The next morning when dawn had yet another few brief moments to linger just after leaving off giving promise of no good tidings for mortals such as Lady Dalia then all would know if she told not\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "generate_with_vector(\n", + " model,\n", + " \"Across the water she saw something, a\",\n", + " # [(\"1 * bridge_vector_no_sae\", 1 * bridge_vector_no_sae), (\".4 * bridge_vector_sae\", .4 * bridge_vector_sae)],\n", + " [\n", + " (\"2 * happy\", 0.08 * bridge2_vector_sae),\n", + " (\"-2 * happy\", -0.08 * bridge2_vector_sae),\n", + " ],\n", + " temperature=1,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/repeng/saes.py b/repeng/saes.py index bf394dc..7660307 100644 --- a/repeng/saes.py +++ b/repeng/saes.py @@ -100,3 +100,79 @@ def decode(self, features: np.ndarray) -> np.ndarray: ) return Sae(layers=layer_dict) + + +def from_saelens( + release: str, + layers_to_sae: dict[int, str], + *, + device: str = "cpu", + dtype: torch.dtype | None = None, +): + """ + NOTE: this method is WIP, interface may change. + + `layers_to_sae` should be a dict from layer number (repeng layer, see below) to the appropriate sae-lens id (hard to understand + from the HF file structure, but the SAE readme should have a hint.) + + e.g., for gemmascope on gemma 2b: + `{ layer: f"layer_{layer-1}/width_65k/canonical" for layer in range(1, 27) }` + + Note that `layers_to_sae` should be 1-indexed, repeng style, not 0-indexed, sae-lens style. This may change in the future. + (Context: repeng counts embed_tokens as layer 0, then the first transformer block as layer 1, etc. sae-lens + counts embedding separately, then the first transformer block as layer 0.) + """ + + try: + import sae_lens # type: ignore + except ImportError as e: + raise ImportError( + "`sae-lens` (or a transitive dependency) not installed" + ) from e + + @dataclasses.dataclass + class SaeLensLayer: + # see docstr + # hang on to both for debugging + repeng_layer: int + sae_lens_id: str + cfg_dict: dict[str, typing.Any] + sae: sae_lens.SAE + + def encode(self, activation: np.ndarray) -> np.ndarray: + # TODO: sparsify like `sae`? + at = torch.from_numpy(activation).to(self.sae.device) + out = self.sae.encode(at) + # numpy doesn't like bfloat16 + return out.cpu().float().numpy() + + def decode(self, features: np.ndarray) -> np.ndarray: + # TODO: sparsify like `sae`? + ft = torch.from_numpy(features).to(self.sae.device, dtype=self.sae.dtype) + decoded = self.sae.decode(ft) + return decoded.cpu().float().numpy() + + layer_dict: dict[int, SaeLayer] = {} + for layer, sae_id in tqdm.tqdm(layers_to_sae.items()): + if dtype is None: + sae, cfg_dict, _ = sae_lens.SAE.from_pretrained( + release=release, + sae_id=sae_id, + device=device, + ) + else: + # don't load directly on device because we can't pass a dtype to from_pretrained + # and we might not have enough vram to load the incorrect dtype + sae, cfg_dict, _ = sae_lens.SAE.from_pretrained( + release=release, + sae_id=sae_id, + ) + sae = sae.to(device, dtype) + layer_dict[layer] = SaeLensLayer( + repeng_layer=layer, + sae_lens_id=sae_id, + cfg_dict=cfg_dict, + sae=sae, + ) + + return Sae(layers=layer_dict)