diff --git a/site/en/gemma/docs/distributed_tuning.ipynb b/site/en/gemma/docs/distributed_tuning.ipynb index 90801d948..30b8d99fc 100644 --- a/site/en/gemma/docs/distributed_tuning.ipynb +++ b/site/en/gemma/docs/distributed_tuning.ipynb @@ -366,16 +366,16 @@ "layout_map = keras.distribution.LayoutMap(device_mesh)\n", "\n", "# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs\n", - "layout_map[\"token_embedding/embeddings\"] = (None, model_dim)\n", + "layout_map[\"token_embedding/embeddings\"] = (model_dim, None)\n", "# Regex to match against the query, key and value matrices in the decoder\n", "# attention layers\n", "layout_map[\"decoder_block.*attention.*(query|key|value).*kernel\"] = (\n", - " None, model_dim, None)\n", + " model_dim, None, None)\n", "\n", "layout_map[\"decoder_block.*attention_output.*kernel\"] = (\n", - " None, None, model_dim)\n", - "layout_map[\"decoder_block.*ffw_gating.*kernel\"] = (model_dim, None)\n", - "layout_map[\"decoder_block.*ffw_linear.*kernel\"] = (None, model_dim)" + " model_dim, None, None)\n", + "layout_map[\"decoder_block.*ffw_gating.*kernel\"] = (None, model_dim)\n", + "layout_map[\"decoder_block.*ffw_linear.*kernel\"] = (model_dim, None)" ] }, {