Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor tweaks to JAX Gemma docs #290

Merged
merged 2 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions site/en/gemma/docs/jax_finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
"\n",
"After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n",
"\n",
"### Set environment variables\n",
"### 2. Set environment variables\n",
"\n",
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the \"Grant access?\" messages, agree to provide secret access."
]
Expand All @@ -128,7 +128,7 @@
"id": "m1UE1CEnE9ql"
},
"source": [
"### 2. Install the `gemma` library\n",
"### 3. Install the `gemma` library\n",
"\n",
"Free Colab hardware acceleration is currently *insufficient* to run this notebook. If you are using [Colab Pay As You Go or Colab Pro](https://colab.research.google.com/signup), click on **Edit** > **Notebook settings** > Select **A100 GPU** > **Save** to enable hardware acceleration.\n",
"\n",
Expand Down Expand Up @@ -170,7 +170,7 @@
"id": "-mRkkT-iPYoq"
},
"source": [
"### 3. Import libraries\n",
"### 4. Import libraries\n",
"\n",
"This notebook uses [Flax](https://flax.readthedocs.io) (for neural networks), core [JAX](https://jax.readthedocs.io), [SentencePiece](https://github.com/google/sentencepiece) (for tokenization), [Chex](https://chex.readthedocs.io/en/latest/) (a library of utilities for writing reliable JAX code), and TensorFlow Datasets."
]
Expand Down Expand Up @@ -912,9 +912,9 @@
"source": [
"## Configure the model\n",
"\n",
"Before you begin fine-tuning the Gemma model, configure it as follows:\n",
"Before you begin fine-tuning the Gemma model, you need to configure it.\n",
"\n",
"Load and format the Gemma model checkpoint with the [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py) method:"
"First, load and format the Gemma model checkpoint with the [`gemma.params.load_and_format_params`](https://github.com/google-deepmind/gemma/blob/c6bd156c246530e1620a7c62de98542a377e3934/gemma/params.py#L27) method:"
]
},
{
Expand All @@ -934,7 +934,7 @@
"id": "BtJhJkkZzsy1"
},
"source": [
"To automatically load the correct configuration from the Gemma model checkpoint, use [`gemma.transformer.TransformerConfig`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L65). The `cache_size` argument is the number of time steps in the Gemma `transformer` cache. Afterwards, instantiate the Gemma model as `transformer` with [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) (which inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html).\n",
"To automatically load the correct configuration from the Gemma model checkpoint, use [`gemma.transformer.TransformerConfig`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L65). The `cache_size` argument is the number of time steps in the Gemma `Transformer` cache. Afterwards, instantiate the Gemma model as `model_2b` with [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) (which inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html)).\n",
"\n",
"**Note:** The vocabulary size is smaller than the number of input embeddings because of unused tokens in the current Gemma release."
]
Expand Down Expand Up @@ -1375,7 +1375,7 @@
"source": [
"## Learn more\n",
"\n",
"- You can learn more about the Google DeepMind [`gemma` library on GitHub](https://github.com/google-deepmind/gemma), which contains docstrings of methods you used in this tutorial, such as [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py),\n",
"- You can learn more about the Google DeepMind [`gemma` library on GitHub](https://github.com/google-deepmind/gemma), which contains docstrings of modules you used in this tutorial, such as [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py),\n",
"[`gemma.transformer`](https://github.com/google-deepmind/gemma/blob/main/gemma/transformer.py), and\n",
"[`gemma.sampler`](https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py).\n",
"- The following libraries have their own documentation sites: [core JAX](https://jax.readthedocs.io), [Flax](https://flax.readthedocs.io), [Chex](https://chex.readthedocs.io/en/latest/), [Optax](https://optax.readthedocs.io/en/latest/), and [Orbax](https://orbax.readthedocs.io/).\n",
Expand Down
12 changes: 6 additions & 6 deletions site/en/gemma/docs/jax_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"\n",
"After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n",
"\n",
"### Set environment variables\n",
"### 2. Set environment variables\n",
"\n",
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the \"Grant access?\" messages, agree to provide secret access."
]
Expand All @@ -126,7 +126,7 @@
"id": "AO7a1Q4Yyc9Z"
},
"source": [
"### 2. Install the `gemma` library\n",
"### 3. Install the `gemma` library\n",
"\n",
"This notebook focuses on using a free Colab GPU. To enable hardware acceleration, click on **Edit** > **Notebook settings** > Select **T4 GPU** > **Save**.\n",
"\n",
Expand Down Expand Up @@ -291,7 +291,7 @@
"id": "aEe3p8geqekV"
},
"source": [
"1. Load and format the Gemma model checkpoint with the [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py) method:"
"1. Load and format the Gemma model checkpoint with the [`gemma.params.load_and_format_params`](https://github.com/google-deepmind/gemma/blob/c6bd156c246530e1620a7c62de98542a377e3934/gemma/params.py#L27) method:"
]
},
{
Expand Down Expand Up @@ -347,7 +347,7 @@
"id": "IkAf4fkNrY-3"
},
"source": [
"3. To automatically load the correct configuration from the Gemma model checkpoint, use [`gemma.transformer.TransformerConfig`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L65). The `cache_size` argument is the number of time steps in the Gemma `transformer` cache. Afterwards, instantiate the Gemma model as `transformer` with [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) (which inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html)).\n",
"3. To automatically load the correct configuration from the Gemma model checkpoint, use [`gemma.transformer.TransformerConfig`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L65). The `cache_size` argument is the number of time steps in the Gemma `Transformer` cache. Afterwards, instantiate the Gemma model as `transformer` with [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) (which inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html)).\n",
"\n",
"**Note:** The vocabulary size is smaller than the number of input embeddings because of unused tokens in the current Gemma release."
]
Expand Down Expand Up @@ -452,7 +452,7 @@
"id": "njxRJy3qsBWw"
},
"source": [
"5. (Option) Run this cell to free up memory if you have completed the notebook and want to try another prompt. Afterwards, you can instantiate the `sampler` again in step 3 and customize and run the prompt in step 4."
"5. (Optional) Run this cell to free up memory if you have completed the notebook and want to try another prompt. Afterwards, you can instantiate the `sampler` again in step 3 and customize and run the prompt in step 4."
]
},
{
Expand All @@ -474,7 +474,7 @@
"source": [
"## Learn more\n",
"\n",
"- You can learn more about the Google DeepMind [`gemma` library on GitHub](https://github.com/google-deepmind/gemma), which contains docstrings of methods you used in this tutorial, such as [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py),\n",
"- You can learn more about the Google DeepMind [`gemma` library on GitHub](https://github.com/google-deepmind/gemma), which contains docstrings of modules you used in this tutorial, such as [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py),\n",
"[`gemma.transformer`](https://github.com/google-deepmind/gemma/blob/main/gemma/transformer.py), and\n",
"[`gemma.sampler`](https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py).\n",
"- The following libraries have their own documentation sites: [core JAX](https://jax.readthedocs.io), [Flax](https://flax.readthedocs.io), and [Orbax](https://orbax.readthedocs.io/).\n",
Expand Down
Loading