Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update agile_classifiers.ipynb #481

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions site/en/gemma/docs/agile_classifiers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -646,37 +646,45 @@
"import numpy as np\n",
"\n",
"\n",
"def softmax_normalization(arr: np.ndarray) -> np.ndarray:\n",
" \"\"\"Normalizes logits values into probabilities summing to one.\"\"\"\n",
" arr_exp = np.exp(arr - np.max(arr))\n",
" return arr_exp / arr_exp.sum()\n",
"\n",
"\n",
"def compute_token_probability(\n",
"def compute_output_probability(\n",
" model: keras_nlp.models.GemmaCausalLM,\n",
" prompt: str,\n",
" target_tokens: list[str],\n",
" target_classes: list[str],\n",
") -> dict[str, float]:\n",
" # Shorthands.\n",
" preprocessor = model.preprocessor\n",
" tokenizer = preprocessor.tokenizer\n",
"\n",
" # Identify output token offset.\n",
" (padding_mask,) = preprocessor.generate_preprocess([prompt])['padding_mask']\n",
" token_offset = sum(padding_mask.numpy()) - 1\n",
"\n",
" # Compute prediction, extract only the next token's logits.\n",
" (logits,) = model.predict([prompt], verbose=0)\n",
" token_logits = logits[token_offset]\n",
" # NOTE: If a token is not found, it will be considered same as \"<unk>\".\n",
" token_unk = tokenizer.token_to_id('<unk>')\n",
"\n",
" # Identify the token indices, which is the same as the ID for this tokenizer.\n",
" # NOTE: If a token is not found, it will be considered same as \"<unk>\".\n",
" token_ids = [tokenizer.token_to_id(token) for token in target_tokens]\n",
" token_ids = [tokenizer.token_to_id(word) for word in target_classes]\n",
"\n",
" # Throw an error if one of the classes maps to a token outside the vocabulary.\n",
" if any(token_id == token_unk for token_id in token_ids):\n",
" raise ValueError('One of the target classes is not in the vocabulary.')\n",
"\n",
" # Preprocess the prompt in a single batch. This is done one sample at a time\n",
" # for illustration purposes, but it would be more efficient to batch prompts.\n",
" preprocessed = model.preprocessor.generate_preprocess([prompt])\n",
"\n",
" # Identify output token offset.\n",
" padding_mask = preprocessed[\"padding_mask\"]\n",
" token_offset = keras.ops.sum(padding_mask) - 1\n",
"\n",
" # Score outputs, extract only the next token's logits.\n",
" vocab_logits = model.score(\n",
" token_ids=preprocessed[\"token_ids\"],\n",
" padding_mask=padding_mask,\n",
" )[0][token_offset]\n",
"\n",
" # Compute the relative probability of each of the requested tokens.\n",
" probabilities = softmax_normalization([token_logits[ix] for ix in token_ids])\n",
" token_logits = [vocab_logits[ix] for ix in token_ids]\n",
" logits_tensor = keras.ops.convert_to_tensor(token_logits)\n",
" probabilities = keras.activations.softmax(logits_tensor)\n",
"\n",
" return dict(zip(target_tokens, probabilities))"
" return dict(zip(target_classes, probabilities.numpy()))"
]
},
{
Expand Down Expand Up @@ -707,10 +715,10 @@
}
],
"source": [
"compute_token_probability(\n",
"compute_output_probability(\n",
" model=model,\n",
" prompt=prompt,\n",
" target_tokens=['Positive', 'Negative'],\n",
" target_classes=['Positive', 'Negative'],\n",
")"
]
},
Expand Down Expand Up @@ -743,7 +751,7 @@
" \"\"\"Agile classifier to be wrapped around a LLM.\"\"\"\n",
"\n",
" # The classes whose probability will be predicted.\n",
" labels: tuple\n",
" labels: tuple[str, ...]\n",
"\n",
" # Provide default instructions and control tokens, can be overridden by user.\n",
" instructions: str = 'Classify the following text into one of the following classes'\n",
Expand Down Expand Up @@ -771,10 +779,10 @@
" x_text: str,\n",
" ) -> list[float]:\n",
" prompt = self.encode_for_prediction(x_text)\n",
" token_probabilities = compute_token_probability(\n",
" token_probabilities = compute_output_probability(\n",
" model=model,\n",
" prompt=prompt,\n",
" target_tokens=self.labels,\n",
" target_classes=self.labels,\n",
" )\n",
" return [token_probabilities[token] for token in self.labels]\n",
"\n",
Expand Down
Loading