Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V6e support #912

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,13 @@
import enum
import functools
import math
import re
from collections.abc import Sequence
from enum import Enum, unique
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union

import einops
import jax
from jax import numpy as jnp
from jax._src.ad_checkpoint import name_p
from jax._src.interpreters import partial_eval as pe
from jax.core import Primitive

from axlearn.common import ops, param_init
from axlearn.common.attention_bias import (
Expand Down Expand Up @@ -120,13 +116,16 @@
from axlearn.common.utils import (
Nested,
NestedTensor,
OffloadPolicy,
PartitionSpec,
SavePattern,
Tensor,
TensorSpec,
VDict,
check_numerics,
flatten_items,
get_or_none,
save_and_offload_only_these_names_regex,
shapes,
split_prng_key,
)
Expand Down Expand Up @@ -3930,30 +3929,23 @@ def forward(
# TODO(sneha): extend_step


OffloadPolicy = Callable[[Primitive, list[Any], dict[str, Any]], Union[bool, Any]]
_SavePattern = Union[str, re.Pattern, None]


# Adapted from jax source code to support regex. Reference:
# https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120
# TODO(kelvin-zou): deprecated, keep it here to minimize distruption to the golden configs.
# Please use axlearn.common.utils.extended_checkpoint_policies instead.
def _save_and_offload_only_these_names_regex(
*,
names_which_can_be_saved: _SavePattern,
names_which_can_be_offloaded: _SavePattern,
names_which_can_be_saved: SavePattern,
names_which_can_be_offloaded: SavePattern,
offload_src: str,
offload_dst: str,
) -> OffloadPolicy:
def policy(prim, *_, **params):
if prim is name_p:
if names_which_can_be_saved and re.fullmatch(names_which_can_be_saved, params["name"]):
return pe.Saveable
if names_which_can_be_offloaded and re.fullmatch(
names_which_can_be_offloaded, params["name"]
):
return pe.Offloadable(src=offload_src, dst=offload_dst)
return pe.Recompute # not saveable unless it's in the allow-list

return policy
return save_and_offload_only_these_names_regex(
names_which_can_be_saved=names_which_can_be_saved,
names_which_can_be_offloaded=names_which_can_be_offloaded,
offload_src=offload_src,
offload_dst=offload_dst,
)


SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)"
Expand All @@ -3964,8 +3956,8 @@ def build_remat_spec(
stack_cfg: Union[
BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore
],
save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN,
offload_pattern: _SavePattern = None,
save_pattern: SavePattern = SELF_ATTENTION_SAVE_PATTERN,
offload_pattern: SavePattern = None,
offload_dst: str = "pinned_host",
) -> Optional[RematSpec]:
"""Configures how the Transformer or Conformer stack will save the linearization points.
Expand Down
17 changes: 10 additions & 7 deletions axlearn/common/compiler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def default_xla_options(
xla_latency_hiding_scheduler_rerun=2,
# Improved performance for v6e.
xla_tpu_scoped_vmem_limit_kib=98304,
# For megascale performance.
xla_jf_crs_combiner_threshold_count=10,
)
options.update(
# Improved performance for v6e.
Expand All @@ -98,13 +100,14 @@ def default_xla_options(
xla_tpu_use_enhanced_launch_barrier="true",
# Sparsecore offloading for all reduce.
# Uncomment below flags to enable it.
# xla_sc_disable_megacore_partitioning="true",
# xla_tpu_use_tc_device_shape_on_sc="true",
# tpu_use_continuations="true",
# xla_jf_crs_combiner_threshold_count=10,
# xla_sc_enable_instruction_fusion="false",
# xla_sc_disjoint_spmem="false",
# xla_tpu_enable_sparse_core_collective_offload_all_reduce="true",
xla_sc_disable_megacore_partitioning="true",
xla_tpu_use_tc_device_shape_on_sc="true",
tpu_use_continuations="true",
xla_sc_enable_instruction_fusion="false",
xla_sc_disjoint_spmem="false",
xla_tpu_enable_sparse_core_collective_offload_all_reduce="true",
# TODO(kelvinzou): temporary workaround to avoid memory leak in megascale.
megascale_grpc_enable_xor_tracer="false",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current plan is to release a fix in jax 0.4.39 which is planned for Jan 15. The fix is in libtpu.

)
# This flag can be removed after upgrading to Jax 0.4.38.
# Uncomment for sparsecore offloading.
Expand Down
46 changes: 44 additions & 2 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
import numpy as np
from absl import logging
from jax import numpy as jnp
from jax._src.ad_checkpoint import name_p
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.mesh import thread_resources
from jax._src.tree_util import KeyEntry, KeyPath
from jax.core import Primitive
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import PartitionSpec

Expand Down Expand Up @@ -103,9 +105,46 @@ def sharding(self) -> jax.sharding.Sharding:


NestedTensorSpec = Optional[Union[TensorSpec, dict[str, Any]]]
SavePattern = Union[str, re.Pattern, None]
OffloadPolicy = Callable[[Primitive, list[Any], dict[str, Any]], Union[bool, Any]]


def offload_dots_saveable(offload_src: str, offload_dst: str) -> Callable[[Any], Any]:
def save_and_offload_only_these_names_regex(
*,
names_which_can_be_saved: SavePattern,
names_which_can_be_offloaded: SavePattern,
offload_src: str,
offload_dst: str,
) -> OffloadPolicy:
"""Adapted from jax source code to support regex.
Reference:
https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120

Args:
names_which_can_be_saved: A regex pattern for names which can be saved.
names_which_can_be_offloaded: A regex pattern for names which can be offloaded.
offload_src: The source device for offloading.
offload_dst: The target device for offloading.

Returns:
A policy function that offloads and saves only the tensors that match the given
regex patterns.
"""

def policy(prim, *_, **params):
if prim is name_p:
if names_which_can_be_saved and re.fullmatch(names_which_can_be_saved, params["name"]):
return pe.Saveable
if names_which_can_be_offloaded and re.fullmatch(
names_which_can_be_offloaded, params["name"]
):
return pe.Offloadable(src=offload_src, dst=offload_dst)
return pe.Recompute # not saveable unless it's in the allow-list

return policy


def offload_dots_saveable(offload_src: str, offload_dst: str) -> OffloadPolicy:
"""Extract from offload_dot_with_no_batch_dims and remove no-batch-dims limit.

https://github.com/google/jax/blob/f4158ace933482844c145a6b919bf5dc86e084ba/jax/_src/ad_checkpoint.py#L81C1-L90C1
Expand All @@ -128,7 +167,10 @@ def policy(prim, *_, **params):
return policy


extended_checkpoint_policies = types.SimpleNamespace(offload_dots_saveable=offload_dots_saveable)
extended_checkpoint_policies = types.SimpleNamespace(
offload_dots_saveable=offload_dots_saveable,
save_and_offload_only_these_names_regex=save_and_offload_only_these_names_regex,
)


@contextlib.contextmanager
Expand Down
82 changes: 82 additions & 0 deletions axlearn/common/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
from absl.testing import absltest, parameterized
from jax import numpy as jnp
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies
from jax.experimental import checkify, mesh_utils
from jax.sharding import PartitionSpec

Expand All @@ -28,6 +29,7 @@
from axlearn.common.layers import BatchNorm, LayerNorm, Linear
from axlearn.common.metrics import WeightedScalar
from axlearn.common.module import Module
from axlearn.common.module import functional as F
from axlearn.common.repeat import Repeat
from axlearn.common.test_utils import (
Nested,
Expand Down Expand Up @@ -72,6 +74,7 @@
pytree_children,
replicate_to_local_data,
runtime_checks,
save_and_offload_only_these_names_regex,
set_data_dir,
set_recursively,
split_prng_key,
Expand Down Expand Up @@ -1804,5 +1807,84 @@ def test_basic(self, x: Nested[Tensor], paths: Sequence[str], missing: Optional[
validate_contains_paths(x, paths=paths)


class _TestRematLayer(BaseLayer):
"""A dummy 2 layer feed forward with saved activation."""

@config_class
class Config(BaseLayer.Config):
linear1: Linear.Config = Linear.default_config().set(input_dim=2, output_dim=4)
linear2: Linear.Config = Linear.default_config().set(input_dim=4, output_dim=1)

def __init__(self, cfg: Config, *, parent: Module):
super().__init__(cfg, parent=parent)
cfg = self.config
# Reuse child 2 config for child 3.
self._add_child("linear1", cfg.linear1)
self._add_child("linear2", cfg.linear2)

def forward(self, inputs: Tensor) -> Tensor:
x = self.linear1(inputs)
x = self._remat_name(x, "linear1")
x = self.linear2(x)
return x


class TestRematPolicy(TestCase):
"""Test remat policy."""

def test_linear_remat(self):
"""Test remat policy for linear layers."""
batch, dim = 8, 2
layer = _TestRematLayer.default_config().set(name="test").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))
x = jax.random.normal(jax.random.PRNGKey(1), shape=[batch, dim])

def f(x, layer_params):
y, _ = F(
layer,
inputs=dict(inputs=x),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
return y

_, save_name_backward = jax.linearize(
jax.remat(
f,
policy=save_and_offload_only_these_names_regex(
names_which_can_be_saved=".*linear1",
names_which_can_be_offloaded=None,
offload_src="device",
offload_dst="pinned_host",
),
),
x,
layer_params,
)
_, save_dots_backward = jax.linearize(
jax.remat(f, policy=jax_remat_policies.dots_saveable),
x,
layer_params,
)

_, remat_backward = jax.linearize(
jax.remat(f, policy=jax_remat_policies.nothing_saveable),
x,
layer_params,
)

# We have 2 forward and 2 backward and they are:
# f = matmul(x, l1), g = matmul(f, l2)
# l2' = matmul(f^t, g'), l1' = matmul(x^t, f')
self.assertEqual(str(save_name_backward).count(" dot_general"), 4)
self.assertEqual(
str(save_name_backward).count(" dot_general"),
str(save_dots_backward).count(" dot_general"),
)
# We have one more recompute of f for remat during backward.
self.assertEqual(str(remat_backward).count(" dot_general"), 5)


if __name__ == "__main__":
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ checkpointer.keep_every_n_steps: 50000
checkpointer.keep_last_n: 3
checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer'
checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy'
checkpointer.save_policy.max_step: 3932160
checkpointer.save_policy.max_step: 983040
checkpointer.save_policy.min_step: 1
checkpointer.save_policy.n: 5000
checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage'
checkpointer.storage.timeout_secs: 3600
evalers['train'].eval_dtype: 'jax.numpy.bfloat16'
evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy'
evalers['train'].eval_policy.max_step: 3932160
evalers['train'].eval_policy.max_step: 983040
evalers['train'].eval_policy.min_step: 1
evalers['train'].eval_policy.n: 5000
evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch'
Expand All @@ -41,7 +41,7 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri
evalers['train'].summary_writer.write_every_n_steps: 1
evalers['validation'].eval_dtype: 'jax.numpy.bfloat16'
evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy'
evalers['validation'].eval_policy.max_step: 3932160
evalers['validation'].eval_policy.max_step: 983040
evalers['validation'].eval_policy.min_step: 1
evalers['validation'].eval_policy.n: 5000
evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch'
Expand Down Expand Up @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003
learner.optimizer.args[1].update_schedule.alpha: 0.1
learner.optimizer.args[1].update_schedule.begin_value: 0.0
learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup'
learner.optimizer.args[1].update_schedule.max_step: 3932160
learner.optimizer.args[1].update_schedule.max_step: 983040
learner.optimizer.args[1].update_schedule.peak_lr: 1.0
learner.optimizer.args[1].update_schedule.warmup_steps: 2000
learner.optimizer.args[1].weight_decay: 0.1
learner.optimizer.fn: 'axlearn.common.optimizers.chain'
max_step: 3932160
max_step: 983040
mesh_axis_names[0]: 'pipeline'
mesh_axis_names[1]: 'data'
mesh_axis_names[2]: 'expert'
Expand Down
Loading
Loading