From 228f34bb637070d33f4b62c36ad56af6bfc19fe5 Mon Sep 17 00:00:00 2001 From: kelvin-zou Date: Wed, 8 Jan 2025 15:05:44 -0800 Subject: [PATCH] snapshot --- axlearn/common/attention.py | 38 ++++----- axlearn/common/compiler_options.py | 17 ++-- axlearn/common/utils.py | 46 +++++++++- axlearn/common/utils_test.py | 82 ++++++++++++++++++ .../fuji-1B-v3-flash-single-host.txt | 10 +-- .../fuji-1B-v3-flash.txt | 16 ++-- .../fuji-1B-v3-single-host.txt | 10 +-- .../fuji-1B-v3.txt | 16 ++-- .../fuji-3B-v3-flash-single-host.txt | 10 +-- .../fuji-3B-v3-flash.txt | 16 ++-- .../fuji-3B-v3-single-host.txt | 10 +-- .../fuji-3B-v3.txt | 16 ++-- .../fuji-70B-v1-flash.txt | 49 +++++++++-- .../fuji-70B-v1.txt | 49 +++++++++-- .../fuji-70B-v2-flash.txt | 49 +++++++++-- .../fuji-70B-v2.txt | 49 +++++++++-- .../fuji-70B-v3-flash.txt | 65 +++++++++++---- .../fuji-70B-v3.txt | 65 +++++++++++---- .../fuji-7B-v1-flash-single-host.txt | 38 ++++++--- .../fuji-7B-v1-flash.txt | 32 +++++-- .../fuji-7B-v1-single-host.txt | 38 ++++++--- .../fuji-7B-v1.txt | 32 +++++-- .../fuji-7B-v2-flash-single-host.txt | 38 ++++++--- .../fuji-7B-v2-flash.txt | 32 +++++-- .../fuji-7B-v2-single-host.txt | 38 ++++++--- .../fuji-7B-v2.txt | 32 +++++-- .../fuji-8B-v3-flash-single-host.txt | 10 +-- .../fuji-8B-v3-flash.txt | 16 ++-- .../fuji-8B-v3-single-host.txt | 10 +-- .../fuji-8B-v3.txt | 16 ++-- axlearn/experiments/text/gpt/fuji.py | 83 ++++++++++++++++++- 31 files changed, 781 insertions(+), 247 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 37baf3d8b..662e28899 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -56,7 +56,6 @@ import enum import functools import math -import re from collections.abc import Sequence from enum import Enum, unique from typing import Any, Callable, NamedTuple, Optional, Protocol, Union @@ -64,9 +63,6 @@ import einops import jax from jax import numpy as jnp -from jax._src.ad_checkpoint import name_p -from jax._src.interpreters import partial_eval as pe -from jax.core import Primitive from axlearn.common import ops, param_init from axlearn.common.attention_bias import ( @@ -120,13 +116,16 @@ from axlearn.common.utils import ( Nested, NestedTensor, + OffloadPolicy, PartitionSpec, + SavePattern, Tensor, TensorSpec, VDict, check_numerics, flatten_items, get_or_none, + save_and_offload_only_these_names_regex, shapes, split_prng_key, ) @@ -3930,30 +3929,23 @@ def forward( # TODO(sneha): extend_step -OffloadPolicy = Callable[[Primitive, list[Any], dict[str, Any]], Union[bool, Any]] -_SavePattern = Union[str, re.Pattern, None] - - # Adapted from jax source code to support regex. Reference: # https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120 +# TODO(kelvin-zou): deprecated, keep it here to minimize distruption to the golden configs. +# Please use axlearn.common.utils.extended_checkpoint_policies instead. def _save_and_offload_only_these_names_regex( *, - names_which_can_be_saved: _SavePattern, - names_which_can_be_offloaded: _SavePattern, + names_which_can_be_saved: SavePattern, + names_which_can_be_offloaded: SavePattern, offload_src: str, offload_dst: str, ) -> OffloadPolicy: - def policy(prim, *_, **params): - if prim is name_p: - if names_which_can_be_saved and re.fullmatch(names_which_can_be_saved, params["name"]): - return pe.Saveable - if names_which_can_be_offloaded and re.fullmatch( - names_which_can_be_offloaded, params["name"] - ): - return pe.Offloadable(src=offload_src, dst=offload_dst) - return pe.Recompute # not saveable unless it's in the allow-list - - return policy + return save_and_offload_only_these_names_regex( + names_which_can_be_saved=names_which_can_be_saved, + names_which_can_be_offloaded=names_which_can_be_offloaded, + offload_src=offload_src, + offload_dst=offload_dst, + ) SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)" @@ -3964,8 +3956,8 @@ def build_remat_spec( stack_cfg: Union[ BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore ], - save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN, - offload_pattern: _SavePattern = None, + save_pattern: SavePattern = SELF_ATTENTION_SAVE_PATTERN, + offload_pattern: SavePattern = None, offload_dst: str = "pinned_host", ) -> Optional[RematSpec]: """Configures how the Transformer or Conformer stack will save the linearization points. diff --git a/axlearn/common/compiler_options.py b/axlearn/common/compiler_options.py index 2d1bf0396..d4555b116 100644 --- a/axlearn/common/compiler_options.py +++ b/axlearn/common/compiler_options.py @@ -72,6 +72,8 @@ def default_xla_options( xla_latency_hiding_scheduler_rerun=2, # Improved performance for v6e. xla_tpu_scoped_vmem_limit_kib=98304, + # For megascale performance. + xla_jf_crs_combiner_threshold_count=10, ) options.update( # Improved performance for v6e. @@ -98,13 +100,14 @@ def default_xla_options( xla_tpu_use_enhanced_launch_barrier="true", # Sparsecore offloading for all reduce. # Uncomment below flags to enable it. - # xla_sc_disable_megacore_partitioning="true", - # xla_tpu_use_tc_device_shape_on_sc="true", - # tpu_use_continuations="true", - # xla_jf_crs_combiner_threshold_count=10, - # xla_sc_enable_instruction_fusion="false", - # xla_sc_disjoint_spmem="false", - # xla_tpu_enable_sparse_core_collective_offload_all_reduce="true", + xla_sc_disable_megacore_partitioning="true", + xla_tpu_use_tc_device_shape_on_sc="true", + tpu_use_continuations="true", + xla_sc_enable_instruction_fusion="false", + xla_sc_disjoint_spmem="false", + xla_tpu_enable_sparse_core_collective_offload_all_reduce="true", + # TODO(kelvinzou): temporary workaround to avoid memory leak in megascale. + megascale_grpc_enable_xor_tracer="false", ) # This flag can be removed after upgrading to Jax 0.4.38. # Uncomment for sparsecore offloading. diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 13c5e30d4..0846f4d86 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -29,10 +29,12 @@ import numpy as np from absl import logging from jax import numpy as jnp +from jax._src.ad_checkpoint import name_p from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.mesh import thread_resources from jax._src.tree_util import KeyEntry, KeyPath +from jax.core import Primitive from jax.experimental import mesh_utils, multihost_utils from jax.sharding import PartitionSpec @@ -103,9 +105,46 @@ def sharding(self) -> jax.sharding.Sharding: NestedTensorSpec = Optional[Union[TensorSpec, dict[str, Any]]] +SavePattern = Union[str, re.Pattern, None] +OffloadPolicy = Callable[[Primitive, list[Any], dict[str, Any]], Union[bool, Any]] -def offload_dots_saveable(offload_src: str, offload_dst: str) -> Callable[[Any], Any]: +def save_and_offload_only_these_names_regex( + *, + names_which_can_be_saved: SavePattern, + names_which_can_be_offloaded: SavePattern, + offload_src: str, + offload_dst: str, +) -> OffloadPolicy: + """Adapted from jax source code to support regex. + Reference: + https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120 + + Args: + names_which_can_be_saved: A regex pattern for names which can be saved. + names_which_can_be_offloaded: A regex pattern for names which can be offloaded. + offload_src: The source device for offloading. + offload_dst: The target device for offloading. + + Returns: + A policy function that offloads and saves only the tensors that match the given + regex patterns. + """ + + def policy(prim, *_, **params): + if prim is name_p: + if names_which_can_be_saved and re.fullmatch(names_which_can_be_saved, params["name"]): + return pe.Saveable + if names_which_can_be_offloaded and re.fullmatch( + names_which_can_be_offloaded, params["name"] + ): + return pe.Offloadable(src=offload_src, dst=offload_dst) + return pe.Recompute # not saveable unless it's in the allow-list + + return policy + + +def offload_dots_saveable(offload_src: str, offload_dst: str) -> OffloadPolicy: """Extract from offload_dot_with_no_batch_dims and remove no-batch-dims limit. https://github.com/google/jax/blob/f4158ace933482844c145a6b919bf5dc86e084ba/jax/_src/ad_checkpoint.py#L81C1-L90C1 @@ -128,7 +167,10 @@ def policy(prim, *_, **params): return policy -extended_checkpoint_policies = types.SimpleNamespace(offload_dots_saveable=offload_dots_saveable) +extended_checkpoint_policies = types.SimpleNamespace( + offload_dots_saveable=offload_dots_saveable, + save_and_offload_only_these_names_regex=save_and_offload_only_these_names_regex, +) @contextlib.contextmanager diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index f4c06b47a..71fd4f89d 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -19,6 +19,7 @@ import torch from absl.testing import absltest, parameterized from jax import numpy as jnp +from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies from jax.experimental import checkify, mesh_utils from jax.sharding import PartitionSpec @@ -28,6 +29,7 @@ from axlearn.common.layers import BatchNorm, LayerNorm, Linear from axlearn.common.metrics import WeightedScalar from axlearn.common.module import Module +from axlearn.common.module import functional as F from axlearn.common.repeat import Repeat from axlearn.common.test_utils import ( Nested, @@ -72,6 +74,7 @@ pytree_children, replicate_to_local_data, runtime_checks, + save_and_offload_only_these_names_regex, set_data_dir, set_recursively, split_prng_key, @@ -1804,5 +1807,84 @@ def test_basic(self, x: Nested[Tensor], paths: Sequence[str], missing: Optional[ validate_contains_paths(x, paths=paths) +class _TestRematLayer(BaseLayer): + """A dummy 2 layer feed forward with saved activation.""" + + @config_class + class Config(BaseLayer.Config): + linear1: Linear.Config = Linear.default_config().set(input_dim=2, output_dim=4) + linear2: Linear.Config = Linear.default_config().set(input_dim=4, output_dim=1) + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + # Reuse child 2 config for child 3. + self._add_child("linear1", cfg.linear1) + self._add_child("linear2", cfg.linear2) + + def forward(self, inputs: Tensor) -> Tensor: + x = self.linear1(inputs) + x = self._remat_name(x, "linear1") + x = self.linear2(x) + return x + + +class TestRematPolicy(TestCase): + """Test remat policy.""" + + def test_linear_remat(self): + """Test remat policy for linear layers.""" + batch, dim = 8, 2 + layer = _TestRematLayer.default_config().set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + x = jax.random.normal(jax.random.PRNGKey(1), shape=[batch, dim]) + + def f(x, layer_params): + y, _ = F( + layer, + inputs=dict(inputs=x), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + return y + + _, save_name_backward = jax.linearize( + jax.remat( + f, + policy=save_and_offload_only_these_names_regex( + names_which_can_be_saved=".*linear1", + names_which_can_be_offloaded=None, + offload_src="device", + offload_dst="pinned_host", + ), + ), + x, + layer_params, + ) + _, save_dots_backward = jax.linearize( + jax.remat(f, policy=jax_remat_policies.dots_saveable), + x, + layer_params, + ) + + _, remat_backward = jax.linearize( + jax.remat(f, policy=jax_remat_policies.nothing_saveable), + x, + layer_params, + ) + + # We have 2 forward and 2 backward and they are: + # f = matmul(x, l1), g = matmul(f, l2) + # l2' = matmul(f^t, g'), l1' = matmul(x^t, f') + self.assertEqual(str(save_name_backward).count(" dot_general"), 4) + self.assertEqual( + str(save_name_backward).count(" dot_general"), + str(save_dots_backward).count(" dot_general"), + ) + # We have one more recompute of f for remat during backward. + self.assertEqual(str(remat_backward).count(" dot_general"), 5) + + if __name__ == "__main__": absltest.main() diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt index f8e50909a..41971d290 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt @@ -7,14 +7,14 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -41,7 +41,7 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt index 2a831b28c..a27337377 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt @@ -7,18 +7,18 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.global_batch_size: 2048 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -41,11 +41,11 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.global_batch_size: 2048 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 +input.batcher.global_batch_size: 2048 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt index 0a470ccea..5cc38c163 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt @@ -7,14 +7,14 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -41,7 +41,7 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt index c4c6eed38..86c13eb79 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt @@ -7,18 +7,18 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.global_batch_size: 2048 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -41,11 +41,11 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.global_batch_size: 2048 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 +input.batcher.global_batch_size: 2048 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt index d06cfb3c7..32be1295c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt @@ -7,14 +7,14 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -41,7 +41,7 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt index 8d5dc4e92..3de7d2b95 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt @@ -7,18 +7,18 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.global_batch_size: 2048 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -41,11 +41,11 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.global_batch_size: 2048 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 +input.batcher.global_batch_size: 2048 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt index 8d7e8f710..7cc3b4afc 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt @@ -7,14 +7,14 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -41,7 +41,7 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt index 53ef5d052..612565b6f 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt @@ -7,18 +7,18 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.global_batch_size: 2048 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -41,11 +41,11 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.global_batch_size: 2048 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 +input.batcher.global_batch_size: 2048 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt index ade5f1af2..278d72e61 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt @@ -128,13 +128,48 @@ mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 +mesh_rules[1][0]: 'tpu-v6e-256-(4|8)' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v6e-256' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[2][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[3][1][0]: 1 +mesh_rules[3][1][1]: -1 +mesh_rules[3][1][2]: 1 +mesh_rules[3][1][3]: 128 +mesh_rules[3][1][4]: 1 +mesh_rules[3][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt index a986f1d08..6ca6030fe 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt @@ -128,13 +128,48 @@ mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 +mesh_rules[1][0]: 'tpu-v6e-256-(4|8)' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v6e-256' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[2][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[3][1][0]: 1 +mesh_rules[3][1][1]: -1 +mesh_rules[3][1][2]: 1 +mesh_rules[3][1][3]: 128 +mesh_rules[3][1][4]: 1 +mesh_rules[3][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt index 03fc3428a..8db672f15 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt @@ -128,13 +128,48 @@ mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 +mesh_rules[1][0]: 'tpu-v6e-256-(4|8)' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v6e-256' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[2][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[3][1][0]: 1 +mesh_rules[3][1][1]: -1 +mesh_rules[3][1][2]: 1 +mesh_rules[3][1][3]: 128 +mesh_rules[3][1][4]: 1 +mesh_rules[3][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt index 1ecf7529f..422e651cf 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt @@ -128,13 +128,48 @@ mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 +mesh_rules[1][0]: 'tpu-v6e-256-(4|8)' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v6e-256' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[2][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[3][1][0]: 1 +mesh_rules[3][1][1]: -1 +mesh_rules[3][1][2]: 1 +mesh_rules[3][1][3]: 128 +mesh_rules[3][1][4]: 1 +mesh_rules[3][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt index 76193c0db..e2fd015c3 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt @@ -7,18 +7,18 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.global_batch_size: 2048 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -41,11 +41,11 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.global_batch_size: 2048 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 +input.batcher.global_batch_size: 2048 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.00015 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' @@ -128,13 +128,48 @@ mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 +mesh_rules[1][0]: 'tpu-v6e-256-(4|8)' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v6e-256' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[2][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[3][1][0]: 1 +mesh_rules[3][1][1]: -1 +mesh_rules[3][1][2]: 1 +mesh_rules[3][1][3]: 128 +mesh_rules[3][1][4]: 1 +mesh_rules[3][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt index 45bdb8e66..829ba2d34 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt @@ -7,18 +7,18 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.global_batch_size: 2048 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -41,11 +41,11 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.global_batch_size: 2048 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 +input.batcher.global_batch_size: 2048 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.00015 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' @@ -128,13 +128,48 @@ mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 +mesh_rules[1][0]: 'tpu-v6e-256-(4|8)' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v6e-256' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[2][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[3][1][0]: 1 +mesh_rules[3][1][1]: -1 +mesh_rules[3][1][2]: 1 +mesh_rules[3][1][3]: 128 +mesh_rules[3][1][4]: 1 +mesh_rules[3][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt index 98cd9261c..535234d6e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 262144 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 64 +evalers['train'].input.batcher.global_batch_size: 16 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 262144 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 64 +evalers['validation'].input.batcher.global_batch_size: 16 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 64 +input.batcher.global_batch_size: 16 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt index 0a62cc2b1..cf5ed9a88 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt index a9e1f38ed..4ec1ad578 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 262144 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 64 +evalers['train'].input.batcher.global_batch_size: 16 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 262144 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 64 +evalers['validation'].input.batcher.global_batch_size: 16 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 64 +input.batcher.global_batch_size: 16 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt index 87736a6f5..a58c11472 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt index e01051cac..087727526 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 524288 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 32 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 524288 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 32 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 32 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt index 964f23e23..32f64479e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt index 17f97ab30..d55e01b42 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 524288 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 32 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 524288 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 32 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 32 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt index 438da62a1..eb4182b28 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt index a5b50a240..a15dfdf0b 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt @@ -7,14 +7,14 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -41,7 +41,7 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt index da5826693..7f520cbde 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt @@ -7,18 +7,18 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.global_batch_size: 2048 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -41,11 +41,11 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.global_batch_size: 2048 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 +input.batcher.global_batch_size: 2048 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt index 811b565e5..225299e7b 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt @@ -7,14 +7,14 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -41,7 +41,7 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt index b71e46c9d..6339517df 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt @@ -7,18 +7,18 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.global_batch_size: 2048 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -41,11 +41,11 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.global_batch_size: 2048 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 +input.batcher.global_batch_size: 2048 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b1102..6cd498143 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -19,6 +19,7 @@ from axlearn.common import causal_lm, config from axlearn.common.attention import ( + SELF_ATTENTION_SAVE_PATTERN, BaseStackedTransformerLayer, FusedGroupedQKVLinear, FusedQKVLinear, @@ -107,6 +108,14 @@ class Version(enum.Enum): }, } +# Llama3 uses 16m tokens after 2.87T tokens. +# https://arxiv.org/pdf/2407.21783 +TOKENS_PER_BATCH = { + Version.V1: 4 * (1024**2), + Version.V2: 4 * (1024**2), + Version.V3: 16 * (1024**2), +} + def get_trainer_kwargs( model_size: str, @@ -116,7 +125,7 @@ def get_trainer_kwargs( flash_attention: bool = False, ) -> dict[str, Any]: """Construct default trainer kwargs given a model size.""" - tokens_per_batch = 4 * (1024**2) # 4M tokens. + tokens_per_batch = TOKENS_PER_BATCH[version] if model_size not in TOTAL_TOKENS[version]: return {} max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch @@ -133,6 +142,15 @@ def get_trainer_kwargs( offload_dots_saveable_policy = config_for_function( extended_checkpoint_policies.offload_dots_saveable ).set(offload_src="device", offload_dst="pinned_host") + # To make it work better with v3 8k sequence length. + offload_attention_proj_policy = config_for_function( + extended_checkpoint_policies.save_and_offload_only_these_names_regex + ).set( + names_which_can_be_saved=None, + names_which_can_be_offloaded=SELF_ATTENTION_SAVE_PATTERN, + offload_src="device", + offload_dst="pinned_host", + ) # dict() is more readable here. # pylint: disable=use-dict-literal if model_size == "test": @@ -275,6 +293,24 @@ def get_trainer_kwargs( ], ), ), + ( + "tpu-v6e-256-(2|4|8)", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=offload_attention_proj_policy, + ), + } + ), + ], + ), + ), # tpu-v5p. ("tpu-v5p-.*", mesh_shape_from_axes(data=-1, fsdp=8)), # H100/A100 80G. @@ -411,6 +447,45 @@ def get_trainer_kwargs( ], ), ), + # V2 on tpu-v6e-256x4, step time: 4.9s. + ( + "tpu-v6e-256-(4|8)", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=offload_attention_proj_policy, + ), + } + ), + ], + ), + ), + # V2 on tpu-v6e-256, step time: 19.5s. + ( + "tpu-v6e-256", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=offload_attention_proj_policy, + ), + } + ), + GradientAccumulationModifier.default_config().set(grad_acc_steps=4), + ], + ), + ), # H100/A100 80G. Maximum per-node batch size = 16, hence need >= 64 nodes. # v2 on gpu-p5.48xlarge 8x64, step time: 12.9s. ( @@ -582,9 +657,11 @@ def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config: # pytype: enable=annotation-type-mismatch # The original config was supposed to run on >= 32 machines. - cfg.input.batcher.global_batch_size //= 32 + # pylint: disable=cell-var-from-loop + cfg.input.batcher.global_batch_size //= 128 if version == Version.V3 else 32 for evaler in cfg.evalers.values(): - evaler.input.batcher.global_batch_size //= 32 + evaler.input.batcher.global_batch_size //= 128 if version == Version.V3 else 32 + # pylint: enable=cell-var-from-loop return cfg # Make single-host config