diff --git a/site/en/gemma/docs/agile_classifiers.ipynb b/site/en/gemma/docs/agile_classifiers.ipynb index e3abdf3ab..991552fca 100644 --- a/site/en/gemma/docs/agile_classifiers.ipynb +++ b/site/en/gemma/docs/agile_classifiers.ipynb @@ -646,37 +646,45 @@ "import numpy as np\n", "\n", "\n", - "def softmax_normalization(arr: np.ndarray) -> np.ndarray:\n", - " \"\"\"Normalizes logits values into probabilities summing to one.\"\"\"\n", - " arr_exp = np.exp(arr - np.max(arr))\n", - " return arr_exp / arr_exp.sum()\n", - "\n", - "\n", - "def compute_token_probability(\n", + "def compute_output_probability(\n", " model: keras_nlp.models.GemmaCausalLM,\n", " prompt: str,\n", - " target_tokens: list[str],\n", + " target_classes: list[str],\n", ") -> dict[str, float]:\n", " # Shorthands.\n", " preprocessor = model.preprocessor\n", " tokenizer = preprocessor.tokenizer\n", "\n", - " # Identify output token offset.\n", - " (padding_mask,) = preprocessor.generate_preprocess([prompt])['padding_mask']\n", - " token_offset = sum(padding_mask.numpy()) - 1\n", - "\n", - " # Compute prediction, extract only the next token's logits.\n", - " (logits,) = model.predict([prompt], verbose=0)\n", - " token_logits = logits[token_offset]\n", + " # NOTE: If a token is not found, it will be considered same as \"\".\n", + " token_unk = tokenizer.token_to_id('')\n", "\n", " # Identify the token indices, which is the same as the ID for this tokenizer.\n", - " # NOTE: If a token is not found, it will be considered same as \"\".\n", - " token_ids = [tokenizer.token_to_id(token) for token in target_tokens]\n", + " token_ids = [tokenizer.token_to_id(word) for word in target_classes]\n", + "\n", + " # Throw an error if one of the classes maps to a token outside the vocabulary.\n", + " if any(token_id == token_unk for token_id in token_ids):\n", + " raise ValueError('One of the target classes is not in the vocabulary.')\n", + "\n", + " # Preprocess the prompt in a single batch. This is done one sample at a time\n", + " # for illustration purposes, but it would be more efficient to batch prompts.\n", + " preprocessed = model.preprocessor.generate_preprocess([prompt])\n", + "\n", + " # Identify output token offset.\n", + " padding_mask = preprocessed[\"padding_mask\"]\n", + " token_offset = keras.ops.sum(padding_mask) - 1\n", + "\n", + " # Score outputs, extract only the next token's logits.\n", + " vocab_logits = model.score(\n", + " token_ids=preprocessed[\"token_ids\"],\n", + " padding_mask=padding_mask,\n", + " )[0][token_offset]\n", "\n", " # Compute the relative probability of each of the requested tokens.\n", - " probabilities = softmax_normalization([token_logits[ix] for ix in token_ids])\n", + " token_logits = [vocab_logits[ix] for ix in token_ids]\n", + " logits_tensor = keras.ops.convert_to_tensor(token_logits)\n", + " probabilities = keras.activations.softmax(logits_tensor)\n", "\n", - " return dict(zip(target_tokens, probabilities))" + " return dict(zip(target_classes, probabilities.numpy()))" ] }, { @@ -707,10 +715,10 @@ } ], "source": [ - "compute_token_probability(\n", + "compute_output_probability(\n", " model=model,\n", " prompt=prompt,\n", - " target_tokens=['Positive', 'Negative'],\n", + " target_classes=['Positive', 'Negative'],\n", ")" ] }, @@ -743,7 +751,7 @@ " \"\"\"Agile classifier to be wrapped around a LLM.\"\"\"\n", "\n", " # The classes whose probability will be predicted.\n", - " labels: tuple\n", + " labels: tuple[str, ...]\n", "\n", " # Provide default instructions and control tokens, can be overridden by user.\n", " instructions: str = 'Classify the following text into one of the following classes'\n", @@ -771,10 +779,10 @@ " x_text: str,\n", " ) -> list[float]:\n", " prompt = self.encode_for_prediction(x_text)\n", - " token_probabilities = compute_token_probability(\n", + " token_probabilities = compute_output_probability(\n", " model=model,\n", " prompt=prompt,\n", - " target_tokens=self.labels,\n", + " target_classes=self.labels,\n", " )\n", " return [token_probabilities[token] for token in self.labels]\n", "\n",