From 38b48c8abd6f03c284a6f0864186865d66f1e8f5 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 15 Jul 2024 17:54:01 -0700 Subject: [PATCH 01/72] Changed version to 1.10.0.dev0 Signed-off-by: Przemek Tredak --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index eb595b2a77..a597619ec0 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.9.0.dev0 +1.10.0.dev0 From 210e57de151cd19f69ff6ae7b4a8363eae8489eb Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 16 Jul 2024 10:56:59 -0700 Subject: [PATCH 02/72] Simplify logic for launching CI (#1001) Signed-off-by: Tim Moon --- .github/workflows/blossom-ci.yml | 7 +++++-- .github/workflows/trigger-ci.yml | 20 ++++++++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 64d6f3f6b6..260adfc6d3 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -23,9 +23,12 @@ jobs: args: ${{ env.args }} # This job only runs for pull request comments - if: | - contains( ',ptrendx,ksivaman,', format(',{0},', github.actor)) && + if: > github.event.comment.body == '/blossom-ci' + && ( + github.actor == 'ptrendx' + || github.actor == 'ksivaman' + ) steps: - name: Check if comment is issued by authorized person run: blossom-ci diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 7a6d269573..5091e5d4f6 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -15,9 +15,25 @@ jobs: args: ${{ env.args }} # This job only runs for pull request comments - if: | - contains( ',ptrendx,ksivaman,schetlur-nv,timmoon10,zlsh80826,mingxu1067,cyanguwa,nzmora-nvidia,galagam,nouiz,denera,sudhakarsingh27,Oleg-Goncharov,phu0ngng,nvcforster,', format(',{0},', github.actor)) && + if: > startsWith(github.event.comment.body, '/te-ci') + && ( + github.actor == 'ptrendx' + || github.actor == 'ksivaman' + || github.actor == 'schetlur-nv' + || github.actor == 'timmoon10' + || github.actor == 'zlsh80826' + || github.actor == 'mingxu1067' + || github.actor == 'cyanguwa' + || github.actor == 'nzmora-nvidia' + || github.actor == 'galagam' + || github.actor == 'nouiz' + || github.actor == 'denera' + || github.actor == 'sudhakarsingh27' + || github.actor == 'Oleg-Goncharov' + || github.actor == 'phu0ngng' + || github.actor == 'nvcforster' + ) steps: - name: Check if comment is issued by authorized person run: blossom-ci From 6c57926782228fb49cbe5e52c92b35e1658ae6f5 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 17 Jul 2024 08:52:25 +0800 Subject: [PATCH 03/72] [JAX] Allow enabling partial custom calls through the environment variable (#1007) * Add enabled() to BasePrimitive * Add layernorm/rmsnorm fallback * Add cast_fp8 fallback * Add transpose/cast_transpose XLA fall back * Act_lu fallback * Add transpose fallback * Add softmax fallback * Unify the use of _cast_fp8 * Add tests for NVTE_JAX_CUSTOM_CALLS_RE --------- Signed-off-by: Reese Wang Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- qa/L0_jax_unittest/test.sh | 3 + tests/jax/test_custom_call_compute.py | 33 +----- tests/jax/test_softmax.py | 6 +- .../jax/cpp_extensions/activation.py | 46 +++++++- transformer_engine/jax/cpp_extensions/base.py | 17 +++ .../jax/cpp_extensions/normalization.py | 102 ++++++++++++++++++ .../jax/cpp_extensions/quantization.py | 23 ++++ .../jax/cpp_extensions/softmax.py | 55 +++++++++- .../jax/cpp_extensions/transpose.py | 93 ++++++++++++++++ transformer_engine/jax/layernorm.py | 17 ++- transformer_engine/jax/layernorm_mlp.py | 11 +- transformer_engine/jax/softmax.py | 10 +- 12 files changed, 369 insertions(+), 47 deletions(-) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 4321432a2e..3db1807fe2 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -9,6 +9,9 @@ pip install pytest==8.2.1 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' +# Test without custom calls +NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py + pip install -r $TE_PATH/examples/jax/mnist/requirements.txt pip install -r $TE_PATH/examples/jax/encoder/requirements.txt diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 8664a03f8d..5006f87a9d 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -19,8 +19,10 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp +from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu from transformer_engine.jax import cpp_extensions as tex + GEMM_CASES = [ (256, 256, 512), (32, 32, 32), @@ -34,21 +36,6 @@ is_fp8_supported, reason = is_fp8_available() -def _convert_to_activation_function(fn_or_string): - """Convert a string to an activation function.""" - if fn_or_string == "linear": - return lambda x: x - if fn_or_string == "quick_gelu": - return lambda x: nn.gelu(x, approximate=True) - if fn_or_string == "squared_relu": - return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)]) - if isinstance(fn_or_string, str): - return getattr(nn, fn_or_string) - if callable(fn_or_string): - return fn_or_string - raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") - - class TestFP8Dot: @staticmethod @@ -293,14 +280,7 @@ def layernorm_fp8_mlp_ref( bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) - x = jnp.split(linear_1_out, len(activation_type), axis=-2) - acts = [] - for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) - acts.append(x_i) - x = functools.reduce(operator.mul, acts) - - x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16) + x = _jax_act_lu(linear_1_out, activation_type) fp8_meta_pkg_2 = FP8MetaPackage( amax_list_2[0], @@ -443,12 +423,7 @@ class TestActivationLu: def ref_func(self, x, activation_type): def ref_act_lu(inputs): - x = jnp.split(inputs, len(activation_type), axis=-2) - acts = [] - for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) - acts.append(x_i) - x = functools.reduce(operator.mul, acts) + x = _jax_act_lu(inputs, activation_type) return jnp.mean(x) ref_act_func = jit(value_and_grad(ref_act_lu, (0,))) diff --git a/tests/jax/test_softmax.py b/tests/jax/test_softmax.py index 0cff5955fa..49e32e503c 100644 --- a/tests/jax/test_softmax.py +++ b/tests/jax/test_softmax.py @@ -123,14 +123,12 @@ def grad_func(func, *args, **kwargs): # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( - value_and_grad( - lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), (0,) - ) + value_and_grad(lambda logits, *args: grad_func(softmax, logits, *args, **kwargs), (0,)) ) jitted_reference = jit( value_and_grad( lambda logits, *args: grad_func( - __class__.reference_softmax, self.logits, *args, **kwargs + __class__.reference_softmax, logits, *args, **kwargs ), (0,), ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index f9b5156847..bdc377cb27 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -4,8 +4,9 @@ """JAX/TE custom ops for activation""" from typing import Tuple, Sequence, Union, Callable import operator -from functools import reduce +from functools import reduce, partial +import jax import jax.numpy as jnp from jax import core, dtypes from jax.interpreters.mlir import ir @@ -22,6 +23,7 @@ jax_dtype_to_ir_dtype, get_padded_spec, ) +from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP @@ -42,6 +44,35 @@ } +def _convert_to_activation_function(fn_or_string): + """Convert a string to an activation function.""" + if fn_or_string == "linear": + return lambda x: x + if fn_or_string == "quick_gelu": + return lambda x: jax.nn.sigmoid(1.702 * x) * x + if fn_or_string == "squared_relu": + return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) + if isinstance(fn_or_string, str): + return getattr(jax.nn, fn_or_string) + if callable(fn_or_string): + return fn_or_string + raise ValueError(f"Unsupported {fn_or_string} to an activation function") + + +def _jax_act_lu(inputs, activation_type): + """ + JAX native activation implementation + """ + x = jnp.split(inputs, len(activation_type), axis=-2) + acts = [] + for idx, act_fn in enumerate(activation_type): + x_i = _convert_to_activation_function(act_fn)(x[idx]) + acts.append(x_i) + x = reduce(operator.mul, acts) + x = jnp.squeeze(x, axis=-2) + return x + + class ActLuPrimitive(BasePrimitive): """ Activation Forward Primitive @@ -155,6 +186,9 @@ def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) Input shape: (N, 1, H) for non-gated activations (N, 2, H) for gated activations """ + if not ActLuPrimitive.enabled(): + return _jax_act_lu(inputs, activation_type) + act_type_id = ActivationEnum[activation_type] return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id) @@ -286,6 +320,11 @@ def dact_lu( dact_lu fusion wrapper Return dgated_act_lu(inputs) """ + + if not DActLuPrimitive.enabled(): + _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs) + return vjp_func(inputs)[0] + act_type_id = ActivationEnum[activation_type] return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id) @@ -443,6 +482,11 @@ def act_lu_fp8( Input shape: (N, 1, H) for non-gated activations (N, 2, H) for gated activations """ + if not ActLuFp8Primitive.enabled(): + act_lu_output = _jax_act_lu(x, activation_type) + casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype) + return casted_output, updated_amax + act_type_id = ActivationEnum[activation_type] return ActLuFp8Primitive.outer_primitive.bind( x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 88fab695d6..3d88c1f078 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -2,6 +2,8 @@ # # See LICENSE for license information. """JAX/TE base custom ops""" +import os +import re from abc import ABCMeta, abstractmethod from functools import partial @@ -17,6 +19,21 @@ class BasePrimitive(metaclass=ABCMeta): jax primitive """ + name = None + + @classmethod + def enabled(cls): + """ + A custom call is marked as disabled if the `cls.name` does not fully match the + `NVTE_JAX_CUSTOM_CALLS_RE` pattern. + By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names. + For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!te_act_lu$).+$'` to disable `te_act_lu`. + """ + pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*") + pattern = re.compile(pattern) + is_enabled = pattern.fullmatch(cls.name) is not None + return is_enabled + @staticmethod @abstractmethod def abstract(): diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 59468db0da..f1d3a7f28d 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -7,6 +7,7 @@ import os import warnings +import jax import jax.numpy as jnp from jax import core, dtypes from jax.interpreters import mlir @@ -25,6 +26,7 @@ jax_dtype_to_ir_dtype, te_dtype_to_jax_dtype, ) +from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp @@ -239,12 +241,77 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): register_primitive(LayerNormFwdPrimitive) +def _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps): + """ + JAX native layernorm implementation + """ + x_ = jnp.asarray(x, jnp.float32) + mean = jnp.mean(x_, axis=-1, keepdims=True) + var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) + normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) + if zero_centered_gamma: + gamma += 1.0 + return jnp.asarray(normed_input * gamma + beta).astype(x.dtype) + + +def _jax_rmsnorm(x, gamma, zero_centered_gamma, eps): + """ + JAX native rmsnorm implementation + """ + x_ = jnp.asarray(x, jnp.float32) + var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + normed_input = x_ * jax.lax.rsqrt(var + eps) + if zero_centered_gamma: + gamma += 1.0 + return jnp.asarray(normed_input * gamma).astype(x.dtype) + + +def _jax_layernorm_fp8(x, gamma, beta, scale, amax, out_dtype, zero_centered_gamma, eps): + """ + JAX native layernorm fp8 implementation + """ + x_ = jnp.asarray(x, jnp.float32) + mean = jnp.mean(x_, axis=-1, keepdims=True) + var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) + rsigma = jax.lax.rsqrt(var + eps) + normed_input = (x_ - mean) * rsigma + if zero_centered_gamma: + gamma += 1.0 + output = normed_input * gamma + beta + casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype) + return casted_output, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1), updated_amax + + +def _jax_rmsnorm_fp8(x, gamma, scale, amax, out_dtype, zero_centered_gamma, eps): + """ + JAX native rmsnorm fp8 implementation + """ + x_ = jnp.asarray(x, jnp.float32) + var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + rsigma = jax.lax.rsqrt(var + eps) + normed_input = x_ * rsigma + if zero_centered_gamma: + gamma += 1.0 + output = normed_input * gamma + casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype) + return casted_output, jnp.squeeze(rsigma, axis=-1), updated_amax + + def layernorm_fwd( x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float ): """ Wrapper for TE layernorm fwd """ + if not LayerNormFwdPrimitive.enabled(): + x_ = jnp.asarray(x, jnp.float32) + mu = jnp.mean(x_, axis=-1, keepdims=True) + rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_ - mu), axis=-1, keepdims=True) + epsilon) + return ( + _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon), + jnp.squeeze(mu, axis=-1), + jnp.squeeze(rsigma, axis=-1), + ) return LayerNormFwdPrimitive.outer_primitive.bind( x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) @@ -468,12 +535,21 @@ def layernorm_bwd( mu: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray, + beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float, ): """ Wrapper for TE layernorm bwd """ + if not LayerNormBwdPrimitive.enabled(): + _, vjp_func = jax.vjp( + partial(_jax_layernorm, zero_centered_gamma=zero_centered_gamma, eps=epsilon), + x, + gamma, + beta, + ) + return vjp_func(dz) return LayerNormBwdPrimitive.outer_primitive.bind( dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) @@ -655,6 +731,12 @@ def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float): """ Wrapper for TE rmsnorm fwd """ + if not RmsNormFwdPrimitive.enabled(): + x_ = jnp.asarray(x, jnp.float32) + rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + epsilon) + return _jax_rmsnorm(x, gamma, zero_centered_gamma=False, eps=epsilon), jnp.squeeze( + rsigma, axis=-1 + ) return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon) @@ -852,6 +934,11 @@ def rmsnorm_bwd( """ Wrapper for TE layernorm bwd """ + if not RmsNormBwdPrimitive.enabled(): + _, vjp_func = jax.vjp( + partial(_jax_rmsnorm, zero_centered_gamma=False, eps=epsilon), x, gamma + ) + return vjp_func(dz) return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon) @@ -1148,6 +1235,17 @@ def layernorm_fwd_fp8( """ Wrapper for TE layernorm fwd (fp8 out) """ + if not LayerNormFwdFp8Primitive.enabled(): + return _jax_layernorm_fp8( + x, + gamma, + beta, + scale, + amax, + out_dtype=out_dtype, + zero_centered_gamma=zero_centered_gamma, + eps=epsilon, + ) return LayerNormFwdFp8Primitive.outer_primitive.bind( x, gamma, @@ -1387,6 +1485,10 @@ def rmsnorm_fwd_fp8( """ Wrapper for TE rmsnorm fwd (fp8 out) """ + if not RmsNormFwdFp8Primitive.enabled(): + return _jax_rmsnorm_fp8( + x, gamma, scale, amax, out_dtype=out_dtype, zero_centered_gamma=False, eps=epsilon + ) return RmsNormFwdFp8Primitive.outer_primitive.bind( x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 40974b07b9..2c529e71c8 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -4,6 +4,7 @@ """JAX/TE custom ops for quantization""" from typing import Tuple +import jax import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import ir @@ -26,6 +27,26 @@ __all__ = ["cast_fp8"] +def _jax_quantize(x, scale, q_dtype): + """ + Quantize with scale + """ + compute_dtype = scale.dtype + dtype_max = (jnp.finfo(q_dtype).max).astype(compute_dtype) + scaled_x = x.astype(compute_dtype) * scale + clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max) + return clipped_scaled_x.astype(q_dtype) + + +def _jax_cast_fp8(inputs, scale, amax, out_dtype): + """ + JAX native fp8 casting implementation + """ + casted_output = _jax_quantize(inputs, scale, q_dtype=out_dtype) + updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype)) + return casted_output, updated_amax + + class CastFP8Primitive(BasePrimitive): """ Cast Primitive @@ -157,4 +178,6 @@ def cast_fp8( Cast wrapper Return FP8 tensor """ + if not CastFP8Primitive.enabled(): + return _jax_cast_fp8(x, scale, amax, out_dtype=out_dtype) return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype) diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index c2dfb65e41..bf92c00de3 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -7,6 +7,7 @@ import operator import warnings +import jax import jax.numpy as jnp from jax import core, dtypes from jax.interpreters.mlir import ir @@ -31,6 +32,30 @@ ] +def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float): + return jax.nn.softmax(scale_factor * logits) + + +def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float): + if mask is not None: + logits += jax.lax.select( + mask > 0, + jnp.full(mask.shape, -1e10).astype(logits.dtype), + jnp.full(mask.shape, 0.0).astype(logits.dtype), + ) + return jax.nn.softmax(logits * scale_factor) + + +def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): + mask = 1 - jnp.tril(jnp.ones_like(logits)) + logits += jax.lax.select( + mask > 0, + jnp.full(mask.shape, -1e10).astype(logits.dtype), + jnp.full(mask.shape, 0.0).astype(logits.dtype), + ) + return jax.nn.softmax(logits * scale_factor) + + def is_softmax_kernel_available( softmax_type: SoftmaxType, batch: int, @@ -395,6 +420,8 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: scaled_softmax_forward wrapper Return FP16/BF16 tensor """ + if not ScaledSoftmaxFwdPrimitive.enabled(): + return _jax_scaled_softmax(logits, scale_factor) return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor) @@ -469,12 +496,16 @@ def partition(scale_factor, mesh, arg_infos, result_infos): def scaled_softmax_bwd( - dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float + dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float ) -> jnp.ndarray: """ scaled_backward wrapper Return FP16/BF16 tensor """ + if not ScaledSoftmaxBwdPrimitive.enabled(): + _, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits) + return vjp_func(dz)[0] + return ScaledSoftmaxBwdPrimitive.outer_primitive.bind( dz, softmax_out, scale_factor=scale_factor ) @@ -625,6 +656,8 @@ def scaled_masked_softmax_fwd( scaled_masked_softmax_forward wrapper Return FP16/BF16 tensor """ + if not ScaledMaskedSoftmaxFwdPrimitive.enabled(): + return _jax_scaled_masked_softmax(logits, mask, scale_factor) return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, mask, scale_factor=scale_factor ) @@ -704,12 +737,21 @@ def partition(scale_factor, mesh, arg_infos, result_infos): def scaled_masked_softmax_bwd( - dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float + dz: jnp.ndarray, + softmax_out: jnp.ndarray, + logits: jnp.ndarray, + mask: jnp.ndarray, + scale_factor: float, ) -> jnp.ndarray: """ scaled_masked_backward wrapper Return FP16/BF16 tensor """ + if not ScaledMaskedSoftmaxBwdPrimitive.enabled(): + _, vjp_func = jax.vjp( + partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask + ) + return vjp_func(dz)[0] return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind( dz, softmax_out, scale_factor=scale_factor ) @@ -806,6 +848,8 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl scaled_upper_triang_masked_softmax_forward wrapper Return FP16/BF16 tensor """ + if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled(): + return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor) return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, scale_factor=scale_factor ) @@ -893,12 +937,17 @@ def partition(scale_factor, mesh, arg_infos, result_infos): def scaled_upper_triang_masked_softmax_bwd( - dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float + dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float ) -> jnp.ndarray: """ scaled_upper_triang_masked_backward wrapper Return FP16/BF16 tensor """ + if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled(): + _, vjp_func = jax.vjp( + partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits + ) + return vjp_func(dz)[0] return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind( dz, softmax_out, scale_factor=scale_factor ) diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index 9102b55cae..cc64951a95 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -6,6 +6,7 @@ from typing import Tuple, Sequence, Union, Callable import operator +import jax import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import ir @@ -26,6 +27,8 @@ normalize_axis_boundary, ) from .activation import ActivationEnum +from .activation import _jax_act_lu +from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp @@ -38,6 +41,27 @@ ] +def _jax_transpose(inputs, static_axis_boundary, transpose_axis_boundary): + """ + JAX native transpose implementation + """ + axes = multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary) + return jnp.transpose(inputs, axes=axes) + + +def _jax_cast_transpose( + inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary +): + """ + JAX native cast_transpose implementation + """ + casted_output, updated_amax = _jax_cast_fp8(inputs, scale, amax, out_dtype=out_dtype) + casted_transposed_output = _jax_transpose( + casted_output, static_axis_boundary, transpose_axis_boundary + ) + return casted_output, casted_transposed_output, updated_amax + + class TransposePrimitive(BasePrimitive): """ Transpose Primitive @@ -176,6 +200,8 @@ def transpose( """ transpose wrapper """ + if not TransposePrimitive.enabled(): + return _jax_transpose(x, static_axis_boundary, transpose_axis_boundary) return TransposePrimitive.outer_primitive.bind( x, static_axis_boundary=static_axis_boundary, @@ -381,6 +407,15 @@ def cast_transpose( cast transpose wrapper Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale` """ + if not CastTransposePrimitive.enabled(): + return _jax_cast_transpose( + x, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) return CastTransposePrimitive.outer_primitive.bind( x, amax, @@ -631,6 +666,28 @@ def dbias_cast_transpose( if static_axis_boundary < 0: static_axis_boundary = -1 # means no static axes + if not DBiasCastTransposePrimitive.enabled(): + casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose( + dz, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + dbias = jnp.sum( + dz, + axis=tuple( + range( + transpose_axis_boundary + if transpose_axis_boundary > 0 + else transpose_axis_boundary + dz.ndim + ) + ), + keepdims=False, + ) + return casted_dz, cast_transposed_dz, dbias, updated_amax + return DBiasCastTransposePrimitive.outer_primitive.bind( dz, amax, @@ -947,6 +1004,31 @@ def dact_lu_dbias_cast_transpose( if static_axis_boundary < 0: static_axis_boundary = -1 # means no static axes + if not DActLuDBiasCastTransposePrimitive.enabled(): + _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x) + (dx,) = vjp_func(dz) + casted_dx, cast_transposed_dx, updated_amax = _jax_cast_transpose( + dx, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + dbias = jnp.squeeze( + jnp.sum( + dx, + axis=tuple( + range( + transpose_axis_boundary + if transpose_axis_boundary > 0 + else transpose_axis_boundary + dx.ndim + ) + ), + ) + ) + return casted_dx, cast_transposed_dx, dbias, updated_amax + act_type_id = ActivationEnum[activation_type] return DActLuDBiasCastTransposePrimitive.outer_primitive.bind( dz, @@ -1161,6 +1243,17 @@ def dgated_act_lu_cast_transpose( Return FP8(dgated_act_lu(inputs)) """ act_type_id = ActivationEnum[activation_type] + if not DgatedActLuCastTransposePrimitive.enabled(): + _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x) + (dx,) = vjp_func(dz) + return _jax_cast_transpose( + dx, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=-2, + ) return DgatedActLuCastTransposePrimitive.outer_primitive.bind( dz, x, diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index e7364a13b6..4f2e83d9a2 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -69,14 +69,14 @@ def _layernorm_fwd_rule( mu = None else: raise ValueError(f"{layernorm_type=} is not supported.") - return output, (x, mu, rsigma, gamma) + return output, (x, mu, rsigma, gamma, beta) def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz): - x, mu, rsigma, gamma = ctx + x, mu, rsigma, gamma, beta = ctx if layernorm_type == "layernorm": dx, dgamma, dbeta = tex.layernorm_bwd( - dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + dz, x, mu, rsigma, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) elif layernorm_type == "rmsnorm": assert ( @@ -267,6 +267,7 @@ def _layernorm_fp8_dot_fwd_rule( rsigma, x, gamma, + beta, x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32, @@ -300,6 +301,7 @@ def _layernorm_fp8_dot_bwd_rule( rsigma, x, gamma, + beta, x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32, @@ -352,7 +354,14 @@ def _layernorm_fp8_dot_bwd_rule( dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) if layernorm_type == "layernorm": dx, dgamma, dbeta = tex.layernorm_bwd( - dgrad, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + dgrad, + x, + mu, + rsigma, + gamma, + beta, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, ) else: assert ( diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 0017acb80c..90504e4c14 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -344,6 +344,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule( mu, rsigma, gamma, + beta, dot_1_output, casted_activation_lu_out, casted_kernel_1, @@ -390,6 +391,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( mu, rsigma, gamma, + beta, dot_1_output, casted_activation_lu_out, casted_kernel_1, @@ -568,7 +570,14 @@ def _fused_layernorm_fp8_mlp_bwd_rule( if layernorm_type == "layernorm": dx, dgamma, dbeta = tex.layernorm_bwd( - dgrad_1, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + dgrad_1, + x, + mu, + rsigma, + gamma, + beta, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, ) else: assert ( diff --git a/transformer_engine/jax/softmax.py b/transformer_engine/jax/softmax.py index 0a997776ef..c63ee85e5d 100644 --- a/transformer_engine/jax/softmax.py +++ b/transformer_engine/jax/softmax.py @@ -49,18 +49,18 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type): else: output = tex.scaled_softmax_fwd(logits, scale_factor) - return output, (output,) + return output, (output, logits, mask) def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz): - (softmax_output,) = ctx + (softmax_output, logits, mask) = ctx if softmax_type is SoftmaxType.SCALED_MASKED: - dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, scale_factor) + dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor) elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: - dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor) + dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor) else: - dgrad = tex.scaled_softmax_bwd(dz, softmax_output, scale_factor) + dgrad = tex.scaled_softmax_bwd(dz, softmax_output, logits, scale_factor) return (dgrad, None) From e39674b99789bbc3d43c354861659bad33306e3e Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 17 Jul 2024 07:02:25 -0700 Subject: [PATCH 04/72] [PyTorch] Add option to pass kwargs to CUDA graph module (#945) * Add option to pass kwargs to CUDA graph module Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Debug unit tests Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Tweak comments Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_cuda_graphs.py | 209 +++++++++++++++++++++++++++- transformer_engine/pytorch/graph.py | 172 +++++++++++++++++------ 2 files changed, 331 insertions(+), 50 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 8d3a9dca4f..60a5a1ea99 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -3,7 +3,8 @@ # See LICENSE for license information. from dataclasses import dataclass -from typing import List, Tuple +import itertools +from typing import Iterable, List, Tuple, Union import pytest import torch @@ -88,7 +89,7 @@ def generate_data( dpa: bool = False, warmup: bool = False, return_grad_output: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[List[torch.Tensor], torch.Tensor]: """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn if dpa: @@ -129,14 +130,20 @@ def generate_data( return inputs, grad_output -def get_outputs(model, output): +def get_outputs( + model: torch.nn.Module, + output: Union[torch.Tensor, Iterable[torch.Tensor]], +) -> List[torch.Tensor]: """Return grads and params for comparsion.""" values = [] for param in model.parameters(): values.append(param) if param.grad is not None: values.append(param.grad) - values.append(output) + if isinstance(output, torch.Tensor): + values.append(output) + else: + values.extend(output) return values @@ -161,7 +168,7 @@ def _test_cuda_graphs( module: str, graph_mode: str, ) -> List[torch.Tensor]: - """Helper function for test.""" + """Helper function for CUDA graph test.""" reset_rng_states() FP8GlobalStateManager.reset() dpa = module == "dpa" @@ -247,7 +254,7 @@ def _test_cuda_graphs( else: model = modules[0] if dpa else _Sequential(*modules) - # Loss function and optimizer. + # Optimizer. if not dpa: optimizer = torch.optim.SGD(model.parameters(), lr=0.001) @@ -312,3 +319,193 @@ def test_gpt_make_graphed_callables( # Check that results match assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode2) + + +def _test_cuda_graphs_with_kwargs( + *, + config: ModelConfig, + dtype: torch.dtype, + with_graph: bool, +) -> List[torch.Tensor]: + """Simulate Megatron-LM interleaved pipeline parallelism.""" + reset_rng_states() + + # Initialize model. + model = TransformerLayer( + config.hidden_size, + config.hidden_size, + config.num_heads, + hidden_dropout=0.0, + attention_dropout=0.0, + self_attn_mask_type="arbitrary", + fuse_qkv_params=True, + params_dtype=dtype, + ) + + # Initialize gradient buffers. + for param in model.parameters(): + param.grad = torch.empty_like(param) + + # Make graphed version of model if needed. + if with_graph: + attn_mask = torch.zeros( + (config.batch_size, 1, config.sequence_length, config.sequence_length), + dtype=torch.bool, + device="cuda", + ) + model = make_graphed_callables( + model, + generate_data(config, dtype, warmup=True), + sample_kwargs=dict(attention_mask=attn_mask), + allow_unused_input=True, + ) + + # Optimizer. + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + + # Training loop. + for _ in range(3): + optimizer.zero_grad(set_to_none=False) + for grad_accumulation_step in range(2): + inputs, grad_output = generate_data(config, dtype, return_grad_output=True) + attn_mask = torch.randint( + 2, + (config.batch_size, 1, config.sequence_length, config.sequence_length), + dtype=torch.bool, + device="cuda", + ) + output = model(*inputs, attention_mask=attn_mask) + output.backward(grad_output) + optimizer.step() + + return get_outputs(model, output) + + +def test_make_graphed_callables_with_kwargs( + dtype: torch.dtype = torch.float32, + model: str = "small", +) -> None: + """Test CUDA graphs with keyword arguments.""" + config = model_configs[model] + kwargs = dict(config=config, dtype=dtype) + outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs) + graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs) + assert_all_equal(outputs, graph_outputs) + + +def _test_cuda_graphs_with_interleaved_pipeline_parallelism( + *, + config: ModelConfig, + dtype: torch.dtype, + with_graph: bool, +) -> List[torch.Tensor]: + """Simulate Megatron-LM interleaved pipeline parallelism.""" + reset_rng_states() + + # Pipeline parallel configuration. + num_layers = 2 + num_microbatches = 3 + layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1] + + # Initialize model. + model = torch.nn.ModuleList( + [ + Linear( + config.hidden_size, + config.hidden_size, + params_dtype=dtype, + ) + for _ in range(num_layers) + ] + ) + + # Initialize gradient buffers. + for param in model.parameters(): + param.grad = torch.empty_like(param) + + # Make graphed version of model if needed. + layer_forwards = { + (i % num_layers, i // num_layers): model[i % num_layers] + for i in range(num_layers * num_microbatches) + } + if with_graph: + sample_args = tuple( + generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches) + ) + layer_forwards = make_graphed_callables( + tuple(model), + sample_args, + allow_unused_input=True, + _order=layer_order, + ) + layer_forwards = { + (i // num_microbatches, i % num_microbatches): forward + for i, forward in enumerate(layer_forwards) + } + + # Optimizer. + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + + # Training loop. + for _ in range(3): + optimizer.zero_grad(set_to_none=False) + + # Generate data. + inputs = {} + grad_outputs = {} + for layer_idx in range(num_layers): + for microbatch_idx in range(num_microbatches): + x, dy = generate_data(config, dtype, return_grad_output=True) + idxs = (layer_idx, microbatch_idx) + inputs[idxs] = x[0] + grad_outputs[idxs] = dy + + # Cache for layer outputs. + outputs = {} + + def forward(layer_idx: int, microbatch_idx: int): + """Helper function for forward steps""" + idxs = (layer_idx, microbatch_idx) + outputs[idxs] = layer_forwards[idxs](inputs[idxs]) + + def backward(layer_idx: int, microbatch_idx: int): + """Helper function for backward steps""" + outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx]) + + # Forward and backward steps. + forward(0, 0) + forward(1, 0) + forward(0, 1) + forward(1, 1) + backward(1, 0) + backward(0, 0) + forward(0, 2) + forward(1, 2) + backward(1, 1) + backward(0, 1) + backward(1, 2) + backward(0, 2) + + # Optimizer step. + optimizer.step() + + outputs = [y for _, y in sorted(outputs.items())] + return get_outputs(model, outputs) + + +def test_make_graphed_callables_with_interleaved_pipeline_parallelism( + dtype: torch.dtype = torch.float16, + model: str = "small", +) -> None: + """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism.""" + config = model_configs[model] + kwargs = dict(config=config, dtype=dtype) + outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + with_graph=False, + **kwargs, + ) + graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + with_graph=True, + **kwargs, + ) + assert_all_equal(outputs, graph_outputs) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index a6f62ac457..f6331c9b2a 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -3,11 +3,14 @@ # See LICENSE for license information. """Functions for CUDA Graphs support in FP8""" +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union + import torch from torch.utils._pytree import tree_flatten as _tree_flatten from torch.utils._pytree import tree_unflatten as _tree_unflatten from torch._C import _graph_pool_handle +from transformer_engine.common.recipe import DelayedScaling from .fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -22,6 +25,9 @@ _IS_GRAPH_CAPTURING = False +_T = TypeVar("_T") +SingleOrTuple = Union[_T, Tuple[_T, ...]] + def set_capture_start() -> None: """Record beginning of `make_graphed_callables`.""" @@ -48,13 +54,14 @@ def graph_pool_handle(): def _make_graphed_callables( - callables, - sample_args, - num_warmup_iters=3, - allow_unused_input=False, - fp8_weight_caching=False, - _order=None, -): + callables: SingleOrTuple[Callable], + sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + fp8_weight_caching: bool = False, + sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, + _order: Optional[List[int]] = None, +) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` """ @@ -65,16 +72,38 @@ def _make_graphed_callables( "caching. Please set `cache_enabled=False`." ) - just_one_callable = False + # Default is to pass no kwargs to callables + if sample_kwargs is None: + if isinstance(callables, tuple): + sample_kwargs = tuple({} for _ in range(len(sample_args))) + else: + sample_kwargs = {} + # Canonicalize args as tuples + just_one_callable = False if not isinstance(callables, tuple): just_one_callable = True callables = (callables,) sample_args = (sample_args,) + sample_kwargs = (sample_kwargs,) - flatten_sample_args = [] - if _order is not None: - # order is a list containing 1..model_chunk values in the order of microbatch schedule + # Check sizes of args + if _order is None: + assert len(sample_args) == len(callables) + assert len(sample_kwargs) == len(callables) + else: + # Custom logic for interleaved pipeline parallelism + # Note: This is tightly coupled with the Megatron-core + # implementation of interleaved pipeline parallelism at + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py. + # Note: The model is assumed to consist of layers + # (corresponding to callables) that are grouped into + # equally-sized model chunks. _order is a list of chunk + # indices (1-indexed) that indicates the order in which the + # layers are evaluated. Positive values indicate forward + # passes and negative values indicate backward passes. Each + # entry in sample_args corresponds to one of the forward + # passes. num_model_chunks = max(_order) num_microbatches = len(_order) // num_model_chunks // 2 assert num_model_chunks * num_microbatches * 2 == len(_order) @@ -90,10 +119,13 @@ def _make_graphed_callables( f"Expected {num_model_chunks * num_microbatches}" + f"args tuple, but got {len(sample_args)}." ) + assert len(sample_kwargs) == len(sample_args) if fp8_weight_caching: + # Initialize flag that controls FP8 weight updates FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) + # Check callables for c in callables: if isinstance(c, torch.nn.Module): assert ( @@ -110,9 +142,14 @@ def _make_graphed_callables( + ":func:`~make_graphed_callables`, only parameters may be trainable. " + "All buffers must have ``requires_grad=False``." ) - for args in sample_args: + + # Flatten callable arguments + per_callable_kwargs_keys = [list(kwargs.keys()) for kwargs in sample_kwargs] + flatten_sample_args = [] + for args, kwargs, kwargs_keys in zip(sample_args, sample_kwargs, per_callable_kwargs_keys): flatten_arg, _ = _tree_flatten(args) - flatten_sample_args.append(tuple(flatten_arg)) + flatten_kwarg, _ = _tree_flatten([kwargs[key] for key in kwargs_keys]) + flatten_sample_args.append(tuple(flatten_arg + flatten_kwarg)) assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( "In the beta API, sample_args " + "for each callable must contain only Tensors. Other types are not allowed." @@ -120,6 +157,10 @@ def _make_graphed_callables( # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly # passes to forward (ie, its sample_args) AND the module's parameter attributes. + # Note: These per_callable_* variables are not actually + # per-callable, but per-forward-pass (see description of _order). + # The names are kept for consistency with + # torch.cuda.make_graphed_callables. per_callable_len_user_args = [len(args) for args in flatten_sample_args] if _order is None: per_callable_module_params = [ @@ -144,6 +185,7 @@ def _make_graphed_callables( fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] graph_callables = [None for _ in range(len(flatten_sample_args))] + # For cases with multiple active RNG states, e.g. TP. if graph_safe_rng_available(): for _, state in get_all_rng_states().items(): @@ -158,11 +200,12 @@ def _make_graphed_callables( # from ending up in any captures. torch.cuda.synchronize() with torch.cuda.stream(torch.cuda.Stream()): - for c_i, func in enumerate(callables): - args = sample_args[c_i] - static_input_surface = per_callable_static_input_surfaces[c_i] + for func_idx, func in enumerate(callables): + args = sample_args[func_idx] + kwargs = sample_kwargs[func_idx] + static_input_surface = per_callable_static_input_surfaces[func_idx] for _ in range(num_warmup_iters): - outputs, _ = _tree_flatten(func(*args)) + outputs, _ = _tree_flatten(func(*args, **kwargs)) grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -194,9 +237,10 @@ def _make_graphed_callables( fwd_idx[m_chunk] * num_layers + l_no ) args = sample_args[per_callable_fwd_idx] + kwargs = sample_kwargs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx] with torch.cuda.graph(fwd_graph, pool=mempool): - outputs = func(*args) + outputs = func(*args, **kwargs) flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec @@ -245,9 +289,9 @@ def _make_graphed_callables( per_callable_static_outputs = [] per_callable_output_unflatten_spec = [] graph_id = 0 - for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs): with torch.cuda.graph(fwd_graph, pool=mempool): - outputs = func(*args) + outputs = func(*args, **kwargs) graph_callables[graph_id] = func graph_id += 1 @@ -300,6 +344,7 @@ def make_graphed_autograd_function( fwd_graph, bwd_graph, module_params, + kwargs_keys, len_user_args, output_unflatten_spec, static_input_surface, @@ -312,14 +357,18 @@ class Graphed(torch.autograd.Function): @staticmethod def forward(ctx, skip_fp8_weight_update, *inputs): - # At this stage, only the user args may (potentially) be new tensors. + + # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) + # Copy values from new tensors into static tensors for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): static_input_surface[i].copy_(inputs[i]) + + # Replay forward graph fwd_graph.replay() assert isinstance(static_outputs, tuple) return tuple(o.detach() for o in static_outputs) @@ -327,6 +376,8 @@ def forward(ctx, skip_fp8_weight_update, *inputs): @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, *grads): + + # Replay backward graph assert len(grads) == len(static_grad_outputs) for g, grad in zip(static_grad_outputs, grads): if g is not None: @@ -336,6 +387,7 @@ def backward(ctx, *grads): g.copy_(grad) bwd_graph.replay() + # Update FP8 scale factors if needed if ctx.is_first_module: FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) @@ -346,10 +398,8 @@ def backward(ctx, *grads): ) def functionalized(*user_args, **user_kwargs): - # Runs the autograd function with inputs == all - # inputs to the graph that might require grad - # (explicit user args + module parameters) - # Assumes module params didn't change since capture. + + # Decide whether to update FP8 weights skip_fp8_weight_update = None if fp8_weight_caching: assert "is_first_microbatch" in user_kwargs and isinstance( @@ -358,8 +408,22 @@ def functionalized(*user_args, **user_kwargs): skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] + # Check that required kwargs are provided + for key in kwargs_keys: + if key not in user_kwargs: + raise TypeError( + f"Graphed callable was initialized with kwarg {key} ," + "but it was not provided in graph replay" + ) + + # Runs the autograd function with inputs == all inputs to + # the graph that might require grad (explicit user args + + # module parameters) + # Assumes module params didn't change since capture. flatten_user_args, _ = _tree_flatten(user_args) - out = Graphed.apply(skip_fp8_weight_update, *(tuple(flatten_user_args) + module_params)) + flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) + func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params + out = Graphed.apply(skip_fp8_weight_update, *func_args) return _tree_unflatten(out, output_unflatten_spec) return functionalized @@ -371,6 +435,7 @@ def functionalized(*user_args, **user_kwargs): fwd_graphs[i], bwd_graphs[i], per_callable_module_params[i], + per_callable_kwargs_keys[i], per_callable_len_user_args[i], per_callable_output_unflatten_spec[i], per_callable_static_input_surfaces[i], @@ -443,25 +508,42 @@ def restore_fp8_tensors(modules, fp8_tensors): def make_graphed_callables( - modules, - sample_args, - num_warmup_iters=3, - allow_unused_input=False, - fp8_enabled=False, - fp8_calibrating=False, - fp8_recipe=None, - fp8_weight_caching=False, - _order=None, -): + modules: SingleOrTuple[Callable], + sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, + fp8_enabled: bool = False, + fp8_calibrating: bool = False, + fp8_recipe: Optional[DelayedScaling] = None, + fp8_weight_caching: bool = False, + _order: Optional[List[int]] = None, +) -> Union[Callable, Tuple[Callable, ...]]: """ - A version of PyTorch's `make_graphed_callables` utility function with support for - TransformerEngine modules and FP8. Please see the original version in upstream PyTorch - `here `_ - for extensive documentation. The documentation for additional parameters which are - specific to FP8 are given below. - - FP8 specific parameters - ----------------------- + Make CUDA graph version of Transformer Engine modules + + A variation of PyTorch's `make_graphed_callables` utility function + with support for Transformer Engine modules and FP8. Please see + the + `original PyTorch implementation `_ + for more documentation. + + Graphing parameters + ------------------- + modules: (tuple of) callable + Callable or callables to graph. + sample_args: (tuple of) tuple of torch.Tensor + Positional arguments to callable(s). + num_warmup_iters: int, default = 3 + Number of warmup iterations. + allow_unused_input: bool, default = `False` + Whether to handle case where callable inputs + and outputs are disconnected in compute graph. + sample_kwargs: (tuple of) dict, optional + Keyword arguments to callable(s) + + FP8-related parameters + ---------------------- fp8_enabled: bool, default = `True` whether or not to enable fp8 fp8_calibrating: bool, default = `False` @@ -478,6 +560,7 @@ def make_graphed_callables( using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg must be set to `False` if calculating weight transposes' outside TE, e.g., in the optimizer step. + """ set_capture_start() @@ -532,6 +615,7 @@ def forward_func(*args, **kwargs): num_warmup_iters=num_warmup_iters, allow_unused_input=allow_unused_input, fp8_weight_caching=fp8_weight_caching, + sample_kwargs=sample_kwargs, _order=_order, ) From 8c0a0c93444eeb8b6a3702d0b0ef149d3889bc4f Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Wed, 17 Jul 2024 07:02:59 -0700 Subject: [PATCH 05/72] DGRAD_RS UB overlap Bug fixes (#1004) * DGRAD_RS UB overlap Bug fixes Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Vasudevan Rengasamy Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- .../pytorch/csrc/comm_gemm_overlap.h | 342 +++++++++--------- .../pytorch/csrc/userbuffers/userbuffers.cu | 9 +- transformer_engine/pytorch/module/base.py | 2 + .../pytorch/module/layernorm_linear.py | 2 +- 4 files changed, 187 insertions(+), 168 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index d2f8b771db..6612124b30 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -37,6 +37,7 @@ } while (0) using namespace torch::indexing; + namespace ubuf { /* @@ -324,47 +325,48 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/, counter); - for (int i = 0; i < _num_splits; i++) { - const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); - if (env_p != nullptr && env_p[0] == '1') { - if (i == _num_splits - 1) { - _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, - &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm); - } else { - reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, - _num_splits, &counter_ptr[i], _ub_comm, - (cudaStream_t)_stream_comm); - } - } else if (env_p != nullptr && env_p[0] == '2') { - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, - counter_ptr, _ub_comm, (cudaStream_t)_stream_comm); - } else { - reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, - m, _num_splits, counter_ptr, _ub_comm, - (cudaStream_t)_stream_comm); - } - break; - } else { - consumer(counter_ptr, i, (cudaStream_t)_stream_comm); - // if (i == _num_splits-1) { - // _ub_comm->sms = UB_MAX_SM; - // } - reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); - } + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + B_type, fp8_type, for (int i = 0; i < _num_splits; i++) { + const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); + if (env_p != nullptr && env_p[0] == '1') { + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + reducescatter2_userbuff_strided_atomic_fp8( + rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, + _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm); + } else { + reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, + n, m, _num_splits, &counter_ptr[i], _ub_comm, + (cudaStream_t)_stream_comm); + } + } else if (env_p != nullptr && env_p[0] == '2') { + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + reducescatter2_userbuff_strided_multiatomic_fp8( + rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, + counter_ptr, _ub_comm, (cudaStream_t)_stream_comm); + } else { + reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, + n, m, _num_splits, counter_ptr, _ub_comm, + (cudaStream_t)_stream_comm); + } + break; + } else { + consumer(counter_ptr, i, (cudaStream_t)_stream_comm); + // if (i == _num_splits-1) { + // _ub_comm->sms = UB_MAX_SM; + // } + reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, (cudaStream_t)_stream_comm); + } - rs_output_ptr += m_chunk * rs_output.element_size(); - } + rs_output_ptr += m_chunk * rs_output.element_size(); + }); _ub_comm->sms = ori_sms; CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); @@ -422,111 +424,115 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(pre_gelu_out.numel() == 0); - if (gemm_overlap) { - torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[0]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); - - for (int i = 1; i < _num_splits; i++) { - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - - torch::Tensor input_a_chunk = - torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - CHECK_CUDA(cudaEventRecord( - _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Communication chunk - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, - (cudaStream_t)_stream_comm); - } - - rs_output_ptr += m_chunk * rs_output.element_size(); - } - int last_compute_stream_id = - (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); - CHECK_CUDA( - cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Last communication chunk with max SM - _ub_comm->sms = UB_MAX_SM; - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, - n, m, _ub_comm, (cudaStream_t)_stream_comm); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, - (_num_splits - 1) * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); - } - } else { - for (int i = 0; i < _num_splits; i++) { - torch::Tensor input_a_chunk = - torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - CHECK_CUDA(cudaEventRecord(_start_comm, - (cudaStream_t)_stream_compute[i % _stream_compute.size()])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + B_type, fp8_type, + if (gemm_overlap) { + torch::Tensor input_a_chunk = + torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); + torch::Tensor output_chunk = + torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[0]); + te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, + transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, + grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + for (int i = 1; i < _num_splits; i++) { + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); + + torch::Tensor input_a_chunk = + torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); + torch::Tensor output_chunk = + torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor workspace_chunk = torch::from_blob( + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, + transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, + grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + CHECK_CUDA(cudaEventRecord( + _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + + // Communication chunk + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, + m, _ub_comm, (cudaStream_t)_stream_comm); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, + (i - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, (cudaStream_t)_stream_comm); + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + int last_compute_stream_id = + (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); + CHECK_CUDA( + cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - // Communication chunk. Uses MAX_SM at the last chunk - if (i == _num_splits - 1) { + // Last communication chunk with max SM _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, + (_num_splits - 1) * output_chunk_size, m_chunk, n, + m, _ub_comm, (cudaStream_t)_stream_comm); + } } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, - m_chunk, n, m, _ub_comm, - (cudaStream_t)_stream_comm); - } - rs_output_ptr += m_chunk * rs_output.element_size(); - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - } - } + for (int i = 0; i < _num_splits; i++) { + torch::Tensor input_a_chunk = + torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); + torch::Tensor output_chunk = + torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor workspace_chunk = torch::from_blob( + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, + transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, + grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + CHECK_CUDA(cudaEventRecord(_start_comm, + (cudaStream_t)_stream_compute[i % _stream_compute.size()])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + + // Communication chunk. Uses MAX_SM at the last chunk + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, + _ub_comm, (cudaStream_t)_stream_comm); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, + m_chunk, n, m, _ub_comm, + (cudaStream_t)_stream_comm); + } + rs_output_ptr += m_chunk * rs_output.element_size(); + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); + } + }); for (size_t i = 0; i < _stream_compute.size(); i++) { CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); @@ -1051,18 +1057,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, - _tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); - } + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + B_type, fp8_type, + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, + _ubufs[0].numel(), (cudaStream_t)stream_main); + } else { + torch::Tensor reduce_buf = torch::from_blob( + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + torch::sum_out(rs_output, reduce_buf, 0); + }); } /* @@ -1145,18 +1153,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, - _tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); - } + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + B_type, fp8_type, + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, + _ubufs[0].numel(), (cudaStream_t)stream_main); + } else { + torch::Tensor reduce_buf = torch::from_blob( + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + torch::sum_out(rs_output, reduce_buf, 0); + }); for (size_t i = 0; i < _stream_compute.size(); i++) { CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index cd94835e68..b648561597 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -1890,11 +1890,18 @@ template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements_out, const int strideelements_in, const int numchunks, void *counters, communicator *comm, cudaStream_t stream); - +template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e5m2>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, cudaStream_t stream); template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements_out, const int strideelements_in, const int numchunks, void *counters, communicator *comm, cudaStream_t stream); +template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, cudaStream_t stream); __global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) { atomicAdd_system(flagptr, 1); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 97f373343e..039df99260 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -233,7 +233,9 @@ def free_callback(data: torch.Tensor) -> None: wgrad_name = name.replace("dgrad", "wgrad") assert wgrad_name not in ub_cfgs layers_reduce_scatter_overlap.remove(wgrad_name) + layers_all_gather_overlap.remove(name) layers_reduce_scatter_overlap.append(name) + methods["pipeline"].append(name) for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: if ub_cfgs is not None and name in ub_cfgs: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 1a3c0fd4d5..ba975d2758 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -184,7 +184,7 @@ def forward( fp8_dtype_forward, out=ln_out_fp8, ) - ln_out = ln_out_fp8 + ln_out = torch.empty_like(ln_out_fp8) else: ln_out_total = tex.cast_to_fp8( ln_out_total, From c57a81f0788968fd68181af0a7052dc888f54cc7 Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Thu, 18 Jul 2024 03:37:10 +0800 Subject: [PATCH 06/72] [Paddle] Compile with paddlepaddle-gpu 2.6.1 (#1021) fix 261 compile Signed-off-by: Frank Lin (Engrg-Hardware 1) Co-authored-by: Frank Lin (Engrg-Hardware 1) Co-authored-by: Kirthi Shankar Sivamani --- build_tools/paddle.py | 5 +++++ transformer_engine/paddle/csrc/custom_ops.cu | 20 +++++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/build_tools/paddle.py b/build_tools/paddle.py index 21a21e3a8a..163f094fce 100644 --- a/build_tools/paddle.py +++ b/build_tools/paddle.py @@ -9,6 +9,10 @@ from .utils import cuda_version +import paddle + +paddle_version = paddle.__version__.replace(".", "") + def setup_paddle_extension( csrc_source_files, @@ -45,6 +49,7 @@ def setup_paddle_extension( "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + f"-DPADDLE_VERSION={paddle_version}", "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 18d380abd1..3204574053 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -595,10 +595,12 @@ void UpdateRandomGenerator(phi::Place place, cudaStream_t stream, int rng_elts_p // extract random number generator seed and offset const phi::DeviceContext *dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + phi::Generator *gen_cuda = dev_ctx->GetGenerator(); auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - auto state_index = gen_cuda->GetStateIndex(); int64_t *rng_state_p = static_cast(rng_state.data()); +#if PADDLE_VERSION > 261 + auto state_index = gen_cuda->GetStateIndex(); auto parameterSetter = [gen_cuda, state_index, rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams ¶ms) { @@ -618,6 +620,9 @@ void UpdateRandomGenerator(phi::Place place, cudaStream_t stream, int rng_elts_p }; phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, cudaKernelCallback); +#else + set_rng_state<<<1, 1, 0, stream>>>(0, seed_offset, rng_state_p); +#endif } void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens, @@ -1005,9 +1010,10 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place()); auto gen_cuda = dev_ctx->GetGenerator(); auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - auto state_index = gen_cuda->GetStateIndex(); - auto rng_state_p = static_cast(rng_state.data()); auto stream = Q.stream(); + auto rng_state_p = static_cast(rng_state.data()); +#if PADDLE_VERSION > 261 + auto state_index = gen_cuda->GetStateIndex(); auto parameterSetter = [gen_cuda, state_index, rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams ¶ms) { // ensure the generator use correct state index @@ -1026,6 +1032,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p }; phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, cudaKernelCallback); +#else + set_rng_state<<<1, 1, 0, stream>>>(0, seed_offset, rng_state_p); +#endif auto te_rng_state = MakeNvteTensor(rng_state); @@ -1354,6 +1363,7 @@ void amax_and_scale_update_inplace_legacy(paddle::Tensor &amax_history, // NOLI bool update_weight_scale_inv, bool fwd_update, float fp8_max, float margin, const std::string &amax_compute) { +#if PADDLE_VERSION > 261 NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent"); paddle::Tensor amax; @@ -1401,6 +1411,10 @@ void amax_and_scale_update_inplace_legacy(paddle::Tensor &amax_history, // NOLI }; phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, cudaKernelCallback); +#else + NVTE_ERROR( + "amax_and_scale_update_inplace_legacy is not supported in old version of PaddlePaddle\n"); +#endif } void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT From a6db82d99c200cd8344dd79f31b65529c5214b0f Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 18 Jul 2024 14:04:57 -0500 Subject: [PATCH 07/72] [C/PyTorch] Fixing incorrect use of TYPE_SWITCH_FP8_ONLY in GEMM + reduce-scatter overlap (#1023) * FP8 type switch macro now wraps only the FP8 kernel to avoid invalid type errors Signed-off-by: Alp Dener * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../pytorch/csrc/comm_gemm_overlap.h | 336 +++++++++--------- 1 file changed, 171 insertions(+), 165 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 6612124b30..611de6ec77 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -19,6 +19,7 @@ #include #include +#include "common/common.h" #include "common/util/logging.h" #include "common/util/system.h" #include "extensions.h" @@ -325,48 +326,51 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/, counter); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, for (int i = 0; i < _num_splits; i++) { - const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); - if (env_p != nullptr && env_p[0] == '1') { - if (i == _num_splits - 1) { - _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + for (int i = 0; i < _num_splits; i++) { + const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); + if (env_p != nullptr && env_p[0] == '1') { + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + B_type, fp8_type, reducescatter2_userbuff_strided_atomic_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, - _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm); - } else { - reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, - n, m, _num_splits, &counter_ptr[i], _ub_comm, - (cudaStream_t)_stream_comm); - } - } else if (env_p != nullptr && env_p[0] == '2') { - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);); + } else { + reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _num_splits, &counter_ptr[i], _ub_comm, + (cudaStream_t)_stream_comm); + } + } else if (env_p != nullptr && env_p[0] == '2') { + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + B_type, fp8_type, reducescatter2_userbuff_strided_multiatomic_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, - counter_ptr, _ub_comm, (cudaStream_t)_stream_comm); - } else { - reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, - n, m, _num_splits, counter_ptr, _ub_comm, - (cudaStream_t)_stream_comm); - } - break; - } else { - consumer(counter_ptr, i, (cudaStream_t)_stream_comm); - // if (i == _num_splits-1) { - // _ub_comm->sms = UB_MAX_SM; - // } - reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); - } + counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);); + } else { + reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, + m, _num_splits, counter_ptr, _ub_comm, + (cudaStream_t)_stream_comm); + } + break; + } else { + consumer(counter_ptr, i, (cudaStream_t)_stream_comm); + // if (i == _num_splits-1) { + // _ub_comm->sms = UB_MAX_SM; + // } + reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, (cudaStream_t)_stream_comm); + } - rs_output_ptr += m_chunk * rs_output.element_size(); - }); + rs_output_ptr += m_chunk * rs_output.element_size(); + } _ub_comm->sms = ori_sms; CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); @@ -424,115 +428,117 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(pre_gelu_out.numel() == 0); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, - if (gemm_overlap) { - torch::Tensor input_a_chunk = - torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[0]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, - transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, - grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - for (int i = 1; i < _num_splits; i++) { - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - - torch::Tensor input_a_chunk = - torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = torch::from_blob( - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, - transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, - grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - CHECK_CUDA(cudaEventRecord( - _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Communication chunk - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + if (gemm_overlap) { + torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); + torch::Tensor output_chunk = + torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[0]); + te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, + output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); + + for (int i = 1; i < _num_splits; i++) { + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); + + torch::Tensor input_a_chunk = + torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); + torch::Tensor output_chunk = + torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, + output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + CHECK_CUDA(cudaEventRecord( + _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + + // Communication chunk + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + B_type, fp8_type, reducescatter2_userbuff_stridedoutput_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, - m, _ub_comm, (cudaStream_t)_stream_comm); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, - (i - 1) * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); - } - - rs_output_ptr += m_chunk * rs_output.element_size(); - } - int last_compute_stream_id = - (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); - CHECK_CUDA( - cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + m, _ub_comm, (cudaStream_t)_stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, + (cudaStream_t)_stream_comm); + } - // Last communication chunk with max SM - _ub_comm->sms = UB_MAX_SM; - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + rs_output_ptr += m_chunk * rs_output.element_size(); + } + int last_compute_stream_id = + (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); + CHECK_CUDA( + cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + + // Last communication chunk with max SM + _ub_comm->sms = UB_MAX_SM; + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + B_type, fp8_type, reducescatter2_userbuff_stridedoutput_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, - (_num_splits - 1) * output_chunk_size, m_chunk, n, - m, _ub_comm, (cudaStream_t)_stream_comm); - } - } else { - for (int i = 0; i < _num_splits; i++) { - torch::Tensor input_a_chunk = - torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = torch::from_blob( - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, - transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, - grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - CHECK_CUDA(cudaEventRecord(_start_comm, - (cudaStream_t)_stream_compute[i % _stream_compute.size()])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Communication chunk. Uses MAX_SM at the last chunk - if (i == _num_splits - 1) { - _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, + (_num_splits - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, (cudaStream_t)_stream_comm); + } + } else { + for (int i = 0; i < _num_splits; i++) { + torch::Tensor input_a_chunk = + torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); + torch::Tensor output_chunk = + torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, + output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + CHECK_CUDA(cudaEventRecord(_start_comm, + (cudaStream_t)_stream_compute[i % _stream_compute.size()])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + + // Communication chunk. Uses MAX_SM at the last chunk + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + B_type, fp8_type, reducescatter2_userbuff_stridedoutput_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, - m_chunk, n, m, _ub_comm, - (cudaStream_t)_stream_comm); - } - rs_output_ptr += m_chunk * rs_output.element_size(); - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - } - }); + _ub_comm, (cudaStream_t)_stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, + m_chunk, n, m, _ub_comm, + (cudaStream_t)_stream_comm); + } + rs_output_ptr += m_chunk * rs_output.element_size(); + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); + } + } for (size_t i = 0; i < _stream_compute.size(); i++) { CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); @@ -1057,20 +1063,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); // Reduce GEMM output chunks - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + B_type, fp8_type, reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, - _ubufs[0].numel(), (cudaStream_t)stream_main); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); - }); + _ubufs[0].numel(), (cudaStream_t)stream_main);); + } else { + torch::Tensor reduce_buf = torch::from_blob( + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + torch::sum_out(rs_output, reduce_buf, 0); + } } /* @@ -1153,20 +1159,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); // Reduce GEMM output chunks - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + B_type, fp8_type, reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, - _ubufs[0].numel(), (cudaStream_t)stream_main); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); - }); + _ubufs[0].numel(), (cudaStream_t)stream_main);); + } else { + torch::Tensor reduce_buf = torch::from_blob( + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + torch::sum_out(rs_output, reduce_buf, 0); + } for (size_t i = 0; i < _stream_compute.size(); i++) { CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); From e1c6d218a354b29d5d5f76636cbc19b89769ced6 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Fri, 19 Jul 2024 05:29:50 +0800 Subject: [PATCH 08/72] [Common] Use nvtx3 (#1025) Update nvtx header Signed-off-by: Reese Wang --- transformer_engine/common/nvtx.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/nvtx.h b/transformer_engine/common/nvtx.h index 191f3b06fa..4625e0ab9d 100644 --- a/transformer_engine/common/nvtx.h +++ b/transformer_engine/common/nvtx.h @@ -7,7 +7,7 @@ #ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_ #define TRANSFORMER_ENGINE_COMMON_NVTX_H_ -#include +#include #include From 238df4ce470a8fe0f9e88367e724a97a488e9bb7 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 19 Jul 2024 12:36:23 -0700 Subject: [PATCH 09/72] Initialize output tensors to 0 for THD (temporary) (#1009) * initialize output tensors to 0 for THD while waiting for cuDNN bug fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move fill_() to F16 loop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fused_attn_bwd() Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * correct typo in check_set_window_size Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use nvtx3 instead Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/paddle/cpp_extensions.py | 33 +++++++++++++++++-- transformer_engine/pytorch/attention.py | 2 +- .../pytorch/csrc/extensions/attention.cu | 30 +++++++++++++++++ 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py index e12c0dd3c4..cd57458c41 100644 --- a/transformer_engine/paddle/cpp_extensions.py +++ b/transformer_engine/paddle/cpp_extensions.py @@ -593,6 +593,9 @@ def fused_attn_fwd_qkvpacked( if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype) else: @@ -676,13 +679,19 @@ def fused_attn_bwd_qkvpacked( fused_attention_backend != FusedAttnBackend["No_Backend"] ), "Fused attention does not support this input combination." + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype) else: dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype) if bias_type != "no_bias": - dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) + if qkv_format == "thd": + dbias = paddle.zero(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) + else: + dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) else: dbias = None # execute kernel @@ -772,6 +781,9 @@ def fused_attn_fwd_kvpacked( if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) else: @@ -867,6 +879,9 @@ def fused_attn_bwd_kvpacked( fused_attention_backend != FusedAttnBackend["No_Backend"] ), "Fused attention does not support this input combination." + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype) @@ -874,7 +889,10 @@ def fused_attn_bwd_kvpacked( dq = paddle.empty(shape=q.shape, dtype=q.dtype) dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype) if bias_type != "no_bias": - dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) + if qkv_format == "thd": + dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) + else: + dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) else: dbias = None # execute kernel @@ -970,6 +988,9 @@ def fused_attn_fwd( if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) else: @@ -1065,6 +1086,9 @@ def fused_attn_bwd( fused_attention_backend != FusedAttnBackend["No_Backend"] ), "Fused attention does not support this input combination." + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype) @@ -1074,7 +1098,10 @@ def fused_attn_bwd( dk = paddle.empty(shape=k.shape, dtype=k.dtype) dv = paddle.empty(shape=v.shape, dtype=v.dtype) if bias_type != "no_bias": - dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) + if qkv_format == "thd": + dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) + else: + dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) else: dbias = None # execute kernel diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f71b469f2d..eda0c136d5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3068,7 +3068,7 @@ def check_set_window_size( warnings.warn( "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type ) - elif orig_window_size[0] < 0 or orig_window_size[0] < 0: + elif orig_window_size[0] < 0 or orig_window_size[1] < 0: assert False, ( "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type ) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index af84054b4c..9f4612f240 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -127,6 +127,9 @@ std::vector fused_attn_fwd_qkvpacked( te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + O.fill_(0); + } // BF16 or FP16 te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); @@ -288,6 +291,9 @@ std::vector fused_attn_bwd_qkvpacked( amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dQKV.fill_(0); + } // BF16 or FP16 te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); @@ -328,6 +334,9 @@ std::vector fused_attn_bwd_qkvpacked( options); te_dBias = makeTransformerEngineTensor(dBias); } + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dBias.fill_(0); + } } // create cu_seqlens tensorwrappers @@ -427,6 +436,9 @@ std::vector fused_attn_fwd_kvpacked( te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + O.fill_(0); + } // BF16 or FP16 te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_KV = @@ -614,6 +626,10 @@ std::vector fused_attn_bwd_kvpacked( amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dQ.fill_(0); + dKV.fill_(0); + } // BF16 or FP16 te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_KV = @@ -684,6 +700,9 @@ std::vector fused_attn_bwd_kvpacked( options); te_dBias = makeTransformerEngineTensor(dBias); } + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dBias.fill_(0); + } } // create workspace @@ -774,6 +793,9 @@ std::vector fused_attn_fwd( te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + O.fill_(0); + } // BF16 or FP16 te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); @@ -1037,6 +1059,11 @@ std::vector fused_attn_bwd( makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dQ.fill_(0); + dK.fill_(0); + dV.fill_(0); + } // BF16 or FP16 te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); @@ -1109,6 +1136,9 @@ std::vector fused_attn_bwd( options); te_dBias = makeTransformerEngineTensor(dBias); } + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dBias.fill_(0); + } } // create workspace From 33a3d02f81c56e6f7b542c09bfa86657078d57fb Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Sat, 20 Jul 2024 19:05:41 -0700 Subject: [PATCH 10/72] [PyTorch] Update Sequential container to handle changes in module base class (#1028) * Update sequential container constructor to handle modules in plain dicts Signed-off-by: Tim Moon * Avoid initializing Sequential with dicts Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/sequential.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index 95499a9e80..cd3c104860 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -5,7 +5,6 @@ """Sequential container for fusible operations.""" from __future__ import annotations -from collections import OrderedDict from collections.abc import Iterable, Iterator from typing import Optional @@ -39,7 +38,7 @@ def __init__( self._module_groups = None # Add modules - if len(args) == 1 and isinstance(args[0], OrderedDict): + if len(args) == 1 and isinstance(args[0], dict): for key, module in args[0].items(): self.add_module(key, module) else: @@ -82,8 +81,9 @@ def __getitem__( ) -> Sequential | torch.nn.Module: keys = self._get_keys_by_idx(idx) if isinstance(idx, slice): - modules = OrderedDict((str(i), self._modules[key]) for i, key in enumerate(keys)) - return self.__class__(modules) + out = Sequential() + out.extend(self._modules[key] for key in keys) + return out return self._modules[keys[0]] def __setitem__(self, idx: int, module: torch.nn.Module) -> None: @@ -129,11 +129,12 @@ def pop(self, idx: slice | int) -> torch.nn.Module: del self[idx] return out - def __iadd__(self, other: Sequential) -> Sequential: - return self.extend(other) + def __iadd__(self, modules: Iterable[torch.nn.Modules]) -> Sequential: + return self.extend(modules) def __add__(self, modules: Iterable[torch.nn.Modules]) -> Sequential: - out = self.__class__(self._modules) + out = Sequential() + out.extend(self) out.extend(modules) return out From 931b44feb6139590bd356283e9e7b0e1e4ad3246 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Mon, 22 Jul 2024 10:53:35 -0700 Subject: [PATCH 11/72] Fixed convergence issues with CPU offloading (#1026) * Fixed convergence issues Signed-off-by: Selvaraj Anandaraj * Update transformer_engine/pytorch/module/layernorm_linear.py Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/pytorch/module/layernorm_mlp.py Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/pytorch/module/linear.py Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Selvaraj Anandaraj Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/layernorm_linear.py | 4 +--- transformer_engine/pytorch/module/layernorm_mlp.py | 7 ++----- transformer_engine/pytorch/module/linear.py | 4 +--- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ba975d2758..76969a4712 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -289,8 +289,6 @@ def forward( if is_grad_enabled: if cpu_offloading: - if fuse_wgrad_accumulation: - weight.main_grad.weight_offloading = True if fp8 and weight_fp8 is not None: weight_fp8.weight_offloading = True ln_weight.weight_offloading = True @@ -411,7 +409,7 @@ def backward( ) if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight, False) + weight = torch.nn.Parameter(weight.requires_grad) weight.main_grad = main_grad if ctx.ub_overlap_rs_dgrad: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8b971e186b..83dd2ebe03 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -425,9 +425,6 @@ def forward( if is_grad_enabled: if cpu_offloading: - if fuse_wgrad_accumulation: - fc1_weight.main_grad.weight_offloading = True - fc2_weight.main_grad.weight_offloading = True if fp8 and fc1_weight_fp8 is not None: fc1_weight_fp8.weight_offloading = True if fp8 and fc2_weight_fp8 is not None: @@ -570,8 +567,8 @@ def backward( ) if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - fc1_weight = Parameter(fc1_weight, False) - fc2_weight = Parameter(fc2_weight, False) + fc1_weight = Parameter(fc1_weight.requires_grad) + fc2_weight = Parameter(fc2_weight.requires_grad) fc1_weight.main_grad = fc1_weight_main_grad fc2_weight.main_grad = fc2_weight_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 745ee9b72e..a95fa1c33a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -310,8 +310,6 @@ def forward( saved_inputmat = inputmat_no_fp8 if cpu_offloading: - if fuse_wgrad_accumulation: - weight.main_grad.weight_offloading = True if fp8 and weight_fp8 is not None: weight_fp8.weight_offloading = True weight.weight_offloading = True @@ -403,7 +401,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight, False) + weight = torch.nn.Parameter(weight.requires_grad) weight.main_grad = main_grad tp_world_size = get_distributed_world_size(ctx.tp_group) From 71124c31fbf775d7e5355d7652ddf67e831b7095 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Tue, 23 Jul 2024 06:42:40 -0700 Subject: [PATCH 12/72] Remove unwanted Memory Copies/Fix weight parameters (#1034) * removed unwanted memcpyDtoD/fixed weight parametrisation Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/cpu_offload.py | 68 +++---------------- .../pytorch/module/grouped_linear.py | 5 +- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 4 +- transformer_engine/pytorch/module/linear.py | 2 +- 5 files changed, 16 insertions(+), 65 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index b07c6d3508..b42d40d9f3 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -284,11 +284,8 @@ def __init__( debug=debug, ) self.num_prefetch_group = num_prefetch_group - - # prepare for tensor buffer - self.tensor_id_to_tensor_buf_double_bufs = [] - for _ in range(2): - self.tensor_id_to_tensor_buf_double_bufs.append({}) + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} # allocate streams and events for synchronization self.d2h_stream = torch.cuda.Stream() @@ -300,37 +297,6 @@ def __init__( self.compute_stream_bwd_start_events.append(torch.cuda.Event()) self.d2h_final_event = torch.cuda.Event() - def get_tensor_buf_for_offloaded_tensor(self, tensor, tensor_tag): - """Get tensor buffer for offloaded tensor.""" - group_id, tensor_id = tensor_tag - # obtain ping-pong buffer - id_buf_map = self.tensor_id_to_tensor_buf_double_bufs[(group_id % 2)] - - if not tensor_id in id_buf_map: - allocate_new_buf = True - else: - tensor_buf = id_buf_map[tensor_id] - allocate_new_buf = ( - tensor_buf.size() != tensor.size() or tensor_buf.dtype != tensor.dtype - ) - - if allocate_new_buf: - # supposed to only execute once - fp8_offload = isinstance(tensor, Float8Tensor) - buffer = torch.empty( - tensor.size(), - dtype=torch.uint8 if fp8_offload else tensor.dtype, - layout=tensor.layout, - device=tensor.device, - ) - - if isinstance(tensor, Float8Tensor): - id_buf_map[tensor_id] = Float8Tensor.make_like(tensor, data=buffer) - else: - id_buf_map[tensor_id] = buffer - - return id_buf_map[tensor_id] - def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: torch_stray_tensor = isinstance( @@ -347,21 +313,12 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: self.tensor_count_current_group += 1 assert tensor_tag not in self.tensor_tag_to_state + self.tensor_tag_to_state[tensor_tag] = tensor + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( tensor ): - # first copy the tensor to tensorbuf, - # so that the original tensor will not be deleted - tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag) - tensor_buf.copy_(tensor) - if hasattr(tensor, "weight_offloading"): - tensor_buf.weight_offloading = True - if hasattr(tensor, "activation_offloading"): - tensor_buf.activation_offloading = True - # Here we just save it, and at commit, bulk_offload_group will handle it - self.tensor_tag_to_state[tensor_tag] = tensor_buf - else: - self.tensor_tag_to_state[tensor_tag] = tensor + self.tensor_tag_to_buf[tensor_tag] = tensor else: tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 @@ -373,6 +330,7 @@ def tensor_pop(self, tensor_tag, **kwargs): """Tensor pop.""" assert tensor_tag in self.tensor_tag_to_state tensor = self.tensor_tag_to_state.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) # the tensor should have been copied back in on_group_commit_backward() # which invokes bulk_reload_group. assert not isinstance(tensor, tuple) @@ -389,10 +347,6 @@ def bulk_offload_group(self, group_to_offload): # if offload, return the reference to cpu copy if self.tensor_need_offloading_checker(tensor_on_device): - if hasattr(tensor_on_device, "weight_offloading"): - delattr(tensor_on_device, "weight_offloading") - if hasattr(tensor_on_device, "activation_offloading"): - delattr(tensor_on_device, "activation_offloading") state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) self.tensor_tag_to_state[tensor_tag] = state @@ -403,12 +357,12 @@ def synchronize_on_group_commit_forward(self, current_group): previous_group = current_group - 1 if previous_group < self.num_offload_group: torch.cuda.synchronize() - # TODO (guyueh): this part is originally designed to reduce the peak memory usage. # pylint: disable=fixme - # however, uncommenting this part will cause illegal access, have not figured out why. - if previous_group + 2 >= self.num_offload_group: - # this buffer is no longer required - self.tensor_id_to_tensor_buf_double_bufs[(previous_group % 2)] = {} + # Have to release the memory held by activations of the previous layer + if previous_group >= 0: + for tensor_tag, _ in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == previous_group: + self.tensor_tag_to_buf[tensor_tag] = None # the copying of this group should wait for the computation stream event if current_group < self.num_offload_group: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e598f167fa..050ff6a02e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -237,9 +237,6 @@ def forward( saved_inputmats = inputmats_no_fp8 if cpu_offloading: - if fuse_wgrad_accumulation: - for w in weights: - w.main_grad.weight_offloading = True if fp8: for w in weights_fp8: if w is not None: @@ -303,7 +300,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], main_grads = saved_tensors[4 * ctx.num_gemms :] if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: for i in ctx.num_gemms: - w = torch.nn.Parameter(weights[i], False) + w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w.main_grad = main_grads[i] weights[i] = w diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 76969a4712..22d7813605 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -409,7 +409,7 @@ def backward( ) if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight.requires_grad) + weight = torch.nn.Parameter(weight, weight.requires_grad) weight.main_grad = main_grad if ctx.ub_overlap_rs_dgrad: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 83dd2ebe03..5be8ee9e29 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -567,8 +567,8 @@ def backward( ) if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - fc1_weight = Parameter(fc1_weight.requires_grad) - fc2_weight = Parameter(fc2_weight.requires_grad) + fc1_weight = Parameter(fc1_weight, fc1_weight.requires_grad) + fc2_weight = Parameter(fc2_weight, fc2_weight.requires_grad) fc1_weight.main_grad = fc1_weight_main_grad fc2_weight.main_grad = fc2_weight_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a95fa1c33a..7510254a9d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -401,7 +401,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight.requires_grad) + weight = torch.nn.Parameter(weight, weight.requires_grad) weight.main_grad = main_grad tp_world_size = get_distributed_world_size(ctx.tp_group) From 5ee98175788d2c3c3945980e0c12fb8dfc6ea94d Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 23 Jul 2024 17:02:07 -0500 Subject: [PATCH 13/72] [PyTorch] Fixing hang in `initialize_ub()` for multi-node runs after PR901 removal of MPI-dependence (#986) * Re-implementing PR901 (removing MPI-dependence in Userbuffers) with multi-node fixes * passing data-parallel rank/size info from torch.distributed to userbuffers Signed-off-by: Alp Dener * multi-node example working with UB_SKIPMC=1 but not with multicast Signed-off-by: Alp Dener * fixed multi-node hang in initialize_ub(), updated comm+GEMM overlap example to support multi-node mixed tensor/data parallelism, added README Signed-off-by: Alp Dener * fixed use case when Userbuffers is asked to allocate the TP overlap buffer with UB_SKIPMC=1 Signed-off-by: Alp Dener * corrected example problem to set device by local ordinal instead of global process rank Signed-off-by: Alp Dener * double-free fix in userbuffers destructor Signed-off-by: Alp Dener * removed unnecessary and incorrect torch.cuda.set_device(...) Signed-off-by: Alp Dener * corrected inter-node ranks logic Signed-off-by: Alp Dener * generalized node ID logic in initialize_ub to handle arbitrary world rank layouts within node Signed-off-by: Alp Dener * added single-node comm+GEMM overlap unit tests Signed-off-by: Alp Dener * LayerNormMLP example confirmed working with 2 nodes on Eos Signed-off-by: Alp Dener * unit test cleanup Signed-off-by: Alp Dener * corrected DP group ranks logic in LNMLP comm+GEMM overlap example Signed-off-by: Alp Dener * corrected enums in unit test Signed-off-by: Alp Dener * fixed incorrect Ubuf object init signature Signed-off-by: Alp Dener * switched default backend for Userbuffer bootstrapping to Gloo with MPI and NCCL fallbacks, and initialize_ub option to manually select backend Signed-off-by: Alp Dener * fixed all comm+GEMM overlap unit tests Signed-off-by: Alp Dener * corrected all_gather use for Gloo backend Signed-off-by: Alp Dener * changed userbuffers allgather callback to always use all_gather() instead of all_gather_into_tensor() Signed-off-by: Alp Dener * restored and verified old MPI-based bootstrapping via NVTE_UB_WITH_MPI=1 option at compile time Signed-off-by: Alp Dener * disabled scoped GIL release for comm+GEMM overlap algorithms Signed-off-by: Alp Dener * avoid dist.init_device_mesh in comm+GEMM overlap example to support older PyTorch versions Signed-off-by: Alp Dener * applied RS overlap FP8 fix from PR1004 Signed-off-by: Alp Dener * fixed segfault in Userbuffers destructor Signed-off-by: Alp Dener * corrected comm+GEMM overlap unit test arguments Signed-off-by: Alp Dener * fixed unit test run command for when Userbuffers is compiled with MPI Signed-off-by: Alp Dener * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactored torch.distributed collectives into pure C++ callbacks Signed-off-by: Alp Dener * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- build_tools/pytorch.py | 8 +- examples/pytorch/comm_gemm_overlap/README.md | 158 ++++ .../comm_gemm_overlap/ln_mlp_with_overlap.py | 253 ++++-- .../distributed/run_gemm_with_overlap.py | 810 ++++++++++++++++++ .../distributed/test_comm_gemm_overlap.py | 105 +++ .../pytorch/csrc/comm_gemm_overlap.h | 464 +++++----- .../pytorch/csrc/extensions/pybind.cpp | 22 +- .../pytorch/csrc/userbuffers/ipcsocket.cc | 150 ++-- .../pytorch/csrc/userbuffers/ipcsocket.h | 52 +- .../csrc/userbuffers/userbuffers-host.cpp | 322 +++---- .../pytorch/csrc/userbuffers/userbuffers.cu | 60 +- .../pytorch/csrc/userbuffers/userbuffers.h | 58 +- transformer_engine/pytorch/module/base.py | 206 +++-- 13 files changed, 2013 insertions(+), 655 deletions(-) create mode 100644 examples/pytorch/comm_gemm_overlap/README.md create mode 100644 tests/pytorch/distributed/run_gemm_with_overlap.py create mode 100644 tests/pytorch/distributed/test_comm_gemm_overlap.py diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index a704d40264..e423ffe907 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -77,14 +77,14 @@ def setup_pytorch_extension( # Libraries library_dirs = [] libraries = [] - if os.getenv("UB_MPI_BOOTSTRAP"): + if os.getenv("NVTE_UB_WITH_MPI"): assert ( os.getenv("MPI_HOME") is not None - ), "MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1" + ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" mpi_home = Path(os.getenv("MPI_HOME")) include_dirs.append(mpi_home / "include") - cxx_flags.append("-DUB_MPI_BOOTSTRAP") - nvcc_flags.append("-DUB_MPI_BOOTSTRAP") + cxx_flags.append("-DNVTE_UB_WITH_MPI") + nvcc_flags.append("-DNVTE_UB_WITH_MPI") library_dirs.append(mpi_home / "lib") libraries.append("mpi") diff --git a/examples/pytorch/comm_gemm_overlap/README.md b/examples/pytorch/comm_gemm_overlap/README.md new file mode 100644 index 0000000000..bb3ba209ed --- /dev/null +++ b/examples/pytorch/comm_gemm_overlap/README.md @@ -0,0 +1,158 @@ +# Overlapping Communication with GEMM in TransformerEngine Modules + +## Requirements + +- Tensor-parallel GPUs must be on a single node, and connected over NVLink/NVSwitch. +- `CUDA_DEVICE_MAX_CONNECTIONS=1` must be enabled in the environment. +- For best performance, point-to-point communication via _CUDA Multicast_ needs CUDA Toolkit 12.0+ + and CUDA driver 535+ on devices with compute capability 9.0 or newer. +- Devices older than compute capability 9.0 require `UB_SKIPMC=1` in the environment in order fall + back on a less performant implementation based on CUDA Inter-Process Communication (IPC) handles. + +## Examples + +### Single node, tensor-parallel LayerNormMLP: + +Forward and backward passes with layer weights distributed over all GPUs in a single node. + +```bash +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py + +# Sample output on 8x H100s: +# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7] +# !!! [UB] Create UbufP2PCommOverlap Communicator +# UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz +# MC initialized succesfully, window size = 549755813888 +# !!! [UBP2P] Register UBuf 1 +# !!! [UBP2P] Register UBuf 2 +# !!! [UBP2P] Register UBuf 3 +# !!! [UBP2P] Register UBuf 4 +# !!! [UB] Register UBuf 5 +# !!! [UBP2P] Register UBuf 6 +# !!! [UB] Register UBuf 7 +# !!! [UB] Register UBuf 8 +# !!! [UBP2P] Register UBuf 9 +# !!! [UB] Register UBuf 10 +# [rank0:node0] Iter 1 +# [rank0:node0] |-- Generate random input batch +# [rank0:node0] |-- Forward pass +# [rank0:node0] |-- Compute loss +# [rank0:node0] |-- Backward pass +# [rank0:node0] |-- Optimizer step +# [rank0:node0] Iter 2 +# [rank0:node0] |-- Generate random input batch +# [rank0:node0] |-- Forward pass +# [rank0:node0] |-- Compute loss +# [rank0:node0] |-- Backward pass +# [rank0:node0] |-- Optimizer step +# [rank0:node0] Iter 3 +# [rank0:node0] |-- Generate random input batch +# [rank0:node0] |-- Forward pass +# [rank0:node0] |-- Compute loss +# [rank0:node0] |-- Backward pass +# [rank0:node0] |-- Optimizer step +# [rank0:node0] Iter 4 +# [rank0:node0] |-- Generate random input batch +# [rank0:node0] |-- Forward pass +# [rank0:node0] |-- Compute loss +# [rank0:node0] |-- Backward pass +# [rank0:node0] |-- Optimizer step +# [rank0:node0] Iter 5 +# [rank0:node0] |-- Generate random input batch +# [rank0:node0] |-- Forward pass +# [rank0:node0] |-- Compute loss +# [rank0:node0] |-- Backward pass +# [rank0:node0] |-- Optimizer step +``` +### Single node, mixed data- and tensor-parallel LayerNormMLP: + +Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across 2 tensor-parallel +groups in a single node. + +```bash +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2 + +# Sample output on 8x H100s: +# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3] +# [rank4:node1] |-- Created tensor-parallel group: [4, 5, 6, 7] +# [rank0:node0] |-- Created data-parallel group: [0, 4] +# [rank3:node1] |-- Created data-parallel group: [3, 7] +# [rank1:node1] |-- Created data-parallel group: [1, 5] +# [rank2:node0] |-- Created data-parallel group: [2, 6] +# !!! [UB] Create UbufP2PCommOverlap Communicator +# UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz +# MC initialized succesfully, window size = 549755813888 +# !!! [UBP2P] Register UBuf 1 +# !!! [UBP2P] Register UBuf 2 +# !!! [UBP2P] Register UBuf 3 +# !!! [UBP2P] Register UBuf 4 +# !!! [UB] Register UBuf 5 +# !!! [UBP2P] Register UBuf 6 +# !!! [UB] Register UBuf 7 +# !!! [UB] Register UBuf 8 +# !!! [UBP2P] Register UBuf 9 +# !!! [UB] Register UBuf 10 +# [rank4:node1] Iter 1 +# [rank0:node0] Iter 1 +# [rank0:node0] |-- Generate random input batch +# [rank4:node1] |-- Generate random input batch +# [rank0:node0] |-- Forward pass +# [rank4:node1] |-- Forward pass +# [rank4:node1] |-- Compute loss +# [rank0:node0] |-- Compute loss +# [rank0:node0] |-- Backward pass +# [rank4:node1] |-- Backward pass +# [rank4:node1] |-- Optimizer step +# [rank0:node0] |-- Optimizer step +# [rank4:node1] Iter 2 +# [rank0:node0] Iter 2 +# [rank0:node0] |-- Generate random input batch +# [rank4:node1] |-- Generate random input batch +# [rank4:node1] |-- Forward pass +# [rank0:node0] |-- Forward pass +# [rank4:node1] |-- Compute loss +# [rank0:node0] |-- Compute loss +# [rank4:node1] |-- Backward pass +# [rank0:node0] |-- Backward pass +# [rank4:node1] |-- Optimizer step +# [rank0:node0] |-- Optimizer step +# [rank4:node1] Iter 3 +# [rank0:node0] Iter 3 +# [rank0:node0] |-- Generate random input batch +# [rank4:node1] |-- Generate random input batch +# [rank0:node0] |-- Forward pass +# [rank4:node1] |-- Forward pass +# [rank4:node1] |-- Compute loss +# [rank0:node0] |-- Compute loss +# [rank4:node1] |-- Backward pass +# [rank0:node0] |-- Backward pass +# [rank0:node0] |-- Optimizer step +# [rank4:node1] |-- Optimizer step +# [rank0:node0] Iter 4 +# [rank4:node1] Iter 4 +# [rank0:node0] |-- Generate random input batch +# [rank4:node1] |-- Generate random input batch +# [rank0:node0] |-- Forward pass +# [rank4:node1] |-- Forward pass +# [rank0:node0] |-- Compute loss +# [rank4:node1] |-- Compute loss +# [rank4:node1] |-- Backward pass +# [rank0:node0] |-- Backward pass +# [rank4:node1] |-- Optimizer step +# [rank0:node0] |-- Optimizer step +# [rank4:node1] Iter 5 +# [rank0:node0] Iter 5 +# [rank0:node0] |-- Generate random input batch +# [rank4:node1] |-- Generate random input batch +# [rank0:node0] |-- Forward pass +# [rank4:node1] |-- Forward pass +# [rank0:node0] |-- Compute loss +# [rank4:node1] |-- Compute loss +# [rank0:node0] |-- Backward pass +# [rank4:node1] |-- Backward pass +# [rank4:node1] |-- Optimizer step +# [rank0:node0] |-- Optimizer step +``` + +**NOTE:** To run with Fp8 compute on supporting hardware, add the `--fp8` flag to the commands +shown above. diff --git a/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py b/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py index 619dbaf9d7..412c948a83 100644 --- a/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py @@ -6,17 +6,22 @@ import os import sys -import subprocess +import socket import argparse +import warnings import torch import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel import transformer_engine.pytorch as te from transformer_engine.common.recipe import Format, DelayedScaling +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) -def parse_args(argv=None, namespace=None): + +def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser( description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers." ) @@ -47,63 +52,182 @@ def parse_args(argv=None, namespace=None): default=False, help="Disable the comm+GEMM overlap.", ) - parser.add_argument("-v", "--verbose", action="store_true", default=False) - return parser.parse_args(argv, namespace) + parser.add_argument( + "--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." + ) + parser.add_argument( + "--tcp-init", + action="store_true", + default=False, + help="Initialize torch.distributed with TcpStore.", + ) + parser.add_argument( + "--bind-to-device", + action="store_true", + default=False, + help="Initialize torch.distributed with `device_id` to bind each rank to a single device.", + ) + parser.add_argument( + "--bootstrap-backend", + type=str.lower, + default="nccl", + choices=["gloo", "mpi", "nccl"], + help="Communications backend for host tensor collectives during Userbuffers bootstrapping.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="Print out from every rank instead of just the root rank of relevant process groups.", + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Print out additional debug information.", + ) + args = parser.parse_args(argv, namespace) + return args -def train(opts): - WORLD_RANK = int(os.getenv("RANK")) - WORLD_SIZE = int(os.getenv("WORLD_SIZE")) +def _train(opts): + if "OMPI_COMM_WORLD_SIZE" in os.environ: + # Execution with `mpirun -np N` + WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) + WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) + opts.tcp_init = True + opts.bind_to_device = True + opts.bootstrap_backend = "mpi" + elif "TORCHELASTIC_RUN_ID" in os.environ: + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + else: + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + NUM_NODES = WORLD_SIZE // LOCAL_SIZE + + def dist_print(msg, group=None, end="\n", debug=False): + if debug and not opts.debug: + return + group = dist.new_group() if group is None else group + group_rank = dist.get_rank(group) + group_size = dist.get_world_size(group) + all_ranks = dist.get_process_group_ranks(group) + ranks_skip = all_ranks[1] - all_ranks[0] > 1 + group_id = WORLD_RANK % group_size if ranks_skip else WORLD_RANK // group_size + if group_rank == 0 or opts.verbose: + print(f"[rank{WORLD_RANK}:node{group_id}] {msg}{end}", end="", flush=True) + dist.barrier(group) + + # Initialize torch.distributed global process group and get DP/TP groups + torch.cuda.set_device(LOCAL_RANK) + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + if opts.tcp_init or NUM_NODES > 1: + if NUM_NODES > 1: + assert ( + "MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." + MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())) + MASTER_PORT = os.getenv("MASTER_PORT", "1234") + dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}" + if opts.bind_to_device or opts.bootstrap_backend == "nccl": + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + dist_print(f"Initialized default NCCL process group with {WORLD_RANK} GPUs", nccl_world) + + # Figure out process groups for tensor- and data-parallelism (if any) + if NUM_NODES > 1: + # Create a list of world ranks on this node + hostnames = [None for _ in range(WORLD_SIZE)] + hostname = socket.gethostname() + dist.all_gather_object(hostnames, hostname) + node_ranks = [] + for i, host in enumerate(hostnames): + if host == hostname: + node_ranks.append(i) + + if opts.num_replicas > 1: + # Split node ranks into multiple replicas + assert len(node_ranks) % opts.num_replicas == 0 + tp_size = len(node_ranks) // opts.num_replicas + found_replica = False + for replica in range(opts.num_replicas): + start = replica * tp_size + end = start + tp_size + tp_ranks = node_ranks[start:end] + if WORLD_RANK in tp_ranks: + found_replica = True + break + assert found_replica + else: + # The entire node is the tensor-parallel group + tp_ranks = node_ranks + + tp_group = dist.new_group(backend="nccl", ranks=tp_ranks) + tp_size = dist.get_world_size(tp_group) + tp_rank = dist.get_rank(tp_group) + + # Data-parallelism across TP groups + dp_start = tp_rank + dp_end = dp_start + WORLD_SIZE + dp_ranks = list(range(dp_start, dp_end, tp_size)) + dp_group = dist.new_group(backend="nccl", ranks=dp_ranks) + + else: + if opts.num_replicas > 1: + # Mixed data- and tensor-parallelism on a single node + # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions + all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") + mesh2d = all_ranks.reshape((opts.num_replicas, LOCAL_SIZE // opts.num_replicas)) + node_idx = (mesh2d == LOCAL_RANK).nonzero().squeeze().tolist() - def dist_print(msg, end="\n", all_ranks=False): - if WORLD_RANK == 0 or all_ranks: - print(f"[RANK-{WORLD_RANK}] {msg}", end=end) + tp_ranks = mesh2d[node_idx[0], :].tolist() + tp_group = dist.new_group(backend="nccl", ranks=tp_ranks) - # Seed RNG - torch.cuda.set_device(WORLD_RANK) - torch.manual_seed(opts.seed + WORLD_RANK) - torch.cuda.manual_seed(opts.seed + WORLD_RANK) + dp_ranks = mesh2d[:, node_idx[1]].tolist() + dp_group = dist.new_group(backend="nccl", ranks=dp_ranks) + else: + dp_group = None + tp_group = nccl_world - # Initialize torch.distributed global process group and get TP group - dist.init_process_group( - backend="nccl", - rank=WORLD_RANK, - world_size=WORLD_SIZE, - device_id=torch.device(f"cuda:{WORLD_RANK}"), + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) + + dist_print( + f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}", + group=tp_group, ) - tp_group = dist.new_group(backend="nccl") - tp_size = dist.get_world_size(tp_group) + if dp_group is not None: + dist_print( + f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}", + group=dp_group, + ) # Intialize userbuffers - ag_cfg = { # Ring-exchange All-Gather overlap for fc1_fprop and fc2_dgrad - "method": "ring_exchange", - "num_splits": 8, - "num_sm": 1, - "set_sm_margin": False, - } - rs_cfg = { # Reduce-scatter overlap for fc1_dgrad and fc2_fprop - "method": "ring_exchange", - "num_splits": 4, - "num_sm": 1, - "set_sm_margin": True, - } hidden_size = opts.num_heads * opts.head_dim batched_size = opts.seq_length * opts.batch_size if not opts.no_comm_overlap: - te.initialize_ub( + te.module.base.initialize_ub( [batched_size, hidden_size], - tp_group, + tp_size, use_fp8=opts.fp8, dtype=torch.bfloat16, - ub_cfgs={ - "fc1_fprop": ag_cfg, - "fc1_dgrad": rs_cfg, - "fc2_fprop": rs_cfg, - "fc2_dgrad": ag_cfg, - }, + bootstrap_backend=opts.bootstrap_backend, ) - # + # Initialize the fused LayerNorm + Multi-layer Perceptron module + torch.manual_seed(opts.seed + tp_rank) + torch.cuda.manual_seed(opts.seed + tp_rank) model = te.LayerNormMLP( hidden_size, opts.mlp_expansion_factor * hidden_size, @@ -114,11 +238,14 @@ def dist_print(msg, end="\n", all_ranks=False): set_parallel_mode=True, sequence_parallel=True, # this is required for comm+GEMM overlap seq_length=opts.seq_length, - micro_batch_size=opts.batch_size, - ub_overlap_rs_dgrad=not opts.no_comm_overlap, ub_overlap_rs=not opts.no_comm_overlap, ub_overlap_ag=not opts.no_comm_overlap, + ub_overlap_rs_dgrad=not opts.no_comm_overlap, + ub_bulk_dgrad=False, + ub_bulk_wgrad=not opts.no_comm_overlap, ) + if dp_group is not None: + model = DistributedDataParallel(model, process_group=dp_group) # Initialize optimizer with model parameters optim = torch.optim.Adam(model.parameters(), lr=0.0001) @@ -128,10 +255,11 @@ def dist_print(msg, end="\n", all_ranks=False): fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") # Start dummy "training" iterations + dist_print("Starting training iterations...", nccl_world) for i in range(opts.num_iters): - dist_print(f"Iter {i+1}", all_ranks=opts.verbose) + dist_print(f" Iter {i+1}", tp_group, debug=True) - dist_print("|-- Generate random input batch", all_ranks=opts.verbose) + dist_print(" |-- Generate random input batch", tp_group, debug=True) x = torch.rand( (opts.seq_length // tp_size, opts.batch_size, hidden_size), dtype=torch.bfloat16, @@ -139,30 +267,29 @@ def dist_print(msg, end="\n", all_ranks=False): requires_grad=True, ) - dist_print("|-- Forward pass", all_ranks=opts.verbose) - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group): + dist_print(" |-- Forward pass", tp_group, debug=True) + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): y = model(x) - dist_print("|-- Compute loss", all_ranks=opts.verbose) + dist_print(" |-- Compute loss", tp_group, debug=True) loss = y.flatten().sum() - dist_print("|-- Backward pass", all_ranks=opts.verbose) + dist_print(" |-- Backward pass", tp_group, debug=True) loss.backward() - dist_print("|-- Optimizer step", all_ranks=opts.verbose) + dist_print(" |-- Optimizer step", tp_group, debug=True) optim.step() - te.destroy_ub() + torch.cuda.synchronize() + dist_print("Finished training!") + te.module.base.destroy_ub() + + dist_print("Destroying all process groups...", debug=True) dist.destroy_process_group() + if opts.debug and WORLD_RANK == 0: + print("Exiting...\n", end="", flush=True) + + return 0 if __name__ == "__main__": - if "TORCHELASTIC_RUN_ID" in os.environ.keys(): - args = parse_args() - train(args) - else: - subprocess.run( - ["torchrun", f"--nproc-per-node={torch.cuda.device_count()}", *sys.argv], - env=os.environ, - check=True, - ) - os._exit(0) + sys.exit(_train(_parse_args())) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py new file mode 100644 index 0000000000..d7dc3e1ce1 --- /dev/null +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -0,0 +1,810 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import socket +import warnings +import subprocess +import argparse +import operator +from functools import partial, reduce + +import torch +import torch.distributed as dist +from torch.distributed.elastic.multiprocessing.errors import record + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpp_extensions as tex +from transformer_engine.common.recipe import Format +from transformer_engine.pytorch.fp8 import _default_sf_compute + +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +torch_dtypes = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +nvte_comm_types = { + "rs": 0, + "ag": 1, +} + + +def _mapped_argtype(opt, typemap): + if str(opt).lower() not in typemap.keys(): + raise TypeError(f"Unrecognized option! Please choose from: {typemap.keys()}") + return typemap[str(opt).lower()] + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser(description="Test comm+GEMM overlap with Userbuffers.") + parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") + parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.") + parser.add_argument( + "-n", "--num-heads", type=int, default=64, help="Number of attention heads." + ) + parser.add_argument( + "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." + ) + parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") + parser.add_argument( + "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + ) + parser.add_argument( + "--p2p", action="store_true", default=False, help="Test overlap with P2P comms." + ) + parser.add_argument( + "--atomic", action="store_true", default=False, help="Test overlap with atomic GEMM." + ) + parser.add_argument( + "--aggregate", + action="store_true", + default=False, + help="Aggregate 2X chunks for P2P split pipelined all-gather.", + ) + parser.add_argument( + "--comm-type", + type=partial(_mapped_argtype, typemap=nvte_comm_types), + default=0, + help="Comm type to overlap.", + ) + parser.add_argument( + "--bulk-overlap", + action="store_true", + default=False, + help="Enable bulk AG or RS overlap for a tensor that is not involved in the GEMM compute.", + ) + parser.add_argument( + "--check-numerics", + action="store_true", + default=False, + help="Test numerical result against torch.matmul(...)", + ) + parser.add_argument( + "--warmup-iters", + type=int, + default=0, + help="Run some warmup iterations of the comm+GEMM overlap before " + "the timing runs.", + ) + parser.add_argument( + "--timing-iters", + type=int, + default=1, + help="Benchmark the comm+GEMM overlap as an average of many iterations.", + ) + parser.add_argument( + "--clock-speed", + type=int, + default=-1, + help="Set device clock speed to a fixed value via `nvidia-smi`.", + ) + parser.add_argument( + "--scale", type=float, default=1e-2, help="Set scaling factor for input and weight tensors." + ) + parser.add_argument( + "--tcp-init", + action="store_true", + default=False, + help="Initialize torch.distributed with TcpStore.", + ) + parser.add_argument( + "--init-method", type=str, default=None, help="Set the torch.distributed init method." + ) + parser.add_argument( + "--bind-to-device", + action="store_true", + default=False, + help=( + "Initialize torch.distributed with 'device_id' argument to bind each rank to 1 device." + ), + ) + parser.add_argument( + "--bootstrap-backend", + type=str.lower, + default="nccl", + choices=["gloo", "mpi", "nccl"], + help=( + "PyTorch distributed backend for host tensor collectives during comm+GEMM overlap " + + "initialization." + ), + ) + parser.add_argument( + "-v", "--verbose", action="store_true", default=False, help="Verbose info messages." + ) + opts = parser.parse_args(argv, namespace) + + if opts.bulk_overlap: + if opts.p2p: + warnings.warn("Point-2-point comms are not supported with bulk overlap.") + opts.p2p = False + if opts.atomic: + warnings.warn("Atomic GEMM is not supported with bulk overlap.") + opts.atomic = False + if opts.fp8: + warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") + opts.fp8 = False + elif opts.comm_type == 1 and not opts.p2p: + warnings.warn("All-gather overlap is only supported with point-2-point comms.") + opts.p2p = True + + if opts.atomic: + if not te.fp8.check_fp8_support(): + assert not opts.fp8, "Atomic GEMM is only supported in FP8." + opts.fp8 = True + + return opts + + +@record +def _main(opts): + if "OMPI_COMM_WORLD_SIZE" in os.environ: + # Execution with `mpirun -np N` + WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) + WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) + opts.tcp_init = True + opts.bootstrap_backend = "mpi" + elif "TORCHELASTIC_RUN_ID" in os.environ: + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + else: + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + + # Fix clock speed + torch.cuda.set_device(LOCAL_RANK) + if opts.clock_speed > 0: + subprocess.run( + ["nvidia-smi", "-pm", "ENABLED", "-i", str(LOCAL_RANK)], + env=os.environ, + check=False, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + result = subprocess.run( + ["nvidia-smi", "-lgc", str(opts.clock_speed), "-i", str(LOCAL_RANK)], + env=os.environ, + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + msg = result.stdout.decode("utf-8").splitlines()[0] + print(f"[rank:{LOCAL_RANK}] {msg}\n", end="", flush=True) + + # Info printout + def dist_print(msg, src=None, info=False, section=False, group=None): + group = dist.new_group() if group is None else group + rank = dist.get_rank(group) + if info or opts.verbose: + if section: + if rank == (0 if src is None else src): + print("\n", end="", flush=True) + dist.barrier(group) + if src is None or rank == src: + prefix = "[GLOBAL] " if src is not None else f"[rank:{rank}] " + lines = msg.splitlines() + msg = "\n".join( + [prefix + lines[0]] + [(" " * len(prefix)) + line for line in lines[1:]] + ) + print(msg + "\n", end="", flush=True) + dist.barrier(group) + + # Initialize torch.distributed global process group and get TP group + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + if opts.tcp_init: + if opts.init_method is not None: + assert opts.init_method.startswith("tcp://") + init_method = opts.init_method + else: + MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())) + MASTER_PORT = os.getenv("MASTER_PORT", "1234") + init_method = f"tcp://{MASTER_ADDR}:{MASTER_PORT}" + dist_init_kwargs["init_method"] = init_method + elif opts.init_method is not None: + assert ( + opts.init_method.startswith("env://") + or opts.init_method.startswith("file://") + or opts.init_method.startswith("tcp://") + ) + dist_init_kwargs["init_method"] = opts.init_method + if opts.bind_to_device or opts.bootstrap_backend == "nccl": + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + tp_group = dist.new_group(backend="nccl") + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) + dist_print( + f"Initialized default NCCL process group with {tp_size} GPUs", + src=0, + section=True, + info=True, + group=tp_group, + ) + + # Initialize backend used in bootstrapping Userbuffers + if opts.bootstrap_backend == "gloo": + assert dist.is_gloo_available() + elif opts.bootstrap_backend == "mpi": + assert dist.is_mpi_available() + bootstrap_pg = dist.new_group(backend=opts.bootstrap_backend) + dist_print( + f'Bootstrapping comm+GEMM overlap with backend="{opts.bootstrap_backend}"', + src=0, + section=True, + info=True, + group=bootstrap_pg, + ) + if WORLD_RANK == 0: + print("\n", end="", flush=True) + + ub_callbacks = ( + tex.UbufBootstrapCallbacks() + if tex.ubuf_built_with_mpi() + else tex.UbufBootstrapCallbacks(bootstrap_pg, bootstrap_pg) + ) + + if opts.comm_type == 0: + if opts.bulk_overlap: + ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_RS + elif opts.p2p: + ub_algo = ( + tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + if opts.atomic + else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ) + else: + ub_algo = ( + tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + if opts.atomic + else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ) + elif opts.comm_type == 1: + if opts.bulk_overlap: + ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + else: + ub_algo = ( + tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + if opts.atomic + else tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ) + else: + raise TypeError("Invalid comm+GEMM overlap type!") + + # Initialize userbuffers with (M, N) buffer + # M = sequence * batch + # N = hidden size + hidden_size = opts.num_heads * opts.head_dim + inp_shape = (opts.seq_length, opts.batch_size, hidden_size) + outer_size = reduce(operator.mul, inp_shape[:-1], 1) + ubuf_dtype = torch.uint8 if opts.fp8 and opts.comm_type == 1 else torch.bfloat16 + sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda") + ub_obj = ub_obj = ( + tex.UbufP2PCommOverlap( + sample_buffer, # Sample userbuffer + WORLD_RANK, # World rank + WORLD_SIZE, # World size + LOCAL_RANK, # Rank within the node + LOCAL_SIZE, # Number of ranks/GPUs per node + 0, # Node ID + 1, # Number of nodes + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + 1, # Number of communication SMs + 1, # CGA cluster size + opts.comm_type == 0 or opts.atomic, # Set SM margin + opts.aggregate, # Aggregate 2X GEMM chunks + 3, # Max concurrent GEMM streams + opts.comm_type == 0, # overlap with reduce scatter + opts.atomic, # use a single GEMM with atomic-counters + True, # Use copy engine for P2P communications + ub_callbacks, + ) + if opts.p2p + else tex.UbufCommOverlap( + sample_buffer, # Sample userbuffer + WORLD_RANK, # World rank + WORLD_SIZE, # World size + LOCAL_RANK, # Rank within the node + LOCAL_SIZE, # Number of ranks/GPUs per node + 0, # Node ID + 1, # Number of nodes + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + 16, # Number of communication SMs + 2, # CGA cluster size + 4, # Number of communication splits + True, # Set SM margin + 3, # Max concurrent GEMM streams + opts.atomic, # uUe a single GEMM with atomic-counters + ub_callbacks, + ) + ) + + # Numerical check on AG + atomic GEMM requires testing an AG+RS pair + ub_obj2 = None + if opts.atomic and opts.comm_type == 1 and opts.check_numerics: + sample_buffer2 = torch.empty((outer_size, hidden_size), dtype=torch.bfloat16, device="cuda") + ub_obj2 = tex.UbufP2PCommOverlap( + sample_buffer2, # Sample userbuffer + WORLD_RANK, # World rank + WORLD_SIZE, # World size + LOCAL_RANK, # Rank within the node + LOCAL_SIZE, # Number of ranks/GPUs per node + 0, # Node ID + 1, # Number of nodes + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + 1, # Number of communication SMs + 1, # CGA cluster size + True, # Set SM margin + False, # Aggregate 2X GEMM chunks + 3, # Max concurrent GEMM streams + True, # overlap with reduce scatter + True, # use a single GEMM with atomic-counters + True, # use copy engine for P2P communications + ub_callbacks, + ) + + # Figure out problem sizing: + # M = sequence * batch + # N = hidden size + # K = MLP intermediate size (usually 4x hidden size) + # P = number of devices for sequence/tensor parallelism + # NOTE: TE-GEMM is set up to work with a transposed kernels and non-transposed inputs. + ffn_hidden_size = 4 * hidden_size + if opts.bulk_overlap: + # Bulk overlap weight and input tensors are not relevant so they're globally sized + local_kernel_t_shape = (ffn_hidden_size, hidden_size) + local_inp_shape = (outer_size, hidden_size) + # Bulk overlap comm tensor is distributed for AG overlap only + if opts.comm_type == 1: + bulk_inp_shape = (outer_size // tp_size, hidden_size) + else: + bulk_inp_shape = (outer_size, hidden_size) + else: + if opts.comm_type == 1: + # (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P) + local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size) + local_inp_shape = (outer_size // tp_size, hidden_size) + if ub_obj2 is not None: + local_kernel2_t_shape = (hidden_size, ffn_hidden_size // tp_size) + else: + # (M, K/P) x (N, K/P)^T = (M, N) -> overlapped RS -> (M/P, N) + local_kernel_t_shape = (hidden_size, ffn_hidden_size // tp_size) + local_inp_shape = (outer_size, ffn_hidden_size // tp_size) + + # Initialize distributed input tensor and GEMM kernels + torch.manual_seed(opts.seed + tp_rank) + torch.cuda.manual_seed(opts.seed + tp_rank) + inp = torch.mul(torch.rand(local_inp_shape, dtype=torch.bfloat16, device="cuda"), opts.scale) + kernel_t = torch.mul( + torch.rand(local_kernel_t_shape, dtype=torch.bfloat16, device="cuda"), opts.scale + ) + if ub_obj2 is not None: + kernel2_t = torch.mul( + torch.rand(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"), opts.scale + ) + + # Gather global tensors and calculate reference result (need these first for Fp8 scales) + if opts.bulk_overlap: + ker_g = torch.transpose(kernel_t, 0, 1) + inp_g = inp + bulk_inp = torch.mul( + torch.rand(bulk_inp_shape, dtype=torch.bfloat16, device="cuda"), opts.scale + ) + else: + if opts.comm_type == 1: + # AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K) + ker_g = torch.transpose( + te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1 + ) + # AG Input: (M/P, N) -> gather -> (M, N) + inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0] + if ub_obj2 is not None: + ker2_g = te.distributed.gather_along_first_dim( + torch.transpose(kernel2_t, 0, 1), tp_group + )[0] + else: + # RS Kernel: (N, K/P) -> T -> (K/P, N) -> gather -> (K, N) + ker_g = te.distributed.gather_along_first_dim( + torch.transpose(kernel_t, 0, 1), tp_group + )[0] + # RS Input: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) + inp_g = torch.transpose( + te.distributed.gather_along_first_dim(torch.transpose(inp, 0, 1), tp_group)[0], 0, 1 + ) + + if opts.bulk_overlap: + if opts.comm_type == 1: + ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0] + else: + # First all-gather all the bulk inputs into a list + bulk_inp_list = [torch.zeros_like(bulk_inp) for _ in range(tp_size)] + dist.all_gather(bulk_inp_list, bulk_inp, tp_group) + # Sum the list together for final global result + ref_g = torch.stack(bulk_inp_list).sum(dim=0) + else: + ref_g = torch.matmul(inp_g, ker_g) + if ub_obj2 is not None: + inp2_g = torch.mul(ref_g, opts.scale) + ref2_g = torch.matmul(inp2_g, ker2_g) + + if opts.fp8: + fp8_formats = { + tex.DType.kFloat8E4M3: Format.E4M3, + tex.DType.kFloat8E5M2: Format.E5M2, + } + + # Structure to maintain amax and scale/scale_inv information for the kernel and input + fp8_dtype = tex.DType.kFloat8E4M3 + fp8_meta = tex.FP8TensorMeta() + num_gemms = 6 if ub_obj2 is not None else 3 + fp8_meta.amax_history = torch.zeros((2, num_gemms), dtype=torch.float, device="cuda") + fp8_meta.scale = torch.ones(num_gemms, dtype=torch.float, device="cuda") + fp8_meta.scale_inv = torch.ones(num_gemms, dtype=torch.float, device="cuda") + + # Compute initial amaxes and scales + inp_amax = torch.max(torch.abs(inp_g)) + fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_INPUT].copy_(inp_amax) + ker_amax = torch.max(torch.abs(ker_g)) + fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) + ref_amax = torch.max(torch.abs(ref_g)) + fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) + if ub_obj2 is not None: + inp2_amax = torch.max(torch.abs(inp2_g)) + fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_INPUT].copy_(inp2_amax) + ker2_amax = torch.max(torch.abs(ker2_g)) + fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_WEIGHT].copy_(ker2_amax) + ref2_amax = torch.max(torch.abs(ref2_g)) + fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(ref2_amax) + fp8_meta.scale = _default_sf_compute( + fp8_meta.amax_history[1], fp8_meta.scale, fp8_formats[fp8_dtype].value.max_fwd, 1 + ) + fp8_meta.scale_inv = torch.reciprocal(fp8_meta.scale) + + # Cast input to Float8Tensor + inp_fp8 = tex.cast_to_fp8(inp, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype) + + # Cast kernel to Float8Tensor + kernel_t_fp8 = tex.cast_to_fp8( + kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype + ) + if ub_obj2 is not None: + kernel2_t_fp8 = tex.cast_to_fp8( + kernel2_t, fp8_meta, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype + ) + + # Make sure the inputs are cast correctly + if opts.check_numerics: + torch.allclose( + inp.to(dtype=torch.float32), + inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT], + rtol=0.125, + atol=0.0675, + ) + torch.allclose( + kernel_t.to(dtype=torch.float32), + kernel_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT], + rtol=0.125, + atol=0.0675, + ) + if ub_obj2 is not None: + torch.allclose( + kernel2_t.to(dtype=torch.float32), + kernel2_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT], + rtol=0.125, + atol=0.0675, + ) + + # Set Fp8 scales for userbuffers + if opts.comm_type == 1: + ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) + if ub_obj2 is not None: + ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) + else: + ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_OUTPUT]) + + # Set up comm/compute buffers + ubuf_out2 = None + rs_out2 = None + if opts.comm_type == 1: + if opts.bulk_overlap: + ub_obj.copy_input_to_ubuf(bulk_inp, 1) + gemm_inp = inp + else: + ub_obj.copy_input_to_ubuf(inp_fp8 if opts.fp8 else inp, 1) + gemm_inp = ub_obj.get_ubuf_output(1) + ubuf_out = None + rs_out = None + if ub_obj2 is not None: + ubuf_out2 = ub_obj2.get_ubuf_output(1) + rs_out2 = torch.empty( + (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" + ) + else: + if opts.bulk_overlap: + ub_obj.copy_input_to_ubuf(bulk_inp, 0) + ubuf_out = None + else: + ubuf_out = ub_obj.get_ubuf_output(1) + gemm_inp = inp_fp8 if opts.fp8 else inp + rs_out = torch.empty( + (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" + ) + + # Trigger GEMM + total_iters = opts.warmup_iters + opts.timing_iters + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)] + torch.cuda.synchronize() + + if opts.fp8: + for i in range(total_iters): + start_events[i].record() + all_outputs = tex.fp8_gemm( + kernel_t_fp8, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype, + gemm_inp, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype, + torch.bfloat16, + te.module.base.get_workspace(), + bias=None, + use_bias=False, + gelu=False, + use_split_accumulator=te.module.base._2X_ACC_FPROP, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, + out=ubuf_out, + ) + end_events[i].record() + if ub_obj2 is not None: + gemm2_inp = tex.cast_to_fp8( + torch.mul(all_outputs[0], opts.scale), + fp8_meta, + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype, + ) + all_outputs = tex.fp8_gemm( + kernel2_t_fp8, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_dtype, + gemm2_inp, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype, + torch.bfloat16, + te.module.base.get_workspace(), + bias=None, + use_bias=False, + gelu=False, + use_split_accumulator=te.module.base._2X_ACC_FPROP, + ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P, + ub=ub_obj2, + extra_output_tensor=rs_out2, + out=ubuf_out2, + ) + else: + for i in range(total_iters): + start_events[i].record() + all_outputs = tex.gemm( + kernel_t, + gemm_inp, + torch.bfloat16, + te.module.base.get_workspace(), + bias=None, + use_bias=False, + gelu=False, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, + out=ubuf_out, + ) + end_events[i].record() + + torch.cuda.synchronize() + gpu_times = [ + s.elapsed_time(e) + for s, e in zip(start_events[opts.warmup_iters :], end_events[opts.warmup_iters :]) + ] + + avg_gpu_time = sum(gpu_times) / opts.timing_iters + gemm_name = "".join( + [ + "p2p all-gather + " if opts.comm_type == 1 else "", + "atomic " if opts.atomic else "", + "GEMM", + (f" + {'p2p ' if opts.p2p else ''}reduce-scatter" if opts.comm_type == 0 else ""), + ] + ) + timing_info = ( + f"Avg. GPU time for {gemm_name}: {avg_gpu_time} ms " + + f"({opts.warmup_iters} warmup + {opts.timing_iters} timing runs)" + ) + dist_print(timing_info, section=True, info=True, group=tp_group) + + # Compare against standard GEMM + numerics_failed = False + if opts.check_numerics: + torch.cuda.synchronize() + dist.barrier(tp_group) + if opts.bulk_overlap: + output_info = "" + if opts.comm_type == 1: + # Bulk overlap AG output is already gathered + test_out = ub_obj.get_ubuf_output(1) + else: + # Bulk overlap RS output needs to be gathered + out_local = ub_obj.get_ubuf_output(0) + output_info += f"rs_output: {list(out_local.shape)} | " + test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0] + + ref_out = ref_g + output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}" + dist_print(output_info, src=0 if opts.comm_type == 0 else None, section=True) + + test_nonzeros = torch.count_nonzero(test_out) + ref_nonzeros = torch.count_nonzero(ref_out) + nonzero_info = ( + f"output nonzeros = {test_nonzeros} " + f"| reference count = {ref_nonzeros}" + ) + dist_print(nonzero_info, src=0, section=True, group=tp_group) + else: + if opts.comm_type == 1: + if ub_obj2 is not None: + # AG+RS Output: (M/P, N) -> gather -> (M, N) + output = rs_out2 + test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] + else: + # AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) + output = all_outputs[0] + test_out = torch.transpose( + te.distributed.gather_along_first_dim( + torch.transpose(output, 0, 1), tp_group + )[0], + 0, + 1, + ) + else: + # RS Output: (M/P, N) -> gather -> (M, N) + output = rs_out + test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] + + if opts.fp8: + dist_print("GEMM1 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) + fp8_meta_info = ( + f"amax_reference = {fp8_meta.amax_history[1][:3].tolist()}\n" + + f"amax_history = {fp8_meta.amax_history[0][:3].tolist()}\n" + + f"scale = {fp8_meta.scale[:3].tolist()}\n" + + f"scale_inv = {fp8_meta.scale_inv[:3].tolist()}" + ) + dist_print(fp8_meta_info, src=0, group=tp_group) + if ub_obj2 is not None: + dist_print("GEMM2 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) + fp8_meta_info = ( + f"amax_reference = {fp8_meta.amax_history[1][3:].tolist()}\n" + + f"amax_history = {fp8_meta.amax_history[0][3:].tolist()}\n" + + f"scale = {fp8_meta.scale[3:].tolist()}\n" + + f"scale_inv = {fp8_meta.scale_inv[3:].tolist()}" + ) + dist_print(fp8_meta_info, src=0, group=tp_group) + + ref_out = ref2_g if ub_obj2 is not None else ref_g + test_nonzeros = torch.count_nonzero(test_out) + ref_nonzeros = torch.count_nonzero(ref_out) + nonzero_info = ( + f"output nonzeros = {test_nonzeros} " + f"| reference count = {ref_nonzeros}" + ) + dist_print(nonzero_info, src=0, section=True, group=tp_group) + + sizing_info = ( + f"input: {list(inp.shape)} " + f"| GEMM1 weights: {list(kernel_t.shape)[::-1]} " + ) + if ub_obj2 is not None: + sizing_info += f"| GEMM2 weights: {list(kernel2_t.shape)[::-1]} " + sizing_info += f"| output: {list(output.shape)}\n" + dist_print(sizing_info, section=True, group=tp_group) + + sizing_info_g = ( + f"input: {list(inp_g.shape)} " + f"| GEMM1 weights: {list(ker_g.shape)} " + ) + if ub_obj2 is not None: + sizing_info_g += f"| GEMM2 weights: {list(ker2_g.shape)} " + sizing_info_g += ( + f"| output: {list(test_out.shape)} " + f"| reference: {list(ref_out.shape)}\n" + ) + dist_print(sizing_info_g, src=0, group=tp_group) + + torch.cuda.synchronize() + dist.barrier(tp_group) + test_out = test_out.to(dtype=torch.float32) + ref_out = ref_out.to(dtype=torch.float32) + error_below_tol = torch.allclose( + test_out, + ref_out, + rtol=0.125 if opts.fp8 else 0.02, + atol=0.0675 if opts.fp8 else 0.001, + ) + diff = torch.abs(test_out - ref_out).flatten() + m = torch.argmax(diff) + abs_err = diff[m].item() + rel_err = abs_err / (ref_out.flatten()[m].item() + 1e-5) + if not error_below_tol: + numerics_failed = True + numerics_info = ( + "NUMERICAL CHECK FAILED: " + + f"Outputs not close enough at index {m.item()} " + + f"with {test_out.flatten()[m].item()} vs {ref_out.flatten()[m].item()} " + + f"(abs error = {abs_err} | rel error = {rel_err})." + ) + else: + numerics_info = f"NUMERICAL CHECK PASSED: abs error = {abs_err} | rel error = {rel_err}" + + dist_print(numerics_info, src=0, section=True, info=True, group=tp_group) + + dist.barrier(tp_group) + if LOCAL_RANK == 0: + print("\n", end="", flush=True) + + dist.destroy_process_group() + + # Reset clock speeds + if opts.clock_speed > 0: + subprocess.run( + ["nvidia-smi", "-pm", "ENABLED", "-i", str(LOCAL_RANK)], + env=os.environ, + check=False, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + result = subprocess.run( + ["nvidia-smi", "-rgc", "-i", str(LOCAL_RANK)], + env=os.environ, + check=False, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + return int(numerics_failed) + + +if __name__ == "__main__": + sys.exit(_main(_parse_args())) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py new file mode 100644 index 0000000000..d0745aebf6 --- /dev/null +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import os +import subprocess +from pathlib import Path + +import pytest +import torch +import transformer_engine.pytorch.cpp_extensions as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +RNG_SEED: int = 1234 +SEQ_LENGTH: int = 2024 +BATCH_SIZE: int = 2 +NUM_HEADS: int = 64 +HEAD_DIM: int = 128 + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = min(torch.cuda.device_count(), 4) +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] +if tex.ubuf_built_with_mpi(): + LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python"] + +# Fall back on CUDA IPC if the platform does not support CUDA multicast +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + +# Force GPU kernels to launch in the order they're executed by the host CPU +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Comm+GEMM overlap requires at least 2 GPUs.") +@pytest.mark.parametrize( + "fp8,p2p,comm_type,aggregate,atomic,bulk", + [ + # FP8, P2P, Type, Aggregate, Atomic, Bulk + (False, True, "AG", False, False, False), + (False, True, "AG", True, False, False), + (True, True, "AG", False, False, False), + (True, True, "AG", True, False, False), + (False, False, "RS", False, False, False), + (False, True, "RS", False, False, False), + (True, False, "RS", False, False, False), + (True, True, "RS", False, False, False), + (True, False, "RS", False, True, False), + (True, True, "RS", False, True, False), + (False, False, "AG", False, False, True), + (False, False, "RS", False, False, True), + ], + ids=[ + " AG -> SPLIT GEMM | BF16 | RING-EXCHANGE ", + " AG -> SPLIT GEMM | BF16 | RING-EXCHANGE (2X AGGREGATED) ", + " AG -> SPLIT GEMM | FP8 | RING-EXCHANGE ", + " AG -> SPLIT GEMM | FP8 | RING-EXCHANGE (2X AGGREGATED) ", + " SPLIT GEMM -> RS | BF16 | PIPELINE ", + " SPLIT GEMM -> RS | BF16 | RING-EXCHANGE ", + " SPLIT GEMM -> RS | FP8 | PIPELINE ", + " SPLIT GEMM -> RS | FP8 | RING-EXCHANGE ", + " ATOMIC GEMM -> RS | FP8 | PIPELINE ", + " ATOMIC GEMM -> RS | FP8 | RING-EXCHANGE ", + " BULK AG & GEMM | BF16 | PIPELINE ", + " BULK RS & GEMM | BF16 | PIPELINE ", + ], +) +def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk): + """ + Test comm+GEMM overlap algorithms with direct calls to + te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm + """ + test_path = TEST_ROOT / "run_gemm_with_overlap.py" + test_cmd = ( + LAUNCH_CMD + + [str(test_path)] + + [ + "--check-numerics", + f"--seed={RNG_SEED}", + f"--seq-length={SEQ_LENGTH}", + f"--batch-size={BATCH_SIZE}", + f"--num-heads={NUM_HEADS}", + f"--head-dim={HEAD_DIM}", + f"--comm-type={comm_type}", + ] + ) + + if bulk: + test_cmd.append("--bulk-overlap") + else: + if fp8: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + test_cmd.append("--fp8") + if p2p: + test_cmd.append("--p2p") + if aggregate: + test_cmd.append("--aggregate") + if atomic: + if torch.cuda.get_device_properties(0).major < 9: + pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.") + test_cmd.append("--atomic") + + output = subprocess.run(test_cmd, env=os.environ, text=True, capture_output=True, check=False) + assert "NUMERICAL CHECK PASSED" in str(output) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 611de6ec77..0d70c9dc45 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -19,7 +19,10 @@ #include #include +#include + #include "common/common.h" +#include "common/util/cuda_driver.h" #include "common/util/logging.h" #include "common/util/system.h" #include "extensions.h" @@ -28,76 +31,97 @@ #define HALF_BYTES 2 #define UB_MAX_SM 32 -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA Error at line %d: %s\n", __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while (0) - using namespace torch::indexing; +using namespace std::placeholders; namespace ubuf { -/* -** Static container for Python callbacks to torch.distributed collectives -*/ -static struct TorchCallbacks : torch::CustomClassHolder { - bool initialized{false}; - std::unordered_map gathered_tensors; - std::function allgather; - std::function barrier; - std::function free; -} torch_callbacks; - -/* -** Helper function for setting Python callbacks to torch.distributed collectives. -*/ -void set_ubuf_bootstrap_callbacks( - std::function allgather, - std::function barrier, std::function free) { - torch_callbacks.allgather = allgather; - torch_callbacks.barrier = barrier; - torch_callbacks.free = free; - torch_callbacks.initialized = true; -} +bool device_supports_multicast() { + int dev, supports_multicast; + CUdevice cudev; -/* -** Python callback for globaldata = torch.distributed.all_gather(localdata, tp_group). -** This *creates* a new tensor, which Userbuffers later frees with a separate callback. -*/ -void ub_alloc_copy_allgather(void **globaldata, void *localdata, size_t localbytes, char *group) { - assert(torch_callbacks.initialized); - auto localtensor = - torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto globaltensor = torch_callbacks.allgather(localtensor, group); - *globaldata = globaltensor.data_ptr(); - torch_callbacks.gathered_tensors[*globaldata] = globaltensor; -} + NVTE_CHECK_CUDA(cudaGetDevice(&dev)); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, dev); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &supports_multicast, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); -/* -** Python callback for torch.distributed.barrier(tp_group). -*/ -void ub_barrier(char *group) { - assert(torch_callbacks.initialized); - torch_callbacks.barrier(group); + return static_cast(supports_multicast); } -/* -** Python callback for freeing up tensors created in the ub_alloc_copy_allgather(...) callback. -*/ -void ub_free(void *ptr) { - assert(torch_callbacks.initialized); - auto i = torch_callbacks.gathered_tensors.find(ptr); - if (i == torch_callbacks.gathered_tensors.end()) return; - auto tensor = std::move(i->second); - torch_callbacks.gathered_tensors.erase(i); - torch_callbacks.free(tensor); +bool ubuf_built_with_mpi() { +#ifdef NVTE_UB_WITH_MPI + return true; +#else + return false; +#endif } +class UbufBootstrapCallbacks : torch::CustomClassHolder { + private: + bool initialized{false}; + bool backend_is_nccl{false}; + std::map pgs; + + public: + UbufBootstrapCallbacks() { +#ifndef NVTE_UB_WITH_MPI + NVTE_ERROR("Internal TE error: Dummy UbufBootstrapCallbacks init without NVTE_UB_WITH_MPI=1!"); +#endif + }; // empty constructor for NVTE_UB_WITH_MPI=1 + + UbufBootstrapCallbacks(c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group) { + pgs.insert({"world", world_group}); + c10d::ProcessGroup::BackendType backend = world_group->getBackendType(); + backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); + + NVTE_CHECK(intra_node_group->getBackendType() == backend, + "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", + "group!", world_group->getBackendName()); + pgs.insert({"intra", intra_node_group}); + + initialized = true; + } + + ~UbufBootstrapCallbacks() { + for (auto &pg : pgs) pg.second = nullptr; + backend_is_nccl = false; + initialized = false; + } + + void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, + char *group) { + NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ", + "with valid process groups!"); + + auto localtensor = + torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; + auto globaltensor = + torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; + + std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; + std::vector localchunk = {localtmp}; + auto work = pgs[group]->allgather(globalchunks, localchunk); + work->wait(); + + if (backend_is_nccl) { + globaltensor.copy_(globaltmp.cpu()); + globaltmp = torch::Tensor(); + localtmp = torch::Tensor(); + } + } + + void ub_barrier(char *group) { + NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ", + "with valid process groups!"); + auto work = pgs[group]->barrier(); + work->wait(); + } +}; + enum class COMM_TYPE { RS = 0, AG = 1 }; enum class UBOverlapAlgo { @@ -127,7 +151,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { torch::Tensor _ubuf_scale_inv; bool _ubuf_scale_inv_initialized; torch::Tensor counter; - torch::Tensor _empty_tensor; at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true); std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm; @@ -136,36 +159,45 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { int _use_ce; bool _atomic_gemm; - UbufCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size, - int num_comm_sm, int comm_cga_size, int num_splits, bool set_sm_margin, - int num_max_streams, bool atomic_gemm, torch::Tensor empty_tensor) { + UbufCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size, + int num_splits, bool set_sm_margin, int num_max_streams, bool atomic_gemm, + UbufBootstrapCallbacks &callbacks) { // Initialize userbuf communicator if (!comm_created) { - if (rank == 0) { + if (myrank == 0) { printf("!!! [UB] Create UbufCommOverlap Communicator\n"); } - if (transformer_engine::getenv("UB_MPI_BOOTSTRAP")) { - create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); - } else { - create_communicator_grouped2(&_ub_comm, rank, world_size, tp_rank, tp_size, 1, 1, - &ub_alloc_copy_allgather, &ub_barrier, &ub_free, 1, 1, tp_size, - 1); - } +#ifdef NVTE_UB_WITH_MPI + create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); +#else + create_communicator_grouped2( + &_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5), + std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1); +#endif comm_created = true; } _use_ce = 0; _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; - _empty_tensor = empty_tensor; // Allocate and register extra userbuffers int ubuf_bytes = sample.numel() * sample.element_size(); - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - if (rank == 0) { + if (transformer_engine::getenv("UB_SKIPMC")) { + _ubuf = torch::zeros_like(sample); + _ubuf_ptr = _ubuf.data_ptr(); + _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, + _ub_comm, false); + } else { + _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, + _ub_comm, true); + _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); + } + + if (_ub_comm->myrank == 0) { printf("!!! [UB] Register UBuf %d\n", _ub_reg); } - _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { @@ -177,7 +209,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { _num_splits = num_splits; _tp_size = tp_size; - _tp_id = (rank % tp_size); + _tp_id = (_ub_comm->myrank % _tp_size); _ubuf_scale_inv_initialized = false; // Set the number of SMs for GEMM with margin @@ -201,6 +233,25 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { cudaEventCreateWithFlags(&_stop_comm, 0); } + ~UbufCommOverlap() { + cudaEventDestroy(_stop_comm); + cudaEventDestroy(_start_comm); + cudaEventDestroy(_start_d2dcopy); + cudaEventDestroy(_stop_compute); + cudaEventDestroy(_start_compute); + + for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); + + if (comm_created) { +#ifdef NVTE_UB_WITH_MPI + destroy_communicator_mpi(_ub_comm); +#else + destroy_communicator(_ub_comm); +#endif + comm_created = false; + } + } + /* ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf @@ -226,8 +277,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { // Catch up the default torch stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); // Communication: AG and RS if (_comm_type == COMM_TYPE::AG) { @@ -261,8 +312,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, workspaceSize, accumulate, use_split_accumulator, _math_sms); - CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); // Generate output tensor from userbuf data pointer int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; @@ -305,9 +356,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { // Catch up the default torch stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; @@ -326,6 +377,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/, counter); + for (int i = 0; i < _num_splits; i++) { const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); if (env_p != nullptr && env_p[0] == '1') { @@ -373,10 +425,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } _ub_comm->sms = ori_sms; - CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); - CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); at::cuda::setCurrentCUDAStream(stream_main); return; @@ -416,11 +468,11 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { // Catch up the default torch stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); for (size_t i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); } - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; @@ -456,9 +508,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); - CHECK_CUDA(cudaEventRecord( + NVTE_CHECK_CUDA(cudaEventRecord( _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); // Communication chunk if (_ubuf.element_size() == 1) { @@ -479,9 +531,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } int last_compute_stream_id = (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); - CHECK_CUDA( + NVTE_CHECK_CUDA( cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); // Last communication chunk with max SM _ub_comm->sms = UB_MAX_SM; @@ -513,9 +565,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); - CHECK_CUDA(cudaEventRecord(_start_comm, - (cudaStream_t)_stream_compute[i % _stream_compute.size()])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, + (cudaStream_t)_stream_compute[i % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); // Communication chunk. Uses MAX_SM at the last chunk if (i == _num_splits - 1) { @@ -540,12 +592,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } } for (size_t i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); } _ub_comm->sms = ori_sms; - CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); at::cuda::setCurrentCUDAStream(stream_main); return; @@ -576,10 +628,11 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)_stream_comm)); } torch::Tensor &get_ubuf_output(int comm_type) { @@ -609,7 +662,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { void *_ubuf_ptr; torch::Tensor _ubuf; torch::Tensor counter; - torch::Tensor _empty_tensor; torch::Tensor _ubuf_scale_inv; bool _ubuf_scale_inv_initialized; std::vector _ubufs; @@ -622,29 +674,30 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int _cga_size; bool _atomic_gemm; - UbufP2PCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size, - int num_comm_sm, int comm_cga_size, bool set_sm_margin, bool aggregate2, - int num_max_streams, bool is_reduce_scatter, bool atomic_gemm, bool use_ce, - torch::Tensor empty_tensor) { + UbufP2PCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size, + bool set_sm_margin, bool aggregate2, int num_max_streams, + bool is_reduce_scatter, bool atomic_gemm, bool use_ce, + UbufBootstrapCallbacks &callbacks) { // Initialize userbuf communicator if (!comm_created) { - if (rank == 0) { + if (myrank == 0) { printf("!!! [UB] Create UbufP2PCommOverlap Communicator\n"); } - if (transformer_engine::getenv("UB_MPI_BOOTSTRAP")) { - create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); - } else { - create_communicator_grouped2(&_ub_comm, rank, world_size, tp_rank, tp_size, 1, 1, - &ub_alloc_copy_allgather, &ub_barrier, &ub_free, 1, 1, tp_size, - 1); - } +#ifdef NVTE_UB_WITH_MPI + create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); +#else + create_communicator_grouped2( + &_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5), + std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1); +#endif comm_created = true; } _use_ce = use_ce; _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; - _empty_tensor = empty_tensor; // Create workspace tensor with userbuffer int ubuf_bytes = sample.numel() * sample.element_size(); int ubuf_chunk_bytes = ubuf_bytes / tp_size; @@ -655,15 +708,23 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ubuf_bytes = static_cast(ubuf_bytes / tp_size * (tp_size * 2 - 1)); num_ubuf_chunks = static_cast(tp_size * 2 - 1); } - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - if (rank == 0) { + if (transformer_engine::getenv("UB_SKIPMC")) { + _ubuf = torch::zeros({sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, + sample.options()); + _ubuf_ptr = _ubuf.data_ptr(); + _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, + _ub_comm, false); + } else { + _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, + _ub_comm, true); + _ubuf = + torch::from_blob(_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, + sample.options()); + } + if (_ub_comm->myrank == 0) { printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); } - _ubuf = torch::from_blob( - _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options()); - // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(_ubuf.data_ptr()); for (int i = 0; i < num_ubuf_chunks; i++) { @@ -690,23 +751,23 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { _tp_size = tp_size; _aggregate2 = aggregate2; - _rank = rank; - _tp_id = (rank % tp_size); - _rank_round_tp = (rank / tp_size) * tp_size; - _next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp; - _prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp; + _rank = _ub_comm->myrank; + _tp_id = (_rank % _tp_size); + _rank_round_tp = (_rank / _tp_size) * _tp_size; + _next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp; + _prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp; _ubuf_scale_inv_initialized = false; _atomic_gemm = atomic_gemm; _self_chunk_id = _tp_id; if (_atomic_gemm) { auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - counter = torch::zeros({tp_size * 2}, counter_options); - counter.index_put_({Slice(None, tp_size)}, 1); + counter = torch::zeros({_tp_size * 2}, counter_options); + counter.index_put_({Slice(None, _tp_size)}, 1); if (!is_reduce_scatter) { const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); - if (rank == 0 && env_p != nullptr) { + if (_rank == 0 && env_p != nullptr) { if (env_p[0] == '1') { printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); } @@ -724,6 +785,25 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { cudaEventCreateWithFlags(&_stop_recv, 0); } + ~UbufP2PCommOverlap() { + cudaEventDestroy(_stop_recv); + cudaEventDestroy(_stop_send); + cudaEventDestroy(_start_comm); + cudaEventDestroy(_stop_compute); + cudaEventDestroy(_start_compute); + + for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); + + if (comm_created) { +#ifdef NVTE_UB_WITH_MPI + destroy_communicator_mpi(_ub_comm); +#else + destroy_communicator(_ub_comm); +#endif + comm_created = false; + } + } + /* ** Split AllGather + AtomicGEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is @@ -766,9 +846,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Catch up the default torch stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); @@ -809,12 +889,12 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); - CHECK_CUDA( + NVTE_CHECK_CUDA( cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(), _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); } // Reset atomic counters @@ -822,9 +902,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Copy the first GEMM output chunk to the end chunk position of D_buffer char *src_ptr = reinterpret_cast(D_buffer.data_ptr()); - CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, - n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, + n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)stream_main)); // Return the last N rows of D_buffer torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n); return D_return; @@ -871,12 +951,12 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); for (size_t i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); } if (_aggregate2) { const int num_steps = _tp_size / 2; @@ -892,9 +972,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { (cudaStream_t)_stream_send); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, (cudaStream_t)_stream_recv); - CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; @@ -931,16 +1011,16 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { next_rank, (cudaStream_t)_stream_send); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, prev_rank, (cudaStream_t)_stream_recv); - CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - CHECK_CUDA(cudaStreamWaitEvent( + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); } } } else { @@ -976,27 +1056,27 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { _next_rank, (cudaStream_t)_stream_send); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank, (cudaStream_t)_stream_recv); - CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - CHECK_CUDA(cudaStreamWaitEvent( + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); } } } for (size_t i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); } - CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); at::cuda::setCurrentCUDAStream(stream_main); return D; @@ -1032,8 +1112,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Catch up the main stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); // Atomic GEMM // Process GEMM chunks in the order that AG+GEMM places the output chunks. @@ -1059,8 +1139,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, (cudaStream_t)_stream_recv); } - CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); // Reduce GEMM output chunks char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); @@ -1113,11 +1193,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Catch up the main stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); for (size_t i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); } // GEMM and send/recv chunks @@ -1145,18 +1225,18 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int recv_offset = comm_bytes * (i - 1 + _tp_size); int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - CHECK_CUDA(cudaEventRecord( + NVTE_CHECK_CUDA(cudaEventRecord( _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0)); userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, (cudaStream_t)_stream_send); userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, (cudaStream_t)_stream_recv); } } - CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); // Reduce GEMM output chunks char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); @@ -1174,11 +1254,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { torch::sum_out(rs_output, reduce_buf, 0); } for (size_t i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); } - CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); } /* @@ -1191,16 +1271,16 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) { NVTE_ERROR("input and ubuf size do not match!"); } - CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(), + input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); } else { if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { NVTE_ERROR("input and ubuf size do not match!"); } - CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(), + input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); } } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d97dcc73f6..f568f4659d 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -206,11 +206,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); - // Communication functions to initialize Userbuffers communicators - // Note: Callbacks are not called, so safe to release GIL. - m.def("set_ubuf_bootstrap_callbacks", &ubuf::set_ubuf_bootstrap_callbacks, + m.def("device_supports_multicast", &ubuf::device_supports_multicast, py::call_guard()); + m.def("ubuf_built_with_mpi", &ubuf::ubuf_built_with_mpi, + py::call_guard()); + + py::class_(m, "UbufBootstrapCallbacks") + .def(py::init<>(), py::call_guard()) + .def(py::init(), + py::call_guard()); + py::enum_(m, "UbufOverlapAlgo") .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) @@ -225,8 +231,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // communicator with Python functions (e.g. PyTorch distributed // communication) py::class_(m, "UbufCommOverlap") - .def(py::init()) + .def(py::init(), + py::call_guard()) .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap, py::call_guard()) .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs, @@ -250,8 +257,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // communicator with Python functions (e.g. PyTorch distributed // communication) py::class_(m, "UbufP2PCommOverlap") - .def(py::init()) + .def(py::init(), + py::call_guard()) .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag, py::call_guard()) .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs, diff --git a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc b/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc index c80709a7e7..2fc6ffbdf9 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc +++ b/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc @@ -7,66 +7,82 @@ #include "ipcsocket.h" #include +#include #include #include -#define WARN(...) \ - {} -#define TRACE(...) \ - {} -#define SYSCHECK(...) \ - {} -#define EQCHECK(...) \ - {} +#define IPC_MAX_MSGLEN 4096 -// Enable Linux abstract socket naming -#define USE_ABSTRACT_SOCKET +void ipc_warn(const char *format, ...) { + char buffer[IPC_MAX_MSGLEN]; -#define NCCL_IPC_SOCKNAME_STR "/tmp/nccl-socket-%d-%lx" + va_list args; + va_start(args, format); + + vsnprintf(buffer, IPC_MAX_MSGLEN - 1, format, args); + snprintf(buffer + strlen(buffer), IPC_MAX_MSGLEN - strlen(buffer) - 1, " : %s (%d)\n", + strerror(errno), errno); + + fflush(stdout); + fputs(buffer, stderr); + fflush(NULL); + + va_end(args); +} + +static const char *ipcSocketResultStrings[static_cast(ipcSocketNumResults)] = { + "Success", "Unhandled CUDA error", "System error", "Internal error", + "Invalid argument", "Invalid usage", "Remote error", "In progress", +}; + +const char *ipcSocketGetErrorString(ipcSocketResult_t res) { + return ipcSocketResultStrings[static_cast(res)]; +} + +#define USE_ABSTRACT_SOCKET // Enable Linux abstract socket naming + +#define IPC_SOCKNAME_STR "/tmp/ub-ipc-socket-%d-%lx" /* * Create a Unix Domain Socket */ -ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash, - volatile uint32_t *abortFlag) { +ipcSocketResult_t ipcSocketInit(IpcSocketHandle *handle, int rank, uint64_t hash, + volatile uint32_t *abortFlag) { int fd = -1; struct sockaddr_un cliaddr; - char temp[NCCL_IPC_SOCKNAME_LEN] = ""; + char temp[IPC_SOCKNAME_LEN] = ""; if (handle == NULL) { - return ncclInternalError; + return ipcSocketInternalError; } handle->fd = -1; handle->socketName[0] = '\0'; if ((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) { - WARN("UDS: Socket creation error : %s (%d)", strerror(errno), errno); - return ncclSystemError; + ipc_warn("UDS: Socket creation error"); + return ipcSocketSystemError; } bzero(&cliaddr, sizeof(cliaddr)); cliaddr.sun_family = AF_UNIX; // Create unique name for the socket. - size_t len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); + size_t len = snprintf(temp, IPC_SOCKNAME_LEN, IPC_SOCKNAME_STR, rank, hash); if (len > (sizeof(cliaddr.sun_path) - 1)) { - WARN("UDS: Cannot bind provided name to socket. Name too large"); - return ncclInternalError; + errno = ENAMETOOLONG; + ipc_warn("UDS: Cannot bind provided name to socket. Name too large"); + return ipcSocketInternalError; } -#ifndef USE_ABSTRACT_SOCKET - unlink(temp); -#endif - - TRACE(NCCL_INIT, "UDS: Creating socket %s", temp); - strncpy(cliaddr.sun_path, temp, len); #ifdef USE_ABSTRACT_SOCKET cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick +#else + unlink(temp); #endif if (bind(fd, (struct sockaddr *)&cliaddr, sizeof(cliaddr)) < 0) { - WARN("UDS: Binding to socket %s failed : %s (%d)", temp, strerror(errno), errno); + ipc_warn("UDS: Binding to socket %s failed", temp); close(fd); - return ncclSystemError; + return ipcSocketSystemError; } handle->fd = fd; @@ -79,24 +95,25 @@ ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash, fcntl(fd, F_SETFL, flags | O_NONBLOCK); } - return ncclSuccess; + return ipcSocketSuccess; } -ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd) { +ipcSocketResult_t ipcSocketGetFd(struct IpcSocketHandle *handle, int *fd) { if (handle == NULL) { - WARN("ncclSocketGetFd: pass NULL socket"); - return ncclInvalidArgument; + errno = EINVAL; + ipc_warn("ipcSocketSocketGetFd: pass NULL socket"); + return ipcSocketInvalidArgument; } if (fd) *fd = handle->fd; - return ncclSuccess; + return ipcSocketSuccess; } -ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) { +ipcSocketResult_t ipcSocketClose(IpcSocketHandle *handle) { if (handle == NULL) { - return ncclInternalError; + return ipcSocketInternalError; } if (handle->fd <= 0) { - return ncclSuccess; + return ipcSocketSuccess; } #ifndef USE_ABSTRACT_SOCKET if (handle->socketName[0] != '\0') { @@ -105,10 +122,10 @@ ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) { #endif close(handle->fd); - return ncclSuccess; + return ipcSocketSuccess; } -ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, int *recvFd) { +ipcSocketResult_t ipcSocketRecvMsg(IpcSocketHandle *handle, void *hdr, int hdrLen, int *recvFd) { struct msghdr msg = {0, 0, 0, 0, 0, 0, 0}; struct iovec iov[1]; @@ -138,39 +155,44 @@ ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, while ((ret = recvmsg(handle->fd, &msg, 0)) <= 0) { if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { - WARN("UDS: Receiving data over socket failed : %d", errno); - return ncclSystemError; + ipc_warn("UDS: Receiving data over socket failed"); + return ipcSocketSystemError; } - if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; + if (handle->abortFlag && *handle->abortFlag) return ipcSocketInternalError; } if (recvFd != NULL) { if (((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) { if ((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) { - WARN("UDS: Receiving data over socket failed"); - return ncclSystemError; + errno = EBADMSG; + ipc_warn("UDS: Receiving data over socket %s failed", handle->socketName); + return ipcSocketSystemError; } memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd)); } else { - WARN("UDS: Receiving data over socket %s failed", handle->socketName); - return ncclSystemError; + errno = ENOMSG; + ipc_warn("UDS: Receiving data over socket %s failed", handle->socketName); + return ipcSocketSystemError; } - TRACE(NCCL_INIT | NCCL_P2P, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName); + } else { + errno = EINVAL; + ipc_warn("UDS: File descriptor pointer cannot be NULL"); + return ipcSocketInvalidArgument; } - return ncclSuccess; + return ipcSocketSuccess; } -ncclResult_t ncclIpcSocketRecvFd(ncclIpcSocket *handle, int *recvFd) { - return ncclIpcSocketRecvMsg(handle, NULL, 0, recvFd); +ipcSocketResult_t ipcSocketRecvFd(IpcSocketHandle *handle, int *recvFd) { + return ipcSocketRecvMsg(handle, NULL, 0, recvFd); } -ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, const int sendFd, - int rank, uint64_t hash) { +ipcSocketResult_t ipcSocketSendMsg(IpcSocketHandle *handle, void *hdr, int hdrLen, const int sendFd, + int rank, uint64_t hash) { struct msghdr msg = {0, 0, 0, 0, 0, 0, 0}; struct iovec iov[1]; - char temp[NCCL_IPC_SOCKNAME_LEN]; + char temp[IPC_SOCKNAME_LEN]; union { struct cmsghdr cm; @@ -185,10 +207,11 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, bzero(&cliaddr, sizeof(cliaddr)); cliaddr.sun_family = AF_UNIX; - size_t len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); + size_t len = snprintf(temp, IPC_SOCKNAME_LEN, IPC_SOCKNAME_STR, rank, hash); if (len > (sizeof(cliaddr.sun_path) - 1)) { - WARN("UDS: Cannot connect to provided name for socket. Name too large"); - return ncclInternalError; + errno = ENAMETOOLONG; + ipc_warn("UDS: Cannot connect to provided name for socket. Name too large"); + return ipcSocketInternalError; } (void)strncpy(cliaddr.sun_path, temp, len); @@ -196,11 +219,7 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick #endif - TRACE(NCCL_INIT, "UDS: Sending hdr %p len %d to UDS socket %s", hdr, hdrLen, temp); - if (sendFd != -1) { - TRACE(NCCL_INIT, "UDS: Sending fd %d to UDS socket %s", sendFd, temp); - msg.msg_control = control_un.control; msg.msg_controllen = sizeof(control_un.control); @@ -228,15 +247,16 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, ssize_t sendResult; while ((sendResult = sendmsg(handle->fd, &msg, 0)) < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { - WARN("UDS: Sending data over socket %s failed : %s (%d)", temp, strerror(errno), errno); - return ncclSystemError; + ipc_warn("UDS: Sending data over socket %s failed", temp); + return ipcSocketSystemError; } - if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; + if (handle->abortFlag && *handle->abortFlag) return ipcSocketInternalError; } - return ncclSuccess; + return ipcSocketSuccess; } -ncclResult_t ncclIpcSocketSendFd(ncclIpcSocket *handle, const int sendFd, int rank, uint64_t hash) { - return ncclIpcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash); +ipcSocketResult_t ipcSocketSendFd(IpcSocketHandle *handle, const int sendFd, int rank, + uint64_t hash) { + return ipcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash); } diff --git a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h b/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h index cc1e45febf..979df384a8 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h +++ b/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h @@ -4,10 +4,9 @@ * See LICENSE for license information. ************************************************************************/ -#ifndef NCCL_IPCSOCKET_H -#define NCCL_IPCSOCKET_H +#ifndef TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H +#define TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H -// #include "nccl.h" #include #include #include @@ -21,32 +20,33 @@ #include typedef enum { - ncclSuccess = 0, - ncclUnhandledCudaError = 1, - ncclSystemError = 2, - ncclInternalError = 3, - ncclInvalidArgument = 4, - ncclInvalidUsage = 5, - ncclRemoteError = 6, - ncclInProgress = 7, - ncclNumResults = 8 -} ncclResult_t; - -#define NCCL_IPC_SOCKNAME_LEN 64 - -struct ncclIpcSocket { + ipcSocketSuccess = 0, + ipcSocketUnhandledCudaError = 1, + ipcSocketSystemError = 2, + ipcSocketInternalError = 3, + ipcSocketInvalidArgument = 4, + ipcSocketInvalidUsage = 5, + ipcSocketRemoteError = 6, + ipcSocketInProgress = 7, + ipcSocketNumResults = 8 +} ipcSocketResult_t; + +const char *ipcSocketGetErrorString(ipcSocketResult_t res); + +#define IPC_SOCKNAME_LEN 64 + +struct IpcSocketHandle { int fd; - char socketName[NCCL_IPC_SOCKNAME_LEN]; + char socketName[IPC_SOCKNAME_LEN]; volatile uint32_t *abortFlag; }; -ncclResult_t ncclIpcSocketInit(struct ncclIpcSocket *handle, int rank, uint64_t hash, - volatile uint32_t *abortFlag); -ncclResult_t ncclIpcSocketClose(struct ncclIpcSocket *handle); -ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd); +ipcSocketResult_t ipcSocketInit(IpcSocketHandle *handle, int rank, uint64_t hash, + volatile uint32_t *abortFlag); +ipcSocketResult_t ipcSocketClose(IpcSocketHandle *handle); +ipcSocketResult_t ipcSocketGetFd(IpcSocketHandle *handle, int *fd); -ncclResult_t ncclIpcSocketRecvFd(struct ncclIpcSocket *handle, int *fd); -ncclResult_t ncclIpcSocketSendFd(struct ncclIpcSocket *handle, const int fd, int rank, - uint64_t hash); +ipcSocketResult_t ipcSocketRecvFd(IpcSocketHandle *handle, int *fd); +ipcSocketResult_t ipcSocketSendFd(IpcSocketHandle *handle, const int fd, int rank, uint64_t hash); -#endif /* NCCL_IPCSOCKET_H */ +#endif /* TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H */ diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp index 60ae6198ee..982da28d33 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp @@ -19,15 +19,52 @@ #include #include -#include "../util/cuda_driver.h" +#include "common/util/cuda_driver.h" +#include "common/util/logging.h" #include "ipcsocket.h" #include "userbuffers.h" -#ifdef UB_MPI_BOOTSTRAP -#include +#ifdef NVTE_UB_WITH_MPI static MPI_Comm EXT_COMM_WORLD = MPI_COMM_WORLD; static MPI_Comm EXT_COMM_INTRA; static MPI_Comm EXT_COMM_INTER; + +#define UB_MPI_CHECK(expr) \ + do { \ + const int mpicode = (expr); \ + if (mpicode != MPI_SUCCESS) { \ + char mpimsg[MPI_MAX_ERROR_STRING]; \ + int mpilen; \ + MPI_Error_string(mpicode, mpimsg, &mpilen); \ + std::vector errmsg(1024); \ + snprintf(errmsg.data(), errmsg.size(), "%s:%d in function %s: %s", __FILE__, __LINE__, \ + __func__, mpimsg); \ + throw std::runtime_error(errmsg.data()); \ + } \ + } while (false) + +void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, + ExtComm group) { + // UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE, + // globaldata, globalbytes, MPI_BYTE, + // static_cast(group))); + MPI_Comm comm = static_cast(group); + int numranks; + UB_MPI_CHECK(MPI_Comm_size(comm, &numranks)); + assert(globalbytes == numranks * localbytes); + + int myrank; + UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank)); + char *globaltarget = reinterpret_cast(globaldata) + (myrank * localbytes); + memcpy(globaltarget, localdata, localbytes); + + for (int n = 0; n < numranks; n++) { + globaltarget = reinterpret_cast(globaldata) + (n * localbytes); + UB_MPI_CHECK(MPI_Bcast(globaltarget, localbytes, MPI_BYTE, n, comm)); + } +} + +void ub_mpi_barrier(ExtComm group) { UB_MPI_CHECK(MPI_Barrier(static_cast(group))); } #else static char EXT_COMM_WORLD[] = "world"; static char EXT_COMM_INTRA[] = "intra"; @@ -38,35 +75,21 @@ static char EXT_COMM_INTER[] = "inter"; int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); } -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ +#define IPCCHECK(cmd) \ + do { \ + ipcSocketResult_t r = cmd; \ + if (r != ipcSocketSuccess) { \ + printf("Failed, UDS error %s:%d '%s'\n", __FILE__, __LINE__, ipcSocketGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ } while (0) -#define NVTE_UB_ERROR(x) \ - do { \ - throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \ - " in function " + __func__ + ": " + x); \ - } while (false) -#define NCCLCHECK(cmd) \ - do { \ - ncclResult_t r = cmd; \ - if (r != ncclSuccess) { \ - printf("Failed, NCCL error %s:%d ''\n", __FILE__, __LINE__ /*,ncclGetErrorString(r)*/); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - -#define NCCLCHECKGOTO(call, RES, label) \ - do { \ - RES = call; \ - if (RES != ncclSuccess && RES != ncclInProgress) { \ - goto label; \ - } \ +#define IPCCHECKGOTO(call, RES, label) \ + do { \ + RES = call; \ + if (RES != ipcSocketSuccess && RES != ipcSocketInProgress) { \ + goto label; \ + } \ } while (0); int pipe_rank(communicator *comm, int step) { @@ -85,15 +108,14 @@ int pipe_rank(communicator *comm, int step) { int create_communicator_grouped2( communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_alloc_copy_allgather, - std::function ext_barrier, std::function ext_free, int pipegpus, - int pipenodes, int tensorgpus, int tensornodes) { + int numnodes, std::function ext_allgather, + std::function ext_barrier, int pipegpus, int pipenodes, int tensorgpus, + int tensornodes) { *comm = new communicator(); (*comm)->comm_world = EXT_COMM_WORLD; - (*comm)->_alloc_copy_allgather = ext_alloc_copy_allgather; + (*comm)->_allgather = ext_allgather; (*comm)->_barrier = ext_barrier; - (*comm)->_free = ext_free; (*comm)->nranks = numranks; (*comm)->myrank = myrank; (*comm)->free_region = 0; @@ -101,9 +123,9 @@ int create_communicator_grouped2( int cur_dev, ndev; cudaDeviceProp device_prop; - CUDACHECK(cudaGetDevice(&cur_dev)); - CUDACHECK(cudaGetDeviceCount(&ndev)); - CUDACHECK(cudaGetDeviceProperties(&device_prop, cur_dev)); + NVTE_CHECK_CUDA(cudaGetDevice(&cur_dev)); + NVTE_CHECK_CUDA(cudaGetDeviceCount(&ndev)); + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&device_prop, cur_dev)); (*comm)->sm_arch = device_prop.major; // (*comm)->use_rr_kernel = device_prop.major == 8; (*comm)->use_rr_kernel = 0; @@ -119,7 +141,7 @@ int create_communicator_grouped2( int device_clock = 0; // 110 sec wait time by default int sec_timeout = getenv("UB_TIMEOUT") ? atoi(getenv("UB_TIMEOUT")) : 110; - CUDACHECK(cudaDeviceGetAttribute(&device_clock, cudaDevAttrClockRate, cur_dev)); + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&device_clock, cudaDevAttrClockRate, cur_dev)); (*comm)->ub_timeout = 1000ull * device_clock * sec_timeout; if ((*comm)->myrank == 0) { printf("UB_TIMEOUT is set to %d sec, %" PRIu64 " cycles, freq: %dkhz\n", sec_timeout, @@ -154,7 +176,7 @@ int create_communicator_grouped2( if (ndev == numlocal) { // all visible devices if (cur_dev != mylocal) printf("%d: device used %d[%d] ,resetting device to %d\n", myrank, cur_dev, ndev, mylocal); - CUDACHECK(cudaSetDevice(mylocal)); + NVTE_CHECK_CUDA(cudaSetDevice(mylocal)); } (*comm)->mydev = cur_dev; // FIXME need to check that numlocal is multiple of pipegpus x tensorgpus @@ -213,14 +235,14 @@ int create_communicator_grouped2( // Broadcast the a POSIX file descriptor from the local root rank to other local ranks. // NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the // file descriptor and prevent cuMemImportFromShareableHandle() from correctly - // interpreting the file. Instead, we use system socket to send/recv the file handle - // without mangling. + // interpreting the file. Instead, we use Unix domain sockets for the kernel to + // recreate the correct file descriptor on every receiving rank. int fd; volatile uint32_t abortFlag = 0; - struct ncclIpcSocket ipcSock = {0}; + IpcSocketHandle ipcSock = {0}; uint64_t opId = 0xdeadcafeb000 + (*comm)->ar2_firstgpu; - ncclResult_t ret = ncclSuccess; - NCCLCHECK(ncclIpcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag)); + ipcSocketResult_t ret = ipcSocketSuccess; + IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag)); (*comm)->_barrier((*comm)->comm_world); if ((*comm)->ar2_nvrank == 0) { @@ -232,19 +254,22 @@ int create_communicator_grouped2( for (int p = 1; p < (*comm)->ar2_nvsize; p++) { (*comm)->_barrier((*comm)->comm_intra); - NCCLCHECKGOTO(ncclIpcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error); + IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error); } } else { - for (int i = 0; i < (*comm)->ar2_nvrank; i++) (*comm)->_barrier((*comm)->comm_intra); - NCCLCHECKGOTO(ncclIpcSocketRecvFd(&ipcSock, &fd), ret, error); - for (int i = 0; i < (*comm)->ar2_nvsize - (*comm)->ar2_nvrank - 1; i++) + for (int p = 1; p < (*comm)->ar2_nvsize; p++) { (*comm)->_barrier((*comm)->comm_intra); + if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error); + } + } + + error: + if ((*comm)->ar2_nvrank != 0) { NVTE_CALL_CHECK_CUDA_DRIVER( cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast(fd), static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); } - error: - NCCLCHECK(ncclIpcSocketClose(&ipcSock)); + IPCCHECK(ipcSocketClose(&ipcSock)); close(fd); NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastAddDevice, (*comm)->mc_handle, (CUdeviceptr)(*comm)->mydev); @@ -275,14 +300,16 @@ int create_communicator_grouped2( #define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) // peer pointers + op flags + comm buffer - CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet - CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE)); - CUDACHECK(cudaDeviceSynchronize()); + NVTE_CHECK_CUDA( + cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet + NVTE_CHECK_CUDA(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE)); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, false); - CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); - CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); - CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); - CUDACHECK(cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); + NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); + NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); + NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); + NVTE_CHECK_CUDA( + cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); (*comm)->sms = 16; (*comm)->threads = 1024; @@ -291,8 +318,8 @@ int create_communicator_grouped2( #define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1) #define GPU_PAGE_MASK (~GPU_PAGE_OFFSET) - CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE)); - CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE)); + NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE)); + NVTE_CHECK_CUDA(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE)); (*comm)->flags = reinterpret_cast(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); @@ -321,75 +348,73 @@ int create_communicator_grouped2( int create_communicator_grouped( communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_alloc_copy_allgather, - std::function ext_barrier, std::function ext_free, int pipegpus, - int pipenodes) { + int numnodes, std::function ext_allgather, + std::function ext_barrier, int pipegpus, int pipenodes) { return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - ext_alloc_copy_allgather, ext_barrier, ext_free, pipegpus, - pipenodes, 1, 1); + ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1); } -int create_communicator( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_alloc_copy_allgather, - std::function ext_barrier, std::function ext_free) { +int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, + std::function ext_allgather, + std::function ext_barrier) { return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - ext_alloc_copy_allgather, ext_barrier, ext_free, 1, 1, 1, 1); + ext_allgather, ext_barrier, 1, 1, 1, 1); } int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, int tensorgpus, int tensornodes) { -#ifdef UB_MPI_BOOTSTRAP +#ifdef NVTE_UB_WITH_MPI // get global numbers int myrank, numranks; - MPI_Comm_rank(EXT_COMM_WORLD, &myrank); - MPI_Comm_size(EXT_COMM_WORLD, &numranks); + UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_WORLD, &myrank)); + UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_WORLD, &numranks)); // find intranode numbers and make internode communicator - char host_name[MPI_MAX_PROCESSOR_NAME]; - char(*host_names)[MPI_MAX_PROCESSOR_NAME]; - int namelen, bytes, color; - int rank = (*comm)->myrank, size = (*comm)->nranks; - MPI_Get_processor_name(host_name, &namelen); - bytes = size * sizeof(char[MPI_MAX_PROCESSOR_NAME]); - host_names = (char(*)[MPI_MAX_PROCESSOR_NAME])malloc(bytes); - strcpy(host_names[rank], host_name); // NOLINT(*) - for (int n = 0; n < size; n++) - MPI_Bcast(&(host_names[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, EXT_COMM_WORLD); - qsort(host_names, size, sizeof(char[MPI_MAX_PROCESSOR_NAME]), stringCmp); - - color = 0; - for (int n = 0; n < size; n++) { - if (n > 0 && strcmp(host_names[n - 1], host_names[n])) color++; - if (strcmp(host_name, host_names[n]) == 0) break; + char hostname[MPI_MAX_PROCESSOR_NAME]; + int namelen; + UB_MPI_CHECK(MPI_Get_processor_name(hostname, &namelen)); + + char(*hostnames)[MPI_MAX_PROCESSOR_NAME] = + static_cast(malloc(numranks * MPI_MAX_PROCESSOR_NAME)); + strcpy(hostnames[myrank], hostname); + for (int n = 0; n < numranks; n++) + UB_MPI_CHECK(MPI_Bcast(&(hostnames[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, EXT_COMM_WORLD)); + qsort(hostnames, numranks, MPI_MAX_PROCESSOR_NAME, stringCmp); + + int color = 0; + for (int n = 0; n < numranks; n++) { + if (n > 0 && strcmp(hostnames[n - 1], hostnames[n])) color++; + if (strcmp(hostname, hostnames[n]) == 0) break; } - free(host_names); + free(hostnames); int mylocal, numlocal; - MPI_Comm_split(EXT_COMM_WORLD, color, rank, &EXT_COMM_INTRA); - MPI_Comm_rank(EXT_COMM_INTRA, &mylocal); - MPI_Comm_size(EXT_COMM_INTRA, &numlocal); + UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, color, myrank, &EXT_COMM_INTRA)); + UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTRA, &mylocal)); + UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTRA, &numlocal)); // find internode numbers and make internode communicator - CUDACHECK(cudaFree(0)); + NVTE_CHECK_CUDA(cudaFree(0)); int allnodes = numranks / numlocal; int datanodes = allnodes / pipenodes / tensornodes; // data reduction group node belongs, equals 0 for all if both pipenodes=1 and tensornodes=1 int datanodegroup_id = myrank / numlocal / datanodes; // mpi communicator only needed for SHARP which is always allreduce1/data-parallel - MPI_Comm_split(EXT_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &EXT_COMM_INTER); + UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, mylocal + numlocal * datanodegroup_id, myrank, + &EXT_COMM_INTER)); // different rails from same group are in different subcommunicators int mynode, numnodes; - MPI_Comm_size(EXT_COMM_INTER, &numnodes); - MPI_Comm_rank(EXT_COMM_INTER, &mynode); + UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTER, &numnodes)); + UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTER, &mynode)); // finally call the abstracted constructor with MPI info return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - &ub_alloc_copy_allgather, &ub_barrier, &ub_free, pipegpus, - pipenodes, tensorgpus, tensornodes); + &ub_mpi_allgather, &ub_mpi_barrier, pipegpus, pipenodes, + tensorgpus, tensornodes); #else - NVTE_UB_ERROR(std::string("Bootstrapping Userbuffers with MPI requires ") + - std::string("building Transformer Engine with UB_MPI_BOOTSTRAP=1")); + NVTE_ERROR(std::string("Bootstrapping Userbuffers with MPI requires building") + + std::string("Transformer Engine with NVTE_UB_WITH_MPI=1 and MPI_HOME=/path/to/mpi")); #endif } @@ -403,49 +428,46 @@ int create_communicator_mpi(communicator **comm) { void destroy_communicator(communicator *comm) { for (int hndl = 0; hndl < comm->free_region; hndl++) { - if (comm->mem_dealloc[hndl]) { - NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressFree, - reinterpret_cast(comm->ucbase_ptr[hndl]), - comm->mem_size[hndl] * comm->nvsize); + if (hndl > 0 && comm->use_mc && comm->mem_dealloc[hndl]) { for (int rank = 0; rank < comm->nvsize; rank++) { - NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]); + if (rank == comm->nvrank) { + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]); + } else { + comm->uchandles[hndl][rank] = 0; + } } free(reinterpret_cast(comm->uchandles[hndl])); } else { for (int rank = 0; rank < comm->nvsize; rank++) { if (rank != comm->nvrank) { cudaIpcCloseMemHandle(comm->peer_ptr[hndl][rank]); + } else if (comm->mem_dealloc[hndl]) { + NVTE_CHECK_CUDA(cudaFree(comm->peer_ptr[hndl][rank])); } else { comm->peer_ptr[hndl][rank] = nullptr; // remove reference to external buffer } } - free(comm->peer_ptr[hndl]); } + free(comm->peer_ptr[hndl]); comm->mem_ptr[hndl] = nullptr; } - cudaFree(reinterpret_cast(comm->flags)); cudaFree(reinterpret_cast(comm->recv_id)); cudaFree(reinterpret_cast(comm->send_id)); if (comm->use_mc) { - NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressFree, reinterpret_cast(comm->mc_baseptr), - comm->mc_maxsize); NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->mc_handle); } - if (comm->mem_dealloc[0]) { - cudaFree(comm->gpu_ptrs); - } free(comm->fifo); delete comm; } void destroy_communicator_mpi(communicator *comm) { -#ifdef UB_MPI_BOOTSTRAP - MPI_Comm_free(comm->comm_inter); - MPI_Comm_free(comm->comm_intra); +#ifdef NVTE_UB_WITH_MPI + MPI_Comm_free(static_cast(&(comm->comm_inter))); + MPI_Comm_free(static_cast(&(comm->comm_intra))); destroy_communicator(comm); #else - NVTE_UB_ERROR(std::string("Communicator is not bootstrapped with MPI and ") + - std::string("can only be deallocated with destroy_communicator().")); + NVTE_ERROR(std::string("Communicator is not bootstrapped with MPI and ") + + std::string("can only be deallocated with destroy_communicator().")); #endif } @@ -457,7 +479,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * comm->memflags[hndl] = 0; comm->mem_dealloc[hndl] = alloc; - if (alloc) { + if (comm->use_mc && alloc) { int nranks = comm->nvsize; // total GPUs in NVLINK domain int myrank = comm->nvrank; void **remptrs = reinterpret_cast(malloc(nranks * sizeof(void *))); @@ -501,26 +523,22 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * (uint64_t)0); volatile uint32_t abortFlag = 0; - struct ncclIpcSocket ipcSock = {0}; + IpcSocketHandle ipcSock = {0}; uint64_t opId = 0xdeadcafebeef; - ncclResult_t ret = ncclSuccess; - - // All-gather POSIX file descriptors across local ranks. - // NOTE: This cannot be done via MPI_Allgather or other external comm libraries. They mangle - // the file descriptor and prevent cuMemImportFromShareableHandle() from correctly - // interpreting the file. Instead, we use system socket to send/recv the file handle - // without mangling. - NCCLCHECK(ncclIpcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag)); + ipcSocketResult_t ret = ipcSocketSuccess; + + // All-gather POSIX file descriptors across local ranks + IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag)); for (int p = 1; p < nranks; p++) { + int send_to = (myrank + p) % nranks; + int recv_from = (myrank + nranks - p) % nranks; comm->_barrier(comm->comm_intra); - NCCLCHECKGOTO( - ncclIpcSocketSendFd(&ipcSock, peerfd[myrank], (myrank + p) % nranks, (uint64_t)opId), ret, - error); - NCCLCHECKGOTO(ncclIpcSocketRecvFd(&ipcSock, &peerfd[(myrank + nranks - p) % nranks]), ret, - error); + IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret, error); + IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error); } + error: - NCCLCHECK(ncclIpcSocketClose(&ipcSock)); + IPCCHECK(ipcSocketClose(&ipcSock)); for (int p = 0; p < nranks; p++) { if (p != myrank) @@ -530,6 +548,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); close(peerfd[p]); } + free(peerfd); + CUdeviceptr ptr; NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &ptr, (size_t)(aligned_size * nranks), (size_t)0, (CUdeviceptr)0, (uint64_t)0); @@ -554,12 +574,11 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * NVTE_CALL_CHECK_CUDA_DRIVER(cuMemSetAccess, ptr, (size_t)(aligned_size * nranks), const_cast(&accessDesc), (size_t)1); - if (hndl == 0) CUDACHECK(cudaMemset(comm->gpu_ptrs, 0, aligned_size)); - CUDACHECK( + if (hndl == 0) NVTE_CHECK_CUDA(cudaMemset(comm->gpu_ptrs, 0, aligned_size)); + NVTE_CHECK_CUDA( cudaMemcpy((reinterpret_cast(comm->gpu_ptrs)) + (hndl * nranks * sizeof(void *)), remptrs, nranks * sizeof(void *), cudaMemcpyHostToDevice)); free(remptrs); - free(peerfd); comm->memflags[hndl] = UB_MEM_UC_CONTIG | UB_MEM_ALLOCATED; if (comm->use_mc && comm->mc_maxsize >= comm->mc_offset + aligned_size) { @@ -575,29 +594,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * } } else { - assert(comm->nvsize <= 8); + if (alloc) { + NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes)); + NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes)); + } + + NVTE_CHECK(comm->nvsize <= 8, "CUDA IPC supports only up to 8 GPUs in an NVLink domain."); cudaIpcMemHandle_t memhndl; - CUDACHECK(cudaIpcGetMemHandle(&memhndl, *gpubuff)); + NVTE_CHECK_CUDA(cudaIpcGetMemHandle(&memhndl, *gpubuff)); - cudaIpcMemHandle_t *tmp; - comm->_alloc_copy_allgather(reinterpret_cast(&tmp), reinterpret_cast(&memhndl), - sizeof(cudaIpcMemHandle_t), comm->comm_intra); + cudaIpcMemHandle_t *tmp = + reinterpret_cast(malloc(comm->nvsize * sizeof(cudaIpcMemHandle_t))); + comm->_allgather(reinterpret_cast(tmp), comm->nvsize * sizeof(cudaIpcMemHandle_t), + reinterpret_cast(&memhndl), sizeof(cudaIpcMemHandle_t), + comm->comm_intra); for (int i = 0; i < comm->nvsize; i++) { if (i != comm->nvrank) { - CUDACHECK(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) - cudaIpcMemLazyEnablePeerAccess)); + NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) + cudaIpcMemLazyEnablePeerAccess)); } } comm->peer_ptr[hndl][comm->nvrank] = *gpubuff; - CUDACHECK(cudaDeviceSynchronize()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - CUDACHECK(cudaMemcpy( + NVTE_CHECK_CUDA(cudaMemcpy( reinterpret_cast(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)), comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice)); - CUDACHECK(cudaDeviceSynchronize()); - comm->_free(tmp); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + free(tmp); } comm->mem_size[hndl] = aligned_size; diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index b648561597..03a1a6a3df 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -23,15 +23,6 @@ #define MAX_THREADS 1024 -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - #define ATOMIC_CONSUMER(chunk) \ if (counters) { \ if (threadIdx.x == 0 && blockIdx.x == 0) { \ @@ -1391,7 +1382,7 @@ __global__ void __launch_bounds__(MAX_THREADS) reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ - CUDACHECK(cudaLaunchKernelExC( \ + NVTE_CHECK_CUDA(cudaLaunchKernelExC( \ &cfg, \ reinterpret_cast(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag \ : userbuffers_fp16_sum_inplace_gpu_rw_ag), \ @@ -1416,7 +1407,7 @@ __global__ void __launch_bounds__(MAX_THREADS) reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ reinterpret_cast(&arg11)}; \ - CUDACHECK(cudaLaunchKernelExC( \ + NVTE_CHECK_CUDA(cudaLaunchKernelExC( \ &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_ag), kernelArgs)); \ } @@ -1436,7 +1427,7 @@ __global__ void __launch_bounds__(MAX_THREADS) reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ - CUDACHECK(cudaLaunchKernelExC( \ + NVTE_CHECK_CUDA(cudaLaunchKernelExC( \ &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs), kernelArgs)); \ } @@ -1458,7 +1449,7 @@ __global__ void __launch_bounds__(MAX_THREADS) reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ reinterpret_cast(&arg11)}; \ - CUDACHECK(cudaLaunchKernelExC( \ + NVTE_CHECK_CUDA(cudaLaunchKernelExC( \ &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs), kernelArgs)); \ } @@ -1481,7 +1472,7 @@ __global__ void __launch_bounds__(MAX_THREADS) reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ reinterpret_cast(&arg13)}; \ - CUDACHECK(cudaLaunchKernelExC( \ + NVTE_CHECK_CUDA(cudaLaunchKernelExC( \ &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop), \ kernelArgs)); \ } @@ -1506,7 +1497,7 @@ __global__ void __launch_bounds__(MAX_THREADS) reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ - CUDACHECK(cudaLaunchKernelExC( \ + NVTE_CHECK_CUDA(cudaLaunchKernelExC( \ &cfg, \ reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8), \ kernelArgs)); \ @@ -1532,7 +1523,7 @@ __global__ void __launch_bounds__(MAX_THREADS) reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ - CUDACHECK(cudaLaunchKernelExC( \ + NVTE_CHECK_CUDA(cudaLaunchKernelExC( \ &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs_oop), \ kernelArgs)); \ } @@ -1562,7 +1553,7 @@ __global__ void __launch_bounds__(MAX_THREADS) reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ reinterpret_cast(&arg15), reinterpret_cast(&arg16), \ reinterpret_cast(&arg17), reinterpret_cast(&arg18)}; \ - CUDACHECK(cudaLaunchKernelExC( \ + NVTE_CHECK_CUDA(cudaLaunchKernelExC( \ &cfg, \ reinterpret_cast( \ userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8), \ @@ -1588,7 +1579,7 @@ __global__ void __launch_bounds__(MAX_THREADS) reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ reinterpret_cast(&arg13)}; \ - CUDACHECK(cudaLaunchKernelExC( \ + NVTE_CHECK_CUDA(cudaLaunchKernelExC( \ &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride), \ kernelArgs)); \ } @@ -1614,7 +1605,7 @@ __global__ void __launch_bounds__(MAX_THREADS) reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ reinterpret_cast(&arg15)}; \ - CUDACHECK(cudaLaunchKernelExC( \ + NVTE_CHECK_CUDA(cudaLaunchKernelExC( \ &cfg, \ reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic), \ kernelArgs)); \ @@ -1641,7 +1632,7 @@ __global__ void __launch_bounds__(MAX_THREADS) reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ reinterpret_cast(&arg15)}; \ - CUDACHECK( \ + NVTE_CHECK_CUDA( \ cudaLaunchKernelExC(&cfg, \ reinterpret_cast( \ userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic), \ @@ -2206,15 +2197,6 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat } } -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - // Return TRUE if two ranks share the same NV domain #define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize)) @@ -2259,7 +2241,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds if (comm->use_ce) { // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); - CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); @@ -2269,7 +2251,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), reinterpret_cast(&arg3), reinterpret_cast(&arg4), reinterpret_cast(&arg5)}; - CUDACHECK( + NVTE_CHECK_CUDA( cudaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsend), kernelArgs)); } } @@ -2291,7 +2273,8 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size if (comm->use_ce) { // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); - CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); + NVTE_CHECK_CUDA( + cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); @@ -2323,7 +2306,7 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size reinterpret_cast(&arg11), reinterpret_cast(&arg12), reinterpret_cast(&arg13), reinterpret_cast(&arg14), reinterpret_cast(&arg15)}; - CUDACHECK( + NVTE_CHECK_CUDA( cudaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsendrecv), kernelArgs)); } @@ -2346,7 +2329,8 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; if (comm->use_ce) { // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); - CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); + NVTE_CHECK_CUDA( + cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); @@ -2379,8 +2363,8 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, reinterpret_cast(&arg11), reinterpret_cast(&arg12), reinterpret_cast(&arg13), reinterpret_cast(&arg14), reinterpret_cast(&arg15), reinterpret_cast(&arg16)}; - CUDACHECK(cudaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsendrecv_atomic), - kernelArgs)); + NVTE_CHECK_CUDA(cudaLaunchKernelExC( + &cfg, reinterpret_cast(kuserbuffers_pushsendrecv_atomic), kernelArgs)); } void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler, @@ -2425,7 +2409,7 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler reinterpret_cast(&arg13), reinterpret_cast(&arg14), reinterpret_cast(&arg15), reinterpret_cast(&arg16), reinterpret_cast(&arg17), reinterpret_cast(&arg18)}; - CUDACHECK(cudaLaunchKernelExC( + NVTE_CHECK_CUDA(cudaLaunchKernelExC( &cfg, reinterpret_cast(kuserbuffers_pushsendrecv_multiatomic), kernelArgs)); } @@ -2451,7 +2435,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds if (!signalonly) kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); if (comm->use_ce) { - CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); } } else { kuserbuffers_pushrecv<<<1, 1, 0, stream>>>( diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h index e8dbf97823..371932f446 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h @@ -15,39 +15,11 @@ #include #include -#ifdef UB_MPI_BOOTSTRAP -#include - -#include - -#define UB_MPI_CHECK(expr) \ - do { \ - const int mpicode = (expr); \ - if (mpicode != MPI_SUCCESS) { \ - char mpimsg[MPI_MAX_ERROR_STRING]; \ - int mpilen; \ - MPI_Error_string(mpicode, mpimsg, &mpilen); \ - std::vector errmsg(1024); \ - snprintf(errmsg.data(), errmsg.size(), "%s:%s in function %s: %s", __FILE__, __LINE__, \ - __func__, mpimsg); \ - throw std::runtime_error(errmsg.data()); \ - } \ - } while (false) +#include "common/util/logging.h" +#ifdef NVTE_UB_WITH_MPI +#include typedef MPI_Comm ExtComm; - -void ub_alloc_copy_allgather(void **globaldata, void *localdata, size_t localbytes, ExtComm comm) { - int myrank, nranks; - UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank)); - UB_MPI_CHECK(MPI_Comm_size(comm, &nranks)); - *globaldata = malloc(nranks * localbytes); - UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE, *globaldata, nranks * localbytes, - MPI_BYTE, comm)); -} - -void ub_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); } - -void ub_free(void *ptr) { free(ptr); } #else typedef char *ExtComm; #endif @@ -170,14 +142,13 @@ struct communicator { volatile int tail; // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) - std::function _alloc_copy_allgather; + std::function _allgather; std::function _barrier; - std::function _free; ExtComm comm_world, comm_inter, // reduction group communicator (subset of the nodes) along GPU rail comm_intra; // full intranode (all ndev GPUS) -#ifdef UB_MPI_BOOTSTRAP +#ifdef NVTE_UB_WITH_MPI MPI_Request mpihndl[NVTE_MAX_SHARP]; #endif @@ -194,20 +165,19 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr /* creates communicator, allocates all internal buffers if necessary */ int create_communicator_grouped2( communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_alloc_copy_allgather, - std::function ext_barrier, std::function ext_free, int pipegpus, - int pipenodes, int tensorgpus, int tensornodes); + int numnodes, std::function ext_allgather, + std::function ext_barrier, int pipegpus, int pipenodes, int tensorgpus, + int tensornodes); int create_communicator_grouped( communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_alloc_copy_allgather, - std::function ext_barrier, std::function ext_free, int pipegpus, - int pipenodes); + int numnodes, std::function ext_allgather, + std::function ext_barrier, int pipegpus, int pipenodes); -int create_communicator( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_alloc_copy_allgather, - std::function ext_barrier, std::function ext_free); +int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, + std::function ext_allgather, + std::function ext_barrier); int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, int tensorgpus, int tensornodes); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 039df99260..6feda77c70 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -7,6 +7,9 @@ import os import pickle import warnings +import socket +import fcntl +import struct from abc import ABC, abstractmethod from typing import Dict, Generator, List, Optional, Tuple, Union from contextlib import contextmanager @@ -79,19 +82,109 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: def initialize_ub( shape: list, - tp_group: dist_group_type, + tp_size: int, use_fp8: bool = False, dtype: torch.dtype = torch.bfloat16, ub_cfgs: Optional[dict] = None, + bootstrap_backend: Union[str, torch.distributed.Backend] = None, ) -> None: """Initialize communicators for TP comm overlap using userbuffers.""" + if not tex.device_supports_multicast(): + assert bool(os.getenv("UB_SKIPMC", "0")), ( + "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " + + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." + ) + global _ub_communicators assert _ub_communicators is None, "UB communicators are already initialized." _ub_communicators = {} - rank_id = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - tp_id = torch.distributed.get_rank(tp_group) - tp_size = torch.distributed.get_world_size(tp_group) + + if tex.ubuf_built_with_mpi(): + # Userbuffers will ignore all these values when it is built with MPI, so these are just + # placeholders based on an assumption that tp_size covers all devices in a physical node. + assert torch.distributed.is_mpi_available() + mpi_group = torch.distributed.new_group(backend="mpi") + world_rank = torch.distributed.get_rank(mpi_group) + world_size = torch.distributed.get_world_size(mpi_group) + local_rank = world_rank % tp_size + local_size = tp_size + node_id = world_rank // tp_size + num_nodes = world_size // tp_size + ub_callbacks = tex.UbufBootstrapCallbacks() + else: + assert ( + torch.distributed.is_initialized() + ), "torch.distributed must be initialized before Userbuffers" + if bootstrap_backend is None: + bootstrap_backend = "nccl" + if torch.distributed.is_gloo_available(): + bootstrap_backend = "gloo" + elif torch.distributed.is_mpi_available(): + bootstrap_backend = "mpi" + else: + assert bootstrap_backend in ["gloo", "mpi", "nccl"] + + world_group = torch.distributed.new_group(backend=bootstrap_backend) + world_rank = torch.distributed.get_rank(world_group) + world_size = torch.distributed.get_world_size(world_group) + + if world_rank == 0: + print( + f'!!! [NVTE] Bootstrapping Userbuffers with backend="{bootstrap_backend}"\n', + end="", + flush=True, + ) + + # Construct an intra-node communicator based on global ranks that share the same hostname + # NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host + # address on that interface instead of the hostname. This can help avoid issues when + # different hosts have the same hostname on Kubernetes clusters. + hostname = socket.gethostname() + ifname = os.getenv( + "NVTE_UB_SOCKET_IFNAME", + os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), + ) + + if ifname is not None: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + hostname = socket.inet_ntoa( + fcntl.ioctl( + s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) + )[20:24] + ) + except OSError as err: + raise OSError(f"Invalid network interface: {ifname}") from err + + hostnames = [None for _ in range(world_size)] + torch.distributed.all_gather_object(hostnames, hostname, world_group) + intra_node_ranks = [] + for i, host in enumerate(hostnames): + if host == hostname: + intra_node_ranks.append(i) + if len(intra_node_ranks) == world_size: + intra_node_group = world_group + local_rank = world_rank + local_size = world_size + intra_node_ranks = list(range(world_size)) + else: + intra_node_group = torch.distributed.new_group( + backend=bootstrap_backend, ranks=intra_node_ranks + ) + local_rank = torch.distributed.get_rank(intra_node_group) + local_size = torch.distributed.get_world_size(intra_node_group) + + node_id = world_rank // local_size + num_nodes = world_size // local_size + if local_rank == 0: + print( + f"!!! [NVTE] Number of physical nodes: {num_nodes}\n" + + f"!!! [NVTE] Global ranks on node {node_id}: {intra_node_ranks}\n", + end="", + flush=True, + ) + + ub_callbacks = tex.UbufBootstrapCallbacks(world_group, intra_node_group) # Increase the workspace by the number of maximum concurrent streams global _cublas_workspace @@ -127,6 +220,23 @@ def get_method(name): return method raise KeyError(f"Given layer name {name} does not exist.") + def get_default_config(name): + method = get_method(name) + is_reduce_scatter = name in layers_reduce_scatter_overlap + default_cfg = { + "method": method, + "is_reduce_scatter": is_reduce_scatter, + "num_sm": 1 if method == "ring_exchange" else 16, + "cga_size": 1 if method == "ring_exchange" else 2, + "set_sm_margin": False, + "num_splits": 4 if method == "pipeline" else tp_size, + "aggregate": False, + "atomic_gemm": False, + "use_ce": True, + "fp8_buf": name in layers_all_gather_overlap, + } + return default_cfg + def add_ub( name: str, method: str, @@ -180,53 +290,43 @@ def add_ub( if method == "ring_exchange": ub_obj = tex.UbufP2PCommOverlap( sample_buffer, # Sample userbuffer - rank_id, # Rank id + world_rank, # World rank world_size, # World size - tp_id, # TP id - tp_size, # TP size + local_rank, # Rank within the node + local_size, # Number of ranks/GPUs per node + node_id, # Node ID + num_nodes, # Number of nodes + tp_size, # Tensor-parallel group size (may be different than local_size) num_sm, # Number of communication SMs cga_size, # CGA cluster size set_sm_margin, # Set SM margin aggregate, # Aggregate 2X GEMM chunks _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams - is_reduce_scatter, # overlap with reduce scatter - atomic_gemm, # use a single GEMM with atomic-counters - use_ce, # use copy engine for P2P communications - torch.Tensor(), # empty tensor to pass to counters + is_reduce_scatter, # Overlap with reduce scatter + atomic_gemm, # Use a single GEMM with atomic-counters + use_ce, # Use copy engine for P2P communications + ub_callbacks, ) else: ub_obj = tex.UbufCommOverlap( sample_buffer, # Sample userbuffer - rank_id, # Rank id + world_rank, # World rank world_size, # World size - tp_id, # TP id - tp_size, # TP size + local_rank, # Rank within the node + local_size, # Number of ranks/GPUs per node + node_id, # Node ID + num_nodes, # Number of nodes + tp_size, # Tensor-parallel group size (may be different than local_size) num_sm, # Number of communication SMs cga_size, # CGA cluster size num_splits, # Number of communication splits set_sm_margin, # Set SM margin _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams - atomic_gemm, # use a single GEMM with atomic-counters - torch.Tensor(), # empty tensor to pass to counters + atomic_gemm, # Use a single GEMM with atomic-counters + ub_callbacks, ) _ub_communicators[name] = ub_obj - def alloc_copy_allgather_callback(local_data: torch.Tensor, group: str) -> torch.Tensor: - pg = None if group == "world" else tp_group - global_size = local_data.numel() * torch.distributed.get_world_size(pg) - global_data = torch.zeros(global_size, dtype=local_data.dtype, device="cuda") - torch.distributed.all_gather_into_tensor(global_data, local_data.cuda(), group=pg) - return global_data.cpu() - - def barrier_callback(group: str) -> None: - pg = None if group == "world" else tp_group - torch.distributed.barrier(group=pg) - - def free_callback(data: torch.Tensor) -> None: - data.data = torch.Tensor() - - tex.set_ubuf_bootstrap_callbacks(alloc_copy_allgather_callback, barrier_callback, free_callback) - if ub_cfgs is not None: for name in dgrad_reduce_scatter_overlap: if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk": @@ -238,48 +338,18 @@ def free_callback(data: torch.Tensor) -> None: methods["pipeline"].append(name) for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: + ub_cfg = get_default_config(name) if ub_cfgs is not None and name in ub_cfgs: - ub_cfg = ub_cfgs[name] - method = ub_cfg.get("method", get_method(name)) - num_sm = ub_cfg.get("num_sm", 1 if method == "ring_exchange" else 16) - cga_size = ub_cfg.get("cga_size", 1 if method == "ring_exchange" else 2) - num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0) - set_sm_margin = ub_cfg.get("set_sm_margin", 0) - aggregate = ub_cfg.get("aggregate", 0) - atomic_gemm = ub_cfg.get("atomic_gemm", 0) - use_ce = ub_cfg.get("use_ce", True) - is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0 - # Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter fp8_buf = (name in layers_all_gather_overlap) or ( - ub_cfg.get("fp8_buf", False) and name in methods["pipeline"] - ) - add_ub( - name, - method, - is_reduce_scatter, - num_sm, - cga_size, - set_sm_margin, - num_splits, - aggregate, - atomic_gemm, - use_ce, - fp8_buf, - ) - else: - method = get_method(name) - add_ub( - name, - method=method, - is_reduce_scatter=1 if name in layers_reduce_scatter_overlap else 0, - num_splits=4 if method == "pipeline" else 0, - fp8_buf=name in layers_all_gather_overlap, + ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"] ) + ub_cfg.update(ub_cfgs[name]) + ub_cfg["fp8_buf"] = fp8_buf + add_ub(name, **ub_cfg) def get_ub(name: str): """Get userbuffer communicator corresponding to give key.""" - global _ub_communicators assert _ub_communicators is not None, "UB manager is not initialized." assert name in _ub_communicators, f"UB for {name} is not registered." return _ub_communicators[name] From 9edcaf0e17458afbbd6a9f26a41ae6d367799305 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 24 Jul 2024 10:13:38 -0700 Subject: [PATCH 14/72] Update minimum CMake version (#1037) * Set minimum CMake version to 3.21 Stop linking to nvtx. Signed-off-by: Tim Moon * Update .github/workflows/build.yml Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Revert Python version to 3.9 Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- .github/workflows/build.yml | 3 ++- setup.py | 2 +- transformer_engine/common/CMakeLists.txt | 5 ++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8df4b5179e..acec20b566 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,7 +18,8 @@ jobs: - name: 'Dependencies' run: | apt-get update - apt-get install -y git python3.9 pip cmake ninja-build cudnn9-cuda-12 + apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 + pip install cmake==3.21.0 - name: 'Checkout' uses: actions/checkout@v3 with: diff --git a/setup.py b/setup.py index d2cc91d65a..41521418ba 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Requirements that may be installed outside of Python if not found_cmake(): - setup_reqs.append("cmake>=3.18") + setup_reqs.append("cmake>=3.21") if not found_ninja(): setup_reqs.append("ninja") if not found_pybind11(): diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 0cf48f37f2..e22e8dbbc8 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -2,7 +2,7 @@ # # See LICENSE for license information. -cmake_minimum_required(VERSION 3.18) +cmake_minimum_required(VERSION 3.21) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) @@ -18,7 +18,7 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") endif() -find_package(CUDAToolkit REQUIRED cublas nvToolsExt) +find_package(CUDAToolkit REQUIRED) # Check for cuDNN frontend API set(CUDNN_FRONTEND_INCLUDE_DIR @@ -79,7 +79,6 @@ target_link_libraries(transformer_engine PUBLIC CUDA::cuda_driver CUDA::cudart CUDA::nvrtc - CUDA::nvToolsExt CUDNN::cudnn) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) From 08b4997605e99c911664539b33657d0c03fe27a6 Mon Sep 17 00:00:00 2001 From: Tian Zheng Date: Thu, 25 Jul 2024 04:07:22 +0800 Subject: [PATCH 15/72] [Paddle] Fix device memory leak (#1029) * i Signed-off-by: Tian Zheng (Engrg-Hardware 1) * . Signed-off-by: Tian Zheng (Engrg-Hardware 1) --------- Signed-off-by: Tian Zheng (Engrg-Hardware 1) --- transformer_engine/paddle/csrc/custom_ops.cu | 21 +++++++------ transformer_engine/paddle/layer/base.py | 31 ++++++++++++++------ 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 3204574053..7d401c348e 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -1355,14 +1355,13 @@ void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT static_cast(fp8_dtype), margin, amax_history.stream()); } -void amax_and_scale_update_inplace_legacy(paddle::Tensor &amax_history, // NOLINT - paddle::Tensor &scale, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - const paddle::Tensor &non_weight_mask, - const paddle::Tensor ¤t_step_id_tensor, - bool update_weight_scale_inv, bool fwd_update, - float fp8_max, float margin, - const std::string &amax_compute) { +void amax_and_scale_update_inplace_legacy( + paddle::Tensor &amax_history, // NOLINT + paddle::Tensor &scale, // NOLINT + paddle::Tensor &scale_inv, // NOLINT + const paddle::Tensor &non_weight_mask, + const paddle::optional ¤t_step_id_tensor, bool update_weight_scale_inv, + bool fwd_update, float fp8_max, float margin, const std::string &amax_compute) { #if PADDLE_VERSION > 261 NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent"); @@ -1380,8 +1379,7 @@ void amax_and_scale_update_inplace_legacy(paddle::Tensor &amax_history, // NOLI auto amax_numel = amax.numel(); size_t num_blocks = (amax_history_numel + BLOCK_SIZE - 1) / BLOCK_SIZE; - const int *current_step_id_ptr = nullptr; - if (fwd_update) current_step_id_ptr = current_step_id_tensor.data(); + const int *current_step_id_ptr = GetOptionalDataPtr(current_step_id_tensor); auto parameterSetter = [current_step_id_ptr, fwd_update](phi::backends::gpu::CUDAKernelParams ¶ms) { if (fwd_update) { @@ -1758,7 +1756,8 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward) PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward)); PD_BUILD_OP(amax_and_scale_update_inplace_legacy) - .Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask", "current_step_id_tensor"}) + .Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask", + paddle::Optional("current_step_id_tensor")}) .Outputs({"amax_history", "scale", "scale_inv"}) .SetInplaceMap({{"_amax_history", "amax_history"}, {"_scale", "scale"}, diff --git a/transformer_engine/paddle/layer/base.py b/transformer_engine/paddle/layer/base.py index 86d8ff37fb..adbd1ce269 100644 --- a/transformer_engine/paddle/layer/base.py +++ b/transformer_engine/paddle/layer/base.py @@ -84,15 +84,7 @@ def __init__(self) -> None: self.fp8_weights = [] self.fp8_weight_cache = {} self.registered_pp_start_callback = False - - self.current_step_id = paddle.to_tensor([1], dtype=paddle.int32, place=paddle.CPUPlace()) - - def current_step_id_callback(step_id=None, **kwargs): # pylint: disable=unused-argument - self.current_step_id.copy_( - paddle.to_tensor([step_id], dtype=paddle.int32, place=paddle.CPUPlace()), True - ) - - register_pp_fwd_begin_hook(current_step_id_callback) + self.current_step_id = None def set_activation_dtype(self, inp: paddle.Tensor) -> None: """Get activation data type for AMP.""" @@ -301,6 +293,27 @@ def prepare_forward( if self.fp8_meta.get("update_amax_and_scale_fwd", False): global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() global_fp8_fwd_buffer.wait() + # Register PP forward begin hook when CUDAGraph is enabled. + # NOTE(tizheng): register_pp_fwd_begin_hook prevents layer parameters from being freed + # when the layer object is deleted. Need to find a better way. + if get_global_fp8_state().is_cudagraph_enabled() and self.current_step_id is None: + self.current_step_id = paddle.to_tensor( + [1], dtype=paddle.int32, place=paddle.CPUPlace() + ) + + def current_step_id_callback( + step_id=None, **kwargs + ): # pylint: disable=unused-argument + self.current_step_id.copy_( + paddle.to_tensor( + [step_id], dtype=paddle.int32, place=paddle.CPUPlace() + ), + True, + ) + + if is_pp_enabled(): + register_pp_fwd_begin_hook(current_step_id_callback) + if self.fp8_meta["recipe"].reduce_amax: global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta) amax_and_scale_update( From e1e83598ebb520b5d05c39f31a54793d65e1bb5a Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 24 Jul 2024 16:34:11 -0700 Subject: [PATCH 16/72] [JAX] Debug distributed attention tests (#1038) * Remove extra args to fused attention func Signed-off-by: Tim Moon * Add missing arg to fused attention func Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon --- tests/jax/test_distributed_fused_attn.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 40e9e74733..15676dd270 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -124,12 +124,9 @@ def target_func(qkv, bias, mask): bias, mask, None, - None, - None, - None, - None, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, + qkv_layout=QKVLayout.BS3HD, scaling_factor=scaling_factor, dropout_probability=dropout_prob, is_training=is_training, @@ -260,12 +257,9 @@ def target_func(q, kv, mask): None, mask, None, - None, - None, - None, - None, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, + qkv_layout=QKVLayout.BSHD_BS2HD, scaling_factor=scaling_factor, dropout_probability=dropout_prob, is_training=is_training, From 4b6c07d430713332c779fda14b65f8b86c8150da Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 24 Jul 2024 18:45:36 -0700 Subject: [PATCH 17/72] Fix build error with Paddle >2.6.1 (#1040) * Fix build error with Paddle >2.6.1 Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/paddle/csrc/custom_ops.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 7d401c348e..69569d5584 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -1379,7 +1379,8 @@ void amax_and_scale_update_inplace_legacy( auto amax_numel = amax.numel(); size_t num_blocks = (amax_history_numel + BLOCK_SIZE - 1) / BLOCK_SIZE; - const int *current_step_id_ptr = GetOptionalDataPtr(current_step_id_tensor); + const int *current_step_id_ptr = + reinterpret_cast(GetOptionalDataPtr(current_step_id_tensor)); auto parameterSetter = [current_step_id_ptr, fwd_update](phi::backends::gpu::CUDAKernelParams ¶ms) { if (fwd_update) { From 6ae584dddbdb53933c41fdac13994fd1bd17b33f Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 24 Jul 2024 19:01:47 -0700 Subject: [PATCH 18/72] [PyTorch] Fix linter warnings (#1041) Fix linter warnings Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/comm_gemm_overlap.h | 2 +- .../pytorch/csrc/userbuffers/userbuffers-host.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 0d70c9dc45..88609b6ddb 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -67,7 +67,7 @@ class UbufBootstrapCallbacks : torch::CustomClassHolder { #ifndef NVTE_UB_WITH_MPI NVTE_ERROR("Internal TE error: Dummy UbufBootstrapCallbacks init without NVTE_UB_WITH_MPI=1!"); #endif - }; // empty constructor for NVTE_UB_WITH_MPI=1 + } // empty constructor for NVTE_UB_WITH_MPI=1 UbufBootstrapCallbacks(c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group) { pgs.insert({"world", world_group}); diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp index 982da28d33..e2628f6a31 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp @@ -377,7 +377,7 @@ int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipe char(*hostnames)[MPI_MAX_PROCESSOR_NAME] = static_cast(malloc(numranks * MPI_MAX_PROCESSOR_NAME)); - strcpy(hostnames[myrank], hostname); + strcpy(hostnames[myrank], hostname); // NOLINT(*) for (int n = 0; n < numranks; n++) UB_MPI_CHECK(MPI_Bcast(&(hostnames[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, EXT_COMM_WORLD)); qsort(hostnames, numranks, MPI_MAX_PROCESSOR_NAME, stringCmp); From 098135785777e6ec1cb3d57764f75326bffa0e41 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 24 Jul 2024 20:26:58 -0700 Subject: [PATCH 19/72] Build scripts for pip wheels (#1036) * Specify python version Signed-off-by: Kirthi Shankar Sivamani * Add classifiers for python Signed-off-by: Kirthi Shankar Sivamani * Add utils to build wheels Signed-off-by: Kirthi Shankar Sivamani * make wheel scripts Signed-off-by: Kirthi Shankar Sivamani * Add aarch Signed-off-by: Kirthi Shankar Sivamani * Fixes Signed-off-by: Kirthi Shankar Sivamani * Fix paddle wheel Signed-off-by: Kirthi Shankar Sivamani * PaddlePaddle only builds for x86 Signed-off-by: Kirthi Shankar Sivamani * Add optional fwk deps Signed-off-by: Kirthi Shankar Sivamani * Python3.8; catch install error Signed-off-by: Kirthi Shankar Sivamani * [wip] cudnn9 compile with paddle support Signed-off-by: Kirthi Shankar Sivamani * [wip] dont link cudnn Signed-off-by: Kirthi Shankar Sivamani * dlopen cudnn Signed-off-by: Kirthi Shankar Sivamani * fixes Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * Fixes Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * dynamically load nvrtc Signed-off-by: Kirthi Shankar Sivamani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Lint Signed-off-by: Kirthi Shankar Sivamani * remove residual packages; exclude stub from nvrtc .so search Signed-off-by: Kirthi Shankar Sivamani * Exclude builtins from nvrtc .so search Signed-off-by: Kirthi Shankar Sivamani * properly include files for sdist Signed-off-by: Kirthi Shankar Sivamani * paddle wheel tie to python version Signed-off-by: Kirthi Shankar Sivamani * Fix paddle build from src [wip] Signed-off-by: Kirthi Shankar Sivamani * Fix workflow paddle build Signed-off-by: Kirthi Shankar Sivamani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix paddle Signed-off-by: Kirthi Shankar Sivamani * Fix paddle Signed-off-by: Kirthi Shankar Sivamani * fix lint from pr986 Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * Add sanity wheel test Signed-off-by: Kirthi Shankar Sivamani * Add sanity import to wheel test Signed-off-by: Kirthi Shankar Sivamani * remove upper limit on paddlepaddle version Signed-off-by: Kirthi Shankar Sivamani * Remove unused imports Signed-off-by: Kirthi Shankar Sivamani * Remove pybind11 dependency Signed-off-by: Kirthi Shankar Sivamani * Fix cpp tests Signed-off-by: Kirthi Shankar Sivamani * Search .sos in cuda home Signed-off-by: Kirthi Shankar Sivamani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Kirthi Shankar Sivamani * CLeanup, remove residual code Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/build.yml | 5 +- build_tools/build_ext.py | 8 +- build_tools/utils.py | 36 +++++---- build_tools/wheel_utils/Dockerfile.aarch | 36 +++++++++ build_tools/wheel_utils/Dockerfile.x86 | 36 +++++++++ build_tools/wheel_utils/build_wheels.sh | 79 +++++++++++++++++++ build_tools/wheel_utils/launch_aarch.sh | 8 ++ build_tools/wheel_utils/launch_x86.sh | 8 ++ qa/L0_jax_wheel/test.sh | 21 +++++ qa/L0_paddle_wheel/test.sh | 21 +++++ qa/L0_pytorch_wheel/test.sh | 21 +++++ setup.py | 21 ++++- tests/cpp/CMakeLists.txt | 1 + tests/cpp/operator/CMakeLists.txt | 2 +- tests/cpp/util/CMakeLists.txt | 3 +- transformer_engine/common/CMakeLists.txt | 6 +- transformer_engine/common/__init__.py | 63 ++++++++++++++- transformer_engine/jax/MANIFEST.in | 3 + .../jax/csrc/extensions/activation.cpp | 2 +- .../jax/csrc/extensions/attention.cpp | 2 +- .../jax/csrc/extensions/misc.cpp | 2 +- .../jax/csrc/extensions/normalization.cpp | 2 +- .../jax/csrc/extensions/packing.cpp | 2 +- .../jax/csrc/extensions/pybind.cpp | 2 +- .../jax/csrc/extensions/quantization.cpp | 2 +- .../jax/csrc/extensions/softmax.cpp | 2 +- .../jax/csrc/extensions/transpose.cpp | 2 +- transformer_engine/jax/setup.py | 11 +-- transformer_engine/paddle/MANIFEST.in | 3 + transformer_engine/paddle/setup.py | 22 ++---- transformer_engine/pytorch/MANIFEST.in | 3 + transformer_engine/pytorch/setup.py | 11 +-- 32 files changed, 378 insertions(+), 68 deletions(-) create mode 100644 build_tools/wheel_utils/Dockerfile.aarch create mode 100644 build_tools/wheel_utils/Dockerfile.x86 create mode 100644 build_tools/wheel_utils/build_wheels.sh create mode 100644 build_tools/wheel_utils/launch_aarch.sh create mode 100644 build_tools/wheel_utils/launch_x86.sh create mode 100644 qa/L0_jax_wheel/test.sh create mode 100644 qa/L0_paddle_wheel/test.sh create mode 100644 qa/L0_pytorch_wheel/test.sh create mode 100644 transformer_engine/jax/MANIFEST.in create mode 100644 transformer_engine/paddle/MANIFEST.in create mode 100644 transformer_engine/pytorch/MANIFEST.in diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index acec20b566..2770919947 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -78,7 +78,10 @@ jobs: with: submodules: recursive - name: 'Build' - run: pip install . -v + run: | + apt-get update + apt-get install -y libgoogle-glog-dev + pip install . -v env: NVTE_FRAMEWORK: paddle - name: 'Sanity check' diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 61c82f6fcc..631b2b3627 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -135,8 +135,14 @@ def run(self) -> None: search_paths = list(Path(__file__).resolve().parent.parent.iterdir()) # Source compilation from top-level search_paths.extend(list(Path(self.build_lib).iterdir())) + + # Dynamically load required_libs. + from transformer_engine.common import _load_cudnn, _load_nvrtc + + _load_cudnn() + _load_nvrtc() else: - # Only during release sdist build. + # Only during release bdist build for paddlepaddle. import transformer_engine search_paths = list(Path(transformer_engine.__path__[0]).iterdir()) diff --git a/build_tools/utils.py b/build_tools/utils.py index cf1a0bb445..3230ad35bf 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -11,6 +11,7 @@ import shutil import subprocess import sys +import importlib from pathlib import Path from subprocess import CalledProcessError from typing import List, Optional, Tuple @@ -253,15 +254,6 @@ def get_frameworks() -> List[str]: return _frameworks -def package_files(directory): - paths = [] - for path, _, filenames in os.walk(directory): - path = Path(path) - for filename in filenames: - paths.append(str(path / filename).replace(f"{directory}/", "")) - return paths - - def copy_common_headers(te_src, dst): headers = te_src / "common" for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True): @@ -272,11 +264,21 @@ def copy_common_headers(te_src, dst): def install_and_import(package): """Install a package via pip (if not already installed) and import into globals.""" - import importlib - - try: - importlib.import_module(package) - except ImportError: - subprocess.check_call([sys.executable, "-m", "pip", "install", package]) - finally: - globals()[package] = importlib.import_module(package) + main_package = package.split("[")[0] + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + globals()[main_package] = importlib.import_module(main_package) + + +def uninstall_te_fw_packages(): + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "uninstall", + "-y", + "transformer_engine_torch", + "transformer_engine_paddle", + "transformer_engine_jax", + ] + ) diff --git a/build_tools/wheel_utils/Dockerfile.aarch b/build_tools/wheel_utils/Dockerfile.aarch new file mode 100644 index 0000000000..a0bcd80347 --- /dev/null +++ b/build_tools/wheel_utils/Dockerfile.aarch @@ -0,0 +1,36 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +FROM quay.io/pypa/manylinux_2_28_aarch64 + +WORKDIR /TransformerEngine/ +COPY ../.. /TransformerEngine/ + +ARG VER="12-3" +ARG ARCH="aarch64" +RUN dnf -y install vim + +# Cuda toolkit, cudnn, driver. +RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo +RUN dnf -y install epel-release +RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ + cuda-libraries-${VER}.${ARCH} \ + cuda-libraries-devel-${VER}.${ARCH} +RUN dnf -y install --allowerasing cudnn9-cuda-12 +RUN dnf clean all +RUN rm -rf /var/cache/dnf/* +RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf +RUN dnf -y install cuda-toolkit +RUN dnf clean all +RUN dnf -y install glog.aarch64 glog-devel.aarch64 + +ENV PATH="/usr/local/cuda/bin:${PATH}" +ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" +ENV CUDA_HOME=/usr/local/cuda +ENV CUDA_ROOT=/usr/local/cuda +ENV CUDA_PATH=/usr/local/cuda +ENV CUDADIR=/usr/local/cuda +ENV NVTE_RELEASE_BUILD=1 + +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "false", "false", "true"] diff --git a/build_tools/wheel_utils/Dockerfile.x86 b/build_tools/wheel_utils/Dockerfile.x86 new file mode 100644 index 0000000000..602d99ed4d --- /dev/null +++ b/build_tools/wheel_utils/Dockerfile.x86 @@ -0,0 +1,36 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +FROM quay.io/pypa/manylinux_2_28_x86_64 + +WORKDIR /TransformerEngine/ +COPY ../.. /TransformerEngine/ + +ARG VER="12-3" +ARG ARCH="x86_64" +RUN dnf -y install vim + +# Cuda toolkit, cudnn, driver. +RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo +RUN dnf -y install epel-release +RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ + cuda-libraries-${VER}.${ARCH} \ + cuda-libraries-devel-${VER}.${ARCH} +RUN dnf -y install --allowerasing cudnn9-cuda-12 +RUN dnf clean all +RUN rm -rf /var/cache/dnf/* +RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf +RUN dnf -y install cuda-toolkit +RUN dnf clean all +RUN dnf -y install glog.x86_64 glog-devel.x86_64 + +ENV PATH="/usr/local/cuda/bin:${PATH}" +ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" +ENV CUDA_HOME=/usr/local/cuda +ENV CUDA_ROOT=/usr/local/cuda +ENV CUDA_PATH=/usr/local/cuda +ENV CUDADIR=/usr/local/cuda +ENV NVTE_RELEASE_BUILD=1 + +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"] diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh new file mode 100644 index 0000000000..3c616613d3 --- /dev/null +++ b/build_tools/wheel_utils/build_wheels.sh @@ -0,0 +1,79 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +PLATFORM=${1:-manylinux_2_28_x86_64} +BUILD_COMMON=${2:-true} +BUILD_JAX=${3:-true} +BUILD_PYTORCH=${4:-true} +BUILD_PADDLE=${5:-true} + +export NVTE_RELEASE_BUILD=1 +export TARGET_BRANCH=${TARGET_BRANCH:-wheels} +mkdir /wheelhouse +mkdir /wheelhouse/logs + +# Generate wheels for common library. +git config --global --add safe.directory /TransformerEngine +cd /TransformerEngine +git checkout $TARGET_BRANCH +git submodule update --init --recursive + +if $BUILD_COMMON ; then + /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt + whl_name=$(basename dist/*) + IFS='-' read -ra whl_parts <<< "$whl_name" + whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}" + mv dist/"$whl_name" /wheelhouse/"$whl_name_target" +fi + +if $BUILD_PYTORCH ; then + cd /TransformerEngine/transformer_engine/pytorch + /opt/python/cp38-cp38/bin/pip install torch + /opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt + cp dist/* /wheelhouse/ +fi + +if $BUILD_JAX ; then + cd /TransformerEngine/transformer_engine/jax + /opt/python/cp38-cp38/bin/pip install jax jaxlib + /opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt + cp dist/* /wheelhouse/ +fi + +if $BUILD_PADDLE ; then + if [ "$PLATFORM" == "manylinux_2_28_x86_64" ] ; then + dnf -y remove --allowerasing cudnn9-cuda-12 + dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64 + cd /TransformerEngine/transformer_engine/paddle + + /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl + /opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1 + /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt + /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + + /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl + /opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1 + /opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt + /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + + /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl + /opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1 + /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt + /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + + /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl + /opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1 + /opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt + /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + + /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl + /opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1 + /opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt + /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + + mv dist/* /wheelhouse/ + fi +fi diff --git a/build_tools/wheel_utils/launch_aarch.sh b/build_tools/wheel_utils/launch_aarch.sh new file mode 100644 index 0000000000..9a8d796119 --- /dev/null +++ b/build_tools/wheel_utils/launch_aarch.sh @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch . +docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel" +rm -rf aarch_wheelhouse +docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse diff --git a/build_tools/wheel_utils/launch_x86.sh b/build_tools/wheel_utils/launch_x86.sh new file mode 100644 index 0000000000..7b5649a642 --- /dev/null +++ b/build_tools/wheel_utils/launch_x86.sh @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 . +docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel" +rm -rf x86_wheelhouse +docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse diff --git a/qa/L0_jax_wheel/test.sh b/qa/L0_jax_wheel/test.sh new file mode 100644 index 0000000000..109633495b --- /dev/null +++ b/qa/L0_jax_wheel/test.sh @@ -0,0 +1,21 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +: "${TE_PATH:=/opt/transformerengine}" + +cd $TE_PATH +pip uninstall -y transformer-engine +export NVTE_RELEASE_BUILD=1 +python setup.py bdist_wheel +cd transformer_engine/jax +python setup.py sdist + +export NVTE_RELEASE_BUILD=0 +pip install dist/* +cd $TE_PATH +pip install dist/* + +python $TE_PATH/tests/jax/test_sanity_import.py diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh new file mode 100644 index 0000000000..c1e9a95615 --- /dev/null +++ b/qa/L0_paddle_wheel/test.sh @@ -0,0 +1,21 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +: "${TE_PATH:=/opt/transformerengine}" + +cd $TE_PATH +pip uninstall -y transformer-engine +export NVTE_RELEASE_BUILD=1 +python setup.py bdist_wheel +pip install dist/* +cd transformer_engine/paddle +python setup.py bdist_wheel + +export NVTE_RELEASE_BUILD=0 +cd $TE_PATH +pip install dist/* + +python $TE_PATH/tests/paddle/test_sanity_import.py diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh new file mode 100644 index 0000000000..e108e93cdb --- /dev/null +++ b/qa/L0_pytorch_wheel/test.sh @@ -0,0 +1,21 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +: "${TE_PATH:=/opt/transformerengine}" + +cd $TE_PATH +pip uninstall -y transformer-engine +export NVTE_RELEASE_BUILD=1 +python setup.py bdist_wheel +cd transformer_engine/pytorch +python setup.py sdist + +export NVTE_RELEASE_BUILD=0 +pip install dist/* +cd $TE_PATH +pip install dist/* + +python $TE_PATH/tests/pytorch/test_sanity_import.py diff --git a/setup.py b/setup.py index 41521418ba..6a8bae2793 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ remove_dups, get_frameworks, install_and_import, + uninstall_te_fw_packages, ) from build_tools.te_version import te_version @@ -28,12 +29,14 @@ from setuptools.command.build_ext import build_ext as BuildExtension +os.environ["NVTE_PROJECT_BUILDING"] = "1" + if "pytorch" in frameworks: from torch.utils.cpp_extension import BuildExtension elif "paddle" in frameworks: from paddle.utils.cpp_extension import BuildExtension elif "jax" in frameworks: - install_and_import("pybind11") + install_and_import("pybind11[global]") from pybind11.setup_helpers import build_ext as BuildExtension @@ -61,7 +64,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: setup_reqs: List[str] = [] install_reqs: List[str] = [ "pydantic", - "importlib-metadata>=1.0; python_version<'3.8'", + "importlib-metadata>=1.0", "packaging", ] test_reqs: List[str] = ["pytest>=8.2.1"] @@ -85,6 +88,9 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ext_modules = [setup_common_extension()] if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + # Remove residual FW packages since compiling from source + # results in a single binary with FW extensions included. + uninstall_te_fw_packages() if "pytorch" in frameworks: from build_tools.pytorch import setup_pytorch_extension @@ -129,10 +135,21 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ), extras_require={ "test": test_requires, + "pytorch": [f"transformer_engine_torch=={__version__}"], + "jax": [f"transformer_engine_jax=={__version__}"], + "paddle": [f"transformer_engine_paddle=={__version__}"], }, description="Transformer acceleration library", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, + python_requires=">=3.8, <3.13", + classifiers=[ + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + ], setup_requires=setup_requires, install_requires=install_requires, license_files=("LICENSE",), diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 9eb50a4c7d..3bef457c43 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -34,6 +34,7 @@ include_directories(../../transformer_engine/common) include_directories(${CMAKE_SOURCE_DIR}) find_package(CUDAToolkit REQUIRED) +include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) add_subdirectory(operator) add_subdirectory(util) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 0dd2a6d8e2..9dd02d4181 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -18,7 +18,7 @@ add_executable(test_operator test_causal_softmax.cu ../test_common.cu) -list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB}) +list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS}) target_compile_options(test_operator PRIVATE -O2) diff --git a/tests/cpp/util/CMakeLists.txt b/tests/cpp/util/CMakeLists.txt index 42a41b06af..d93be956b0 100644 --- a/tests/cpp/util/CMakeLists.txt +++ b/tests/cpp/util/CMakeLists.txt @@ -7,7 +7,8 @@ add_executable(test_util test_string.cpp ../test_common.cu) -target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB}) + +target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) target_compile_options(test_util PRIVATE -O2) include(GoogleTest) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e22e8dbbc8..242689f990 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -32,7 +32,6 @@ endif() include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) - include_directories(${PROJECT_SOURCE_DIR}/..) # Configure Transformer Engine library @@ -77,9 +76,7 @@ target_include_directories(transformer_engine PUBLIC target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cuda_driver - CUDA::cudart - CUDA::nvrtc - CUDNN::cudnn) + CUDA::cudart) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") @@ -125,3 +122,4 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") # Install library install(TARGETS transformer_engine DESTINATION .) + diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 66be4b1baa..f4eb2c419f 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -4,6 +4,9 @@ """FW agnostic user-end APIs""" +import glob +import sysconfig +import subprocess import ctypes import os import platform @@ -31,6 +34,39 @@ def _get_sys_extension(): return extension +def _load_cudnn(): + """Load CUDNN shared library.""" + + lib_path = glob.glob( + os.path.join( + sysconfig.get_path("purelib"), + f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]", + ) + ) + + if lib_path: + assert ( + len(lib_path) == 1 + ), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX." + return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL) + + cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") + if cudnn_home: + libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + if cuda_home: + libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + + return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + + def _load_library(): """Load shared library with Transformer Engine C extensions""" @@ -42,5 +78,30 @@ def _load_library(): return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) -if "NVTE_PROJECT_BUILDING" not in os.environ: +def _load_nvrtc(): + """Load NVRTC shared library.""" + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + if cuda_home: + libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True) + libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + + libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True) + libs = libs.decode("utf-8").split("\n") + sos = [] + for lib in libs: + if "stub" in lib or "libnvrtc-builtins" in lib: + continue + if "libnvrtc" in lib and "=>" in lib: + sos.append(lib.split(">")[1].strip()) + if sos: + return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + + +if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + _CUDNN_LIB_CTYPES = _load_cudnn() + _NVRTC_LIB_CTYPES = _load_nvrtc() _TE_LIB_CTYPES = _load_library() diff --git a/transformer_engine/jax/MANIFEST.in b/transformer_engine/jax/MANIFEST.in new file mode 100644 index 0000000000..0c814f95da --- /dev/null +++ b/transformer_engine/jax/MANIFEST.in @@ -0,0 +1,3 @@ +recursive-include build_tools *.* +recursive-include common_headers *.* +recursive-include csrc *.* diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index f291aaecef..51563a8ccd 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -6,7 +6,7 @@ #include "transformer_engine/activation.h" -#include "jax/csrc/extensions.h" +#include "extensions.h" #include "transformer_engine/transpose.h" namespace transformer_engine { diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index bcc49b92c1..640869ac36 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "jax/csrc/extensions.h" +#include "extensions.h" #include "transformer_engine/fused_attn.h" namespace transformer_engine { diff --git a/transformer_engine/jax/csrc/extensions/misc.cpp b/transformer_engine/jax/csrc/extensions/misc.cpp index c40e899e62..357a5679db 100644 --- a/transformer_engine/jax/csrc/extensions/misc.cpp +++ b/transformer_engine/jax/csrc/extensions/misc.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "jax/csrc/extensions.h" +#include "extensions.h" namespace transformer_engine { namespace jax { diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index c93bd13c25..9585e2edf1 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "jax/csrc/extensions.h" +#include "extensions.h" #include "transformer_engine/layer_norm.h" #include "transformer_engine/rmsnorm.h" diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 89d8596ce0..8c948d0a8f 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "jax/csrc/extensions.h" +#include "extensions.h" namespace transformer_engine { namespace jax { diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 3302b2e3c0..95fe3101c9 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "jax/csrc/extensions.h" +#include "extensions.h" namespace transformer_engine { namespace jax { diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 67a2519788..ba376c6238 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "jax/csrc/extensions.h" +#include "extensions.h" #include "transformer_engine/cast.h" namespace transformer_engine { diff --git a/transformer_engine/jax/csrc/extensions/softmax.cpp b/transformer_engine/jax/csrc/extensions/softmax.cpp index 18d59667a9..3af32d1d84 100644 --- a/transformer_engine/jax/csrc/extensions/softmax.cpp +++ b/transformer_engine/jax/csrc/extensions/softmax.cpp @@ -6,7 +6,7 @@ #include "transformer_engine/softmax.h" -#include "jax/csrc/extensions.h" +#include "extensions.h" namespace transformer_engine { namespace jax { diff --git a/transformer_engine/jax/csrc/extensions/transpose.cpp b/transformer_engine/jax/csrc/extensions/transpose.cpp index 3de1856043..3e53b7521f 100644 --- a/transformer_engine/jax/csrc/extensions/transpose.cpp +++ b/transformer_engine/jax/csrc/extensions/transpose.cpp @@ -6,7 +6,7 @@ #include "transformer_engine/transpose.h" -#include "jax/csrc/extensions.h" +#include "extensions.h" namespace transformer_engine { namespace jax { diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 19656ced94..c2219e3ba9 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -29,13 +29,14 @@ from build_tools.build_ext import get_build_ext -from build_tools.utils import package_files, copy_common_headers, install_and_import +from build_tools.utils import copy_common_headers, install_and_import from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension install_and_import("pybind11") from pybind11.setup_helpers import build_ext as BuildExtension +os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension) @@ -53,18 +54,12 @@ setuptools.setup( name="transformer_engine_jax", version=te_version(), - packages=["csrc", common_headers_dir, "build_tools"], description="Transformer acceleration library - Jax Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, install_requires=["jax", "flax>=0.7.1"], tests_require=["numpy", "praxis"], - include_package_data=True, - package_data={ - "csrc": package_files("csrc"), - common_headers_dir: package_files(common_headers_dir), - "build_tools": package_files("build_tools"), - }, ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) + shutil.rmtree("build_tools") diff --git a/transformer_engine/paddle/MANIFEST.in b/transformer_engine/paddle/MANIFEST.in new file mode 100644 index 0000000000..0c814f95da --- /dev/null +++ b/transformer_engine/paddle/MANIFEST.in @@ -0,0 +1,3 @@ +recursive-include build_tools *.* +recursive-include common_headers *.* +recursive-include csrc *.* diff --git a/transformer_engine/paddle/setup.py b/transformer_engine/paddle/setup.py index 3ab8420fe7..5b1d1a1e04 100644 --- a/transformer_engine/paddle/setup.py +++ b/transformer_engine/paddle/setup.py @@ -29,15 +29,13 @@ shutil.copytree(build_tools_dir, build_tools_copy) -from build_tools.build_ext import get_build_ext # pylint: disable=wrong-import-position -from build_tools.utils import ( - package_files, - copy_common_headers, -) # pylint: disable=wrong-import-position -from build_tools.te_version import te_version # pylint: disable=wrong-import-position -from build_tools.paddle import setup_paddle_extension # pylint: disable=wrong-import-position +from build_tools.build_ext import get_build_ext +from build_tools.utils import copy_common_headers +from build_tools.te_version import te_version +from build_tools.paddle import setup_paddle_extension +os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension) @@ -55,18 +53,12 @@ setuptools.setup( name="transformer_engine_paddle", version=te_version(), - packages=["csrc", common_headers_dir, "build_tools"], description="Transformer acceleration library - Paddle Paddle Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, - install_requires=["paddlepaddle-gpu"], + install_requires=["paddlepaddle-gpu>=2.6.1"], tests_require=["numpy"], - include_package_data=True, - package_data={ - "csrc": package_files("csrc"), - common_headers_dir: package_files(common_headers_dir), - "build_tools": package_files("build_tools"), - }, ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) + shutil.rmtree("build_tools") diff --git a/transformer_engine/pytorch/MANIFEST.in b/transformer_engine/pytorch/MANIFEST.in new file mode 100644 index 0000000000..0c814f95da --- /dev/null +++ b/transformer_engine/pytorch/MANIFEST.in @@ -0,0 +1,3 @@ +recursive-include build_tools *.* +recursive-include common_headers *.* +recursive-include csrc *.* diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 9d0f24b478..e2f15d5d89 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -30,11 +30,12 @@ from build_tools.build_ext import get_build_ext -from build_tools.utils import package_files, copy_common_headers +from build_tools.utils import copy_common_headers from build_tools.te_version import te_version from build_tools.pytorch import setup_pytorch_extension +os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension) @@ -52,18 +53,12 @@ setuptools.setup( name="transformer_engine_torch", version=te_version(), - packages=["csrc", common_headers_dir, "build_tools"], description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"], tests_require=["numpy", "onnxruntime", "torchvision"], - include_package_data=True, - package_data={ - "csrc": package_files("csrc"), - common_headers_dir: package_files(common_headers_dir), - "build_tools": package_files("build_tools"), - }, ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) + shutil.rmtree("build_tools") From 1aaf1cc8b570f317205a6c38acdcaa150506b1cd Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 25 Jul 2024 09:50:48 -0700 Subject: [PATCH 20/72] Fixes for pip wheels (#1042) * Fixes for wheels Signed-off-by: Kirthi Shankar Sivamani * Fix paddle wheel test Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- build_tools/wheel_utils/build_wheels.sh | 5 ++--- qa/L0_paddle_wheel/test.sh | 1 - transformer_engine/common/CMakeLists.txt | 3 +++ transformer_engine/common/pycudnn.cpp | 14 ++++++++++++++ 4 files changed, 19 insertions(+), 4 deletions(-) create mode 100644 transformer_engine/common/pycudnn.cpp diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 3c616613d3..1896fc4e42 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -11,9 +11,8 @@ BUILD_PYTORCH=${4:-true} BUILD_PADDLE=${5:-true} export NVTE_RELEASE_BUILD=1 -export TARGET_BRANCH=${TARGET_BRANCH:-wheels} -mkdir /wheelhouse -mkdir /wheelhouse/logs +export TARGET_BRANCH=${TARGET_BRANCH:-} +mkdir -p /wheelhouse/logs # Generate wheels for common library. git config --global --add safe.directory /TransformerEngine diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh index c1e9a95615..e2d6d38dd4 100644 --- a/qa/L0_paddle_wheel/test.sh +++ b/qa/L0_paddle_wheel/test.sh @@ -15,7 +15,6 @@ cd transformer_engine/paddle python setup.py bdist_wheel export NVTE_RELEASE_BUILD=0 -cd $TE_PATH pip install dist/* python $TE_PATH/tests/paddle/test_sanity_import.py diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 242689f990..b814ef5974 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -37,6 +37,7 @@ include_directories(${PROJECT_SOURCE_DIR}/..) # Configure Transformer Engine library set(transformer_engine_SOURCES) list(APPEND transformer_engine_SOURCES + pycudnn.cpp transformer_engine.cpp transpose/cast_transpose.cu transpose/transpose.cu @@ -72,6 +73,8 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") +target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) + # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas diff --git a/transformer_engine/common/pycudnn.cpp b/transformer_engine/common/pycudnn.cpp new file mode 100644 index 0000000000..7d06f332cb --- /dev/null +++ b/transformer_engine/common/pycudnn.cpp @@ -0,0 +1,14 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +namespace cudnn_frontend { + +// This is needed to define the symbol `cudnn_dlhandle` +// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING +// to enable dynamic loading. +void *cudnn_dlhandle = nullptr; + +} // namespace cudnn_frontend From 4cc220c9b0be8fc1777d6f19ab67620aa1dc23c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=87=91=E6=A2=81?= <975761915@qq.com> Date: Fri, 26 Jul 2024 00:56:42 +0800 Subject: [PATCH 21/72] fix bug of attn backward in non-casual model with context parallel open. (#1031) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This bug will cause bug [ERROR] failed (exitcode: -11) local_rank: 0 (pid: 1761020) of binary: ~/megatron/bin/python. That is because we miss the rng_states that is required in attention recompute (for dropout), but no hint is provided. It is very very very difficult to trace and cost me two weeks. ```python before the start of training step] datetime: 2024-07-22 18:26:45 [2024-07-22 18:27:00,941] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -11) local_rank: 0 (pid: 1761020) of binary: /home//miniconda3/envs/megatron/bin/python Traceback (most recent call last): File "/home//miniconda3/envs/megatron/bin/torchrun", line 33, in sys.exit(load_entry_point('torch==2.2.1+cu121', 'console_scripts', 'torchrun')()) File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper return f(*args, **kwargs) File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/run.py", line 812, in main run(args) File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run elastic_launch( File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ``` Signed-off-by: 李金梁 <975761915@qq.com> --- transformer_engine/pytorch/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index eda0c136d5..44f0f633a0 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2132,6 +2132,7 @@ def backward(ctx, dout): ctx.dropout_p, ctx.softmax_scale, False, + rng_state=rng_states[cp_size - i - 1], **fa_optional_backward_kwargs, ) From 0b303dad4c66693351dda93e11c7b61b873751b0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 26 Jul 2024 07:59:46 -0700 Subject: [PATCH 22/72] [PyTorch] Fix tp_size for MQA/GQA (#1044) fix tp_size for GQA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 44f0f633a0..fa72ecfa33 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5125,7 +5125,7 @@ def __init__( self.hidden_size_per_attention_head = kv_channels self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups - self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) + self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) assert ( num_attention_heads % self.num_gqa_groups == 0 From 81dd6ad0a097bbb9f6690ef3203218ae1b30987e Mon Sep 17 00:00:00 2001 From: Tian Zheng Date: Tue, 30 Jul 2024 08:17:47 +0800 Subject: [PATCH 23/72] [Paddle] Update Paddle image (#1053) Update Paddle image Signed-off-by: Tian Zheng --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2770919947..fb7ab345d1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -70,7 +70,7 @@ jobs: name: 'PaddlePaddle' runs-on: ubuntu-latest container: - image: nvcr.io/nvidia/paddlepaddle:24.05-py3 + image: nvcr.io/nvidia/paddlepaddle:24.07-py3 options: --user root steps: - name: 'Checkout' From d793ca17ac3b1206f420f273c1747e648cde96f6 Mon Sep 17 00:00:00 2001 From: Shijie Date: Wed, 31 Jul 2024 00:15:25 +0800 Subject: [PATCH 24/72] [Paddle] Add deterministic option in DotProductAttention (#956) add deterministic option Signed-off-by: Shijie Wang Co-authored-by: Kirthi Shankar Sivamani --- tests/paddle/test_layers.py | 16 ++++++- .../fused_attn_f16_arbitrary_seqlen.cu | 2 +- transformer_engine/paddle/cpp_extensions.py | 6 +++ transformer_engine/paddle/csrc/custom_ops.cu | 43 ++++++++++--------- transformer_engine/paddle/layer/attention.py | 40 +++++++++++++++++ 5 files changed, 85 insertions(+), 22 deletions(-) diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py index 6a985d7e86..b519fc0a0f 100644 --- a/tests/paddle/test_layers.py +++ b/tests/paddle/test_layers.py @@ -872,8 +872,9 @@ def test_layernorm_mlp_fp8_microbatch( @pytest.mark.parametrize("attn_type", ["self", "cross"]) @pytest.mark.parametrize("mask_type", ["causal", "padding"]) @pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) +@pytest.mark.parametrize("deterministic", [True, False]) def test_dot_product_attention( - bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype + bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype, deterministic ): """ Test DotProductAttention Layer @@ -927,6 +928,10 @@ def test_dot_product_attention( attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False head_size = hidden_size // num_heads + + if deterministic: + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + layer_te = te.DotProductAttention( num_heads, head_size, @@ -981,6 +986,15 @@ def calc_attn_output_and_grad(layer, q, k, v, mask, dout): assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol) assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol) assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol) + if deterministic: + out2, q_grad2, k_grad2, v_grad2 = calc_attn_output_and_grad( + layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out + ) + assert_allclose(out, out2, rtol=1e-12, atol=1e-12) + assert_allclose(q_grad, q_grad2, rtol=1e-12, atol=1e-12) + assert_allclose(k_grad, k_grad2, rtol=1e-12, atol=1e-12) + assert_allclose(v_grad, v_grad2, rtol=1e-12, atol=1e-12) + os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) @pytest.mark.parametrize("bs", [1, 2]) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 525fd3330d..7ee7ba33bd 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -586,7 +586,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_sliding_window_length(window_size_left); } - if (cudnn_runtime_version >= 90000 && sm_arch_ >= 90) { + if (cudnn_runtime_version >= 90000) { sdpa_backward_options.set_deterministic_algorithm(deterministic); } diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py index cd57458c41..7860da2496 100644 --- a/transformer_engine/paddle/cpp_extensions.py +++ b/transformer_engine/paddle/cpp_extensions.py @@ -659,6 +659,7 @@ def fused_attn_bwd_qkvpacked( qkv_layout: str = "bs3hd", bias_type: str = "no_bias", attn_mask_type: str = "padding", + deterministic: bool = False, ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Fused Attention BWD for packed QKV input""" @@ -715,6 +716,7 @@ def fused_attn_bwd_qkvpacked( bias_type, attn_mask_type, int(qkv_dtype), + deterministic, ) return dqkv, dbias @@ -855,6 +857,7 @@ def fused_attn_bwd_kvpacked( qkv_layout: str = "bshd_bs2hd", bias_type: str = "no_bias", attn_mask_type: str = "padding", + deterministic: bool = False, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Fused Attention BWD for packed KV input""" @@ -921,6 +924,7 @@ def fused_attn_bwd_kvpacked( bias_type, attn_mask_type, int(qkv_dtype), + deterministic, ) return dq, dkv, dbias @@ -1061,6 +1065,7 @@ def fused_attn_bwd( qkv_layout: str = "bshd_bshd_bshd", bias_type: str = "no_bias", attn_mask_type: str = "padding", + deterministic: bool = False, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Fused Attention BWD for packed KV input""" @@ -1130,6 +1135,7 @@ def fused_attn_bwd( bias_type, attn_mask_type, int(qkv_dtype), + deterministic, ) return dq, dk, dv, dbias diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 69569d5584..904d979b8e 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -708,7 +708,8 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor int64_t b, int64_t h, int64_t d, int64_t total_seqs, int64_t max_seqlen, float attn_scale, float p_dropout, const std::string &qkv_layout, const std::string &bias_type, - const std::string &attn_mask_type, int64_t qkv_type) { + const std::string &attn_mask_type, int64_t qkv_type, + bool deterministic) { TensorWrapper te_dBias; if (bias_type != "no_bias" && dBias) { auto bias_shape = dBias->shape(); @@ -759,22 +760,22 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), - te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream()); + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, + te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, + attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, + deterministic, workspace.data(), QKV.stream()); // allocate memory for workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); // execute kernel - nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), - te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream()); + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, + te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, + attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, + deterministic, workspace.data(), QKV.stream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -884,7 +885,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv, float attn_scale, float p_dropout, const std::string &qkv_layout, const std::string &bias_type, const std::string &attn_mask_type, - int64_t qkv_type) { + int64_t qkv_type, bool deterministic) { TensorWrapper te_dBias; if (bias_type != "no_bias" && dBias) { auto bias_shape = dBias->shape(); @@ -945,7 +946,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, - -1, -1, true, workspace.data(), Q.stream()); + -1, -1, deterministic, workspace.data(), Q.stream()); // allocate memory for workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); @@ -957,7 +958,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, - -1, -1, true, workspace.data(), Q.stream()); + -1, -1, deterministic, workspace.data(), Q.stream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -1086,7 +1087,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv, float attn_scale, float p_dropout, const std::string &qkv_layout, const std::string &bias_type, const std::string &attn_mask_type, - int64_t qkv_type) { + int64_t qkv_type, bool deterministic) { TensorWrapper te_dBias; if (bias_type != "no_bias" && dBias) { auto bias_shape = dBias->shape(); @@ -1149,7 +1150,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream()); + attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream()); // allocate memory for workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); @@ -1161,7 +1162,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream()); + attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -1657,7 +1658,8 @@ PD_BUILD_OP(te_fused_attn_bwd_qkvpacked) .Outputs({"dQKV", paddle::Optional("dBias")}) .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"}) + "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", + "deterministic: bool"}) .SetInplaceMap({{"_dQKV", "dQKV"}, {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_qkvpacked)); @@ -1682,7 +1684,8 @@ PD_BUILD_OP(te_fused_attn_bwd_kvpacked) .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"}) + "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", + "deterministic: bool"}) .SetInplaceMap({{"_dQ", "dQ"}, {"_dKV", "dKV"}, {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) @@ -1708,7 +1711,7 @@ PD_BUILD_OP(te_fused_attn_bwd) .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string", - "qkv_type: int64_t"}) + "qkv_type: int64_t", "deterministic: bool"}) .SetInplaceMap({{"_dQ", "dQ"}, {"_dK", "dK"}, {"_dV", "dV"}, diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index 98e50b9e04..75a3513d14 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -152,6 +152,7 @@ def forward( attn_bias_type, attn_mask_type, is_training, + deterministic, fused_attention_backend, ): """Forward function for FusedAttention with packed QKV input""" @@ -180,6 +181,7 @@ def forward( ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type + ctx.deterministic = deterministic ctx.fused_attention_backend = fused_attention_backend return out @@ -204,6 +206,7 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.deterministic, ) # if no_bias, return dqkv @@ -234,6 +237,7 @@ def forward( attn_bias_type, attn_mask_type, is_training, + deterministic, fused_attention_backend, ): """Forward function for FusedAttention with packed KV input""" @@ -266,6 +270,7 @@ def forward( ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type + ctx.deterministic = deterministic ctx.fused_attention_backend = fused_attention_backend return out @@ -293,6 +298,7 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.deterministic, ) # if no_bias, return dq, dkv @@ -324,6 +330,7 @@ def forward( attn_bias_type, attn_mask_type, is_training, + deterministic, fused_attention_backend, ): """Forward function for FusedAttention with separate Q, K, V tensors""" @@ -357,6 +364,7 @@ def forward( ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type + ctx.deterministic = deterministic ctx.fused_attention_backend = fused_attention_backend return out @@ -385,6 +393,7 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.deterministic, ) # if no_bias, return dq, dk, dv if ctx.attn_bias_type == "no_bias": @@ -404,6 +413,12 @@ class DotProductAttention(paddle.nn.Layer): Argument :attr:`attention_mask` will be ignored in the `forward` call when :attr:`attn_mask_type` is set to `"causal"`. + .. warning:: + + Fused attention backward uses a non-deterministic algorithm when workspace + optimization is not enabled. To use a deterministic algorithm, set the + environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` + Parameters ---------- num_attention_heads: int @@ -458,6 +473,29 @@ def __init__( self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1"))) + self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + + # To use the workspace optimization path for determinism, please + # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0, + # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0. + cudnn_version = paddle.get_cudnn_version() + if 8905 <= cudnn_version < 9000: + if self.deterministic: + # workspace optimization path is deterministic + os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" + + # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT + # - unset: enables workspace optimization when required workspace is <= 256MB + # or when bias gradient needs to be computed + # - n: enables workspace optimization when required workspace is <= n bytes + # - -1: enables workspace optimization always + # - 0: disables workspace optimization always + if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ: + if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0": + os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0" + if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": + os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" + if not self.use_fused_attention and backend == "transformer_engine": warnings.warn("Fused attention is not enabled, falling back to Paddle backend") self.backend = "paddle" @@ -603,6 +641,7 @@ def _te_forward( core_attention_bias_type, self.attn_mask_type, self.training, + self.deterministic, self.fused_attention_backend, ) elif self.attention_type == "cross": @@ -637,6 +676,7 @@ def _te_forward( core_attention_bias_type, self.attn_mask_type, self.training, + self.deterministic, self.fused_attention_backend, ) else: From 54c1cfad3eb7b6a54075c6a161bb6213e7265f3b Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 30 Jul 2024 09:44:33 -0700 Subject: [PATCH 25/72] [pytorch] removed unused import causing CI failures in fused attention (#1058) Rm unused import causing CI failures Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/fused_attn/test_fused_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 760624d8c9..73dfa23d9a 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -11,7 +11,6 @@ import pytest import torch -from pkg_resources import packaging from transformer_engine.common import recipe from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init From c8c05f38b773d7509c43dbdbb52cddf58aac6962 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Tue, 30 Jul 2024 16:02:07 -0700 Subject: [PATCH 26/72] Load balanced offloading algorithm (#1057) * Load balanced offloading algorithm Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/cpu_offload.py | 117 +++++++++++----------- 1 file changed, 61 insertions(+), 56 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index b42d40d9f3..4e9c74d396 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -274,7 +274,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): def __init__( self, num_offload_group, # must be <= actual number of groups (number of commits) - num_prefetch_group=1, + num_model_group, tensor_need_offloading_checker=(lambda t: True), debug=False, ) -> None: @@ -283,19 +283,29 @@ def __init__( tensor_need_offloading_checker=tensor_need_offloading_checker, debug=debug, ) - self.num_prefetch_group = num_prefetch_group + # Number of layers in the model + self.num_layers = num_model_group # Data Structure to maintain reference to activation tensors self.tensor_tag_to_buf = {} + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant # allocate streams and events for synchronization self.d2h_stream = torch.cuda.Stream() self.h2d_stream = torch.cuda.Stream() - self.h2d_finish_events = [] - self.compute_stream_bwd_start_events = [] - for _ in range(self.num_offload_group): - self.h2d_finish_events.append(torch.cuda.Event()) - self.compute_stream_bwd_start_events.append(torch.cuda.Event()) - self.d2h_final_event = torch.cuda.Event() def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: @@ -352,41 +362,44 @@ def bulk_offload_group(self, group_to_offload): def synchronize_on_group_commit_forward(self, current_group): """Synchronize on group commit forward.""" - # the host should wait for the copying of previous group - # to avoid overwriting buffer - previous_group = current_group - 1 - if previous_group < self.num_offload_group: - torch.cuda.synchronize() - - # Have to release the memory held by activations of the previous layer - if previous_group >= 0: - for tensor_tag, _ in self.tensor_tag_to_buf.items(): - if tensor_tag[0] == previous_group: - self.tensor_tag_to_buf[tensor_tag] = None - - # the copying of this group should wait for the computation stream event - if current_group < self.num_offload_group: - # perform bulk offloading + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(torch.cuda.current_stream()) self.bulk_offload_group(current_group) - if current_group == self.num_offload_group - 1: - self.d2h_stream.record_event(self.d2h_final_event) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + if self.layer_window_map[self.offloaded_group_count] == current_group: + + # Stream synchronization both ways + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, _ in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 def on_group_commit_forward(self): """This function will cause host device synchronization""" # handle synchronization events self.synchronize_on_group_commit_forward(self.current_group) - # during forward, the next_group_to_fetch always points to the min of - # the last commited group, and the last offloaded group - self.next_group_to_fetch = min(self.current_group, self.num_offload_group - 1) - super().on_group_commit_forward() def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" assert group_to_reload < self.num_offload_group - if group_to_reload == self.num_offload_group - 1: - self.h2d_stream.wait_event(self.d2h_final_event) + with torch.cuda.stream(self.h2d_stream): # move back tensors for tensor_label, state in self.tensor_tag_to_state.items(): @@ -403,39 +416,29 @@ def on_group_commit_backward(self): self.current_group -= 1 assert self.current_group >= 0 - # decide the range of group to prefetch - should_prefetch_until_group = self.current_group - self.num_prefetch_group - should_prefetch_until_group = max(should_prefetch_until_group, 0) - - # do prefetch - for group_num_to_prefetch in range( - self.next_group_to_fetch, should_prefetch_until_group - 1, -1 - ): - # record the event in the compute stream, for h2d to wait - torch.cuda.current_stream().record_event( - self.compute_stream_bwd_start_events[group_num_to_prefetch] - ) - - # start of h2d should wait for the compute and the d2h - self.h2d_stream.wait_event(self.compute_stream_bwd_start_events[group_num_to_prefetch]) + # Layer window data structure helps us to reload at right times + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: - # recover tensors (copy back from host) - self.bulk_reload_group(group_num_to_prefetch) + # Stream synchronization both ways + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.h2d_stream) - # record an event for the backward of this layer to wait - self.h2d_stream.record_event(self.h2d_finish_events[group_num_to_prefetch]) + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) - # always is set to -1 at the end of the backward - self.next_group_to_fetch = min(self.num_offload_group - 1, should_prefetch_until_group - 1) + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 - # wait for the current group - if self.current_group < self.num_offload_group: - torch.cuda.current_stream().wait_event(self.h2d_finish_events[self.current_group]) + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + torch.cuda.current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 def get_cpu_offload_context( enabled: bool = False, num_layers: int = 1, + model_layers: int = 1, offload_activations: bool = True, offload_weights: bool = True, ): @@ -460,6 +463,8 @@ def get_cpu_offload_context( num_layers: int, default = 1 Determines the number of transformer layers you want to offload activations/weights for. + model_layers: int, default = 1 + Number of layers in the model that will be used under this context. offload_activations: bool, default = `True` When set to `True`, offloads the activations for the TE layer. offload_weights: bool, default = `True` @@ -491,7 +496,7 @@ def tensor_need_offloading_checker_all(tensor): cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( num_offload_group=num_layers, - num_prefetch_group=1, + num_model_group=model_layers, tensor_need_offloading_checker=tensor_need_offloading_checker, ) From e113bf84bc1b6127400234d3b2a9eab692199283 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 31 Jul 2024 09:59:04 -0700 Subject: [PATCH 27/72] [pyTorch] Fix wrong results for noncontiguous input (#1017) * Ensure that the inputs to custom calls are contiguous Signed-off-by: Przemek Tredak * Fixes Signed-off-by: Przemek Tredak * Added test Signed-off-by: Przemek Tredak * Fixes Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes from review Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Przemek Tredak Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_numerics.py | 49 ++++++++++++ .../pytorch/cpp_extensions/gemm.py | 15 +++- .../pytorch/csrc/extensions/gemm.cu | 4 + .../pytorch/csrc/extensions/normalization.cu | 74 ++++++++++++------- .../pytorch/module/layernorm_linear.py | 6 +- .../pytorch/module/layernorm_mlp.py | 4 +- 6 files changed, 122 insertions(+), 30 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 56c6de0333..6c967d78e9 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1816,3 +1816,52 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): # should be bit-wise match for o, o_ref in zip(out, out_ref): torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +def test_noncontiguous(): + def _create2modules(m, params): + mod1 = m(*params) + mod2 = m(*params) + for p1, p2 in zip(mod1.parameters(), mod2.parameters()): + p2.data = p1.data.clone() + + return mod1, mod2 + + def _run_module(m, inp): + out = m(inp) + out.sum().backward() + ret = [out] + if inp.grad is not None: + ret.append(inp.grad) + + for p in m.parameters(): + if p.requires_grad: + ret.append(p.grad) + return ret + + a = torch.randn((128, 256), device="cuda", requires_grad=True) + a = a.T + assert not a.is_contiguous(), "The test is supposed to test noncontiguous input." + + b = a.contiguous() + + # LayerNorm + ln1, ln2 = _create2modules(LayerNorm, [128]) + outT = _run_module(ln1, a) + out = _run_module(ln2, b) + + assert_allclose(out, outT, 1e-7) + + # RMSNorm + ln1, ln2 = _create2modules(RMSNorm, [128]) + outT = _run_module(ln1, a) + out = _run_module(ln2, b) + + assert_allclose(out, outT, 1e-7) + + # GEMM + g1, g2 = _create2modules(Linear, [128, 128]) + outT = _run_module(g1, a) + out = _run_module(g2, b) + + assert_allclose(out, outT, 1e-7) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index d810cf8478..38392a5795 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -54,6 +54,9 @@ def fp8_gemm( dtype=out_dtype, device="cuda", ) + else: + if not out.is_contiguous(): + raise ValueError("Output tensor is not contiguous.") # Use bfloat16 as default bias_dtype bias_dtype = torch.bfloat16 if bias is None else bias.dtype @@ -202,6 +205,9 @@ def gemm( dtype=dtype, device="cuda", ) + else: + if not out.is_contiguous(): + raise ValueError("Output tensor is not contiguous.") if gelu and not grad: gelu_input = torch.empty_like(out, dtype=dtype) @@ -311,7 +317,9 @@ def grouped_gemm( empty_tensors = [torch.Tensor()] * num_gemms if gelu and not grad: - gelu_input = [torch.empty_like(o, dtype=dtype) for o in out] + gelu_input = [ + torch.empty_like(o, dtype=dtype, memory_format=torch.contiguous_format) for o in out + ] elif not gelu: gelu_input = empty_tensors @@ -406,7 +414,10 @@ def fp8_grouped_gemm( # Use bfloat16 as default bias_dtype bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype if gelu: - gelu_input = [torch.empty_like(o, dtype=bias_dtype) for o in out] + gelu_input = [ + torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) + for o in out + ] else: gelu_input = empty_tensors bias_dtype = TE_DType[bias_dtype] diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 720fc146d1..bd698ded27 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -21,6 +21,10 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_(); return; } + + A = A.contiguous(); + B = B.contiguous(); + auto te_A = makeTransformerEngineTensor( A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, nullptr, nullptr, A_scale_inverse.data_ptr()); diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cu b/transformer_engine/pytorch/csrc/extensions/normalization.cu index 77bbcbc9d6..04274ae2ef 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cu +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cu @@ -10,16 +10,22 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &mu, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma) { - auto dx = at::empty_like(x); - auto dgamma = at::empty_like(gamma); - auto dbeta = at::empty_like(gamma); + const auto &dz_ = dz.contiguous(); + const auto &x_ = x.contiguous(); + const auto &mu_ = mu.contiguous(); + const auto &rsigma_ = rsigma.contiguous(); + const auto &gamma_ = gamma.contiguous(); + + auto dx = at::empty_like(x_); + auto dgamma = at::empty_like(gamma_); + auto dbeta = at::empty_like(gamma_); transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part; - auto dz_cu = makeTransformerEngineTensor(dz); - auto x_cu = makeTransformerEngineTensor(x); - auto mu_cu = makeTransformerEngineTensor(mu); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - auto gamma_cu = makeTransformerEngineTensor(gamma); + auto dz_cu = makeTransformerEngineTensor(dz_); + auto x_cu = makeTransformerEngineTensor(x_); + auto mu_cu = makeTransformerEngineTensor(mu_); + auto rsigma_cu = makeTransformerEngineTensor(rsigma_); + auto gamma_cu = makeTransformerEngineTensor(gamma_); auto dx_cu = makeTransformerEngineTensor(dx); auto dgamma_cu = makeTransformerEngineTensor(dgamma); auto dbeta_cu = makeTransformerEngineTensor(dbeta); @@ -63,8 +69,10 @@ std::vector layernorm_fwd_fp8(const at::Tensor &input, const at::Ten const int amax_offset, const int scale_inv_offset) { using namespace transformer_engine; - auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); - return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, scale, ln_out, amax, scale_inv, otype, + const auto &input_ = input.contiguous(); + + auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype))); + return layernorm_fwd_fp8_noalloc(input_, weight, bias, eps, scale, ln_out, amax, scale_inv, otype, sm_margin, zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset); } @@ -76,6 +84,10 @@ std::vector layernorm_fwd_fp8_noalloc( const int scale_offset, const int amax_offset, const int scale_inv_offset) { using namespace transformer_engine; + const auto &input_ = input.contiguous(); + const auto &weight_ = weight.contiguous(); + const auto &bias_ = bias.contiguous(); + // Choose kernel implementation const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; @@ -92,9 +104,9 @@ std::vector layernorm_fwd_fp8_noalloc( DType itype = GetTransformerEngineDType(input.scalar_type()); auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input); - auto gamma_cu = makeTransformerEngineTensor(weight); - auto beta_cu = makeTransformerEngineTensor(bias); + auto input_cu = makeTransformerEngineTensor(input_); + auto gamma_cu = makeTransformerEngineTensor(weight_); + auto beta_cu = makeTransformerEngineTensor(bias_); auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr, scale_inv_dptr); auto mu_cu = makeTransformerEngineTensor(mu); @@ -145,9 +157,10 @@ std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor using namespace transformer_engine; DType itype = GetTransformerEngineDType(input.scalar_type()); - auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); + const auto &input_ = input.contiguous(); + auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype))); - return layernorm_fwd_noalloc(input, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma); + return layernorm_fwd_noalloc(input_, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma); } std::vector layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, @@ -174,14 +187,19 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma) { - auto dx = at::empty_like(x); - auto dgamma = at::empty_like(gamma); + const auto &dz_ = dz.contiguous(); + const auto &x_ = x.contiguous(); + const auto &rsigma_ = rsigma.contiguous(); + const auto &gamma_ = gamma.contiguous(); + + auto dx = at::empty_like(x_); + auto dgamma = at::empty_like(gamma_); transformer_engine::TensorWrapper workspace, barrier, dgamma_part; - auto dz_cu = makeTransformerEngineTensor(dz); - auto x_cu = makeTransformerEngineTensor(x); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - auto gamma_cu = makeTransformerEngineTensor(gamma); + auto dz_cu = makeTransformerEngineTensor(dz_); + auto x_cu = makeTransformerEngineTensor(x_); + auto rsigma_cu = makeTransformerEngineTensor(rsigma_); + auto gamma_cu = makeTransformerEngineTensor(gamma_); auto dx_cu = makeTransformerEngineTensor(dx); auto dgamma_cu = makeTransformerEngineTensor(dgamma); @@ -219,8 +237,11 @@ std::vector rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tenso const int scale_inv_offset) { using namespace transformer_engine; - auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); - return rmsnorm_fwd_fp8_noalloc(input, weight, eps, scale, ln_out, amax, scale_inv, otype, + const auto &input_ = input.contiguous(); + const auto &weight_ = weight.contiguous(); + + auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype))); + return rmsnorm_fwd_fp8_noalloc(input_, weight_, eps, scale, ln_out, amax, scale_inv, otype, sm_margin, zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset); } @@ -295,10 +316,13 @@ std::vector rmsnorm_fwd(const at::Tensor &input, const at::Tensor &w const int sm_margin, const bool zero_centered_gamma) { using namespace transformer_engine; + const auto &input_ = input.contiguous(); + const auto &weight_ = weight.contiguous(); + DType itype = GetTransformerEngineDType(input.scalar_type()); - auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); + auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype))); - return rmsnorm_fwd_noalloc(input, weight, ln_out, eps, sm_margin, zero_centered_gamma); + return rmsnorm_fwd_noalloc(input_, weight_, ln_out, eps, sm_margin, zero_centered_gamma); } std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 22d7813605..262c6f8d16 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -130,12 +130,14 @@ def forward( if return_layernorm_output: # First prepare LN output in higher precision, # which will be later copied to a FP8 UB - ln_out = torch.empty_like(inputmat) + ln_out = torch.empty_like(inputmat, memory_format=torch.contiguous_format) else: ln_out = ub_obj_lnout.get_ubuf_output(0) else: ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype - ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) + ln_out = torch.empty_like( + inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format + ) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 5be8ee9e29..be6df21322 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -149,7 +149,9 @@ def forward( ln_out = ub_obj_lnout.get_ubuf_output(0) else: ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype - ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) + ln_out = torch.empty_like( + inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format + ) ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) From 91a16a3f9c8c1aedd4cd0b1da7a0da58e977412c Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 31 Jul 2024 16:44:00 -0700 Subject: [PATCH 28/72] Add more C++ tests for activations (#1049) * Added tests for silu/relu/swiglu/reglu Signed-off-by: Przemek Tredak * Fixes Signed-off-by: Przemek Tredak * Added other activations/backwards and fixed dqgelu Signed-off-by: Przemek Tredak * Fix Signed-off-by: Przemek Tredak * Fix 2 Signed-off-by: Przemek Tredak * Actually adding srelu and qgelu tests Signed-off-by: Przemek Tredak * Fix glu backward test Signed-off-by: Przemek Tredak * Pruning unnecessary test configurations Signed-off-by: Przemek Tredak --------- Signed-off-by: Przemek Tredak --- tests/cpp/operator/CMakeLists.txt | 3 +- tests/cpp/operator/test_act.cu | 456 ++++++++++++++++++++++++++ tests/cpp/operator/test_geglu.cu | 115 ------- tests/cpp/operator/test_gelu.cu | 123 ------- transformer_engine/common/util/math.h | 3 +- 5 files changed, 459 insertions(+), 241 deletions(-) create mode 100644 tests/cpp/operator/test_act.cu delete mode 100644 tests/cpp/operator/test_geglu.cu delete mode 100644 tests/cpp/operator/test_gelu.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 9dd02d4181..e302be57bd 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -9,8 +9,7 @@ add_executable(test_operator test_cast_transpose_dbias.cu test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu - test_gelu.cu - test_geglu.cu + test_act.cu test_dgeglu.cu test_layernorm.cu test_rmsnorm.cu diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu new file mode 100644 index 0000000000..7d03e41271 --- /dev/null +++ b/tests/cpp/operator/test_act.cu @@ -0,0 +1,456 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +// forward + +float gelu(const float x) { + return 0.5f * x * (1.0f + tanhf(0.79788456F * x * (1.0f + 0.044715f * x * x))); +} + +float silu(const float x) { + return x / (1 + expf(-x)); +} + +float relu(const float x) { + return x > 0 ? x : 0; +} + +float srelu(const float x) { + return x > 0 ? x * x : 0; +} + +float qgelu(const float x) { + return x / (1 + expf(-1.702f * x)); +} + +// backward + +float dgelu(const float x) { + const float tanh_out = tanhf(0.79788456f * x * (1.f + 0.044715f * x * x)); + return 0.5f * x * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) + + 0.5f * (1.f + tanh_out); +} + +float dsilu(const float x) { + const float sigmoid = 1.f / (1 + expf(-x)); + return x * sigmoid * (1.f - sigmoid) + sigmoid; +} + +float drelu(const float x) { + return x > 0.f ? 1.f : 0.f; +} + +float dsrelu(const float x) { + return fmaxf(2.f * x, 0.f); +} + +float dqgelu(const float x) { + const float sigmoid = 1.f / (1 + expf(-1.702f * x)); + return 1.702f * x * sigmoid * (1.f - sigmoid) + sigmoid; +} + +} // namespace + +template +void compute_ref_act_cast(const IT *input_h, + OT *output_h, + const CT scale, + CT *amax_h, + const size_t N, + const size_t H) { + CT amax = 0.; + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < H; j++) { + CT elt = static_cast(input_h[i * H + j]); + elt = act(elt); + output_h[i * H + j] = static_cast(scale * elt); + amax = std::abs(elt) > amax ? std::abs(elt) : amax; + } + } + + *amax_h = amax; +} + +template +void compute_ref_dact_cast(const IT *input_h, + const IT *grad_h, + OT *output_h, + const size_t N, + const size_t H) { + using CT = float; + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < H; j++) { + CT elt = static_cast(input_h[i * H + j]); + elt = dact(elt); + CT grad = static_cast(grad_h[i * H + j]); + output_h[i * H + j] = static_cast(grad * elt); + } + } +} + +template +void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, CT *amax_h, + const size_t N, const size_t H) { + CT amax = 0.; + + const int col = H * 2; + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < H; j++) { + CT gelu_elt = static_cast(input_h[i * col + j]); + gelu_elt = act(gelu_elt); + CT gate_elt = static_cast(input_h[i * col + H + j]); + CT elt = gelu_elt * gate_elt; + output_h[i * H + j] = static_cast(scale * elt); + amax = std::abs(elt) > amax ? std::abs(elt) : amax; + } + } + + *amax_h = amax; +} + +template +void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h, + const size_t N, const size_t H) { + const int col = H * 2; + using CT = float; + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < H; j++) { + CT grad = static_cast(grad_h[i * H + j]); + CT gelu_elt = static_cast(input_h[i * col + j]); + CT gate_elt = static_cast(input_h[i * col + H + j]); + output_h[i * col + H + j] = static_cast(grad * act(gelu_elt)); + gelu_elt = dact(gelu_elt); + CT elt = gelu_elt * gate_elt; + output_h[i * col + j] = static_cast(grad * elt); + } + } +} + + +template +void performTest(const size_t N, const size_t H) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + Tensor input({ N, H }, itype); + Tensor output({ N, H }, otype); + Tensor igrad({ N, H }, itype); + Tensor ograd({ N, H }, itype); + + fillUniform(&input); + fillUniform(&ograd); + setRandomScale(&output); + + std::unique_ptr ref_output = std::make_unique(N*H); + std::unique_ptr ref_igrad = std::make_unique(N*H); + + nvte_act(input.data(), output.data(), 0); + + float ref_amax; + compute_ref_act_cast(input.cpu_dptr(), ref_output.get(), + output.scale(), &ref_amax, N, H); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_act", output, ref_output.get(), atol, rtol); + + nvte_dact(ograd.data(), input.data(), igrad.data(), 0); + + compute_ref_dact_cast(input.cpu_dptr(), ograd.cpu_dptr(), + ref_igrad.get(), N, H); + + cudaDeviceSynchronize(); + err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + { + auto [atol, rtol] = getTolerances(otype); + compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol); + } +} + +template +void performTestGLU(const size_t N, const size_t H) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + Tensor input({N, H * 2}, itype); + Tensor output({N, H}, otype); + Tensor igrad({ N, H * 2 }, itype); + Tensor ograd({ N, H }, itype); + + fillUniform(&input); + fillUniform(&ograd); + setRandomScale(&output); + + std::unique_ptr ref_output = std::make_unique(N * H); + std::unique_ptr ref_igrad = std::make_unique(2 * N * H); + + nvte_act(input.data(), output.data(), 0); + + float ref_amax; + compute_ref_glu_act_cast(input.cpu_dptr(), ref_output.get(), + output.scale(), &ref_amax, N, H); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_gelu", output, ref_output.get(), atol, rtol); + + nvte_dact(ograd.data(), input.data(), igrad.data(), 0); + + compute_ref_dglu_act_cast(input.cpu_dptr(), ograd.cpu_dptr(), + ref_igrad.get(), N, H); + + cudaDeviceSynchronize(); + err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + { + auto [atol, rtol] = getTolerances(otype); + compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol); + } +} + + +class ActTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(ActTestSuite, TestGELU) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size.first, size.second); + ); + ); +} + +TEST_P(ActTestSuite, TestSILU) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size.first, size.second); + ); + ); +} + +TEST_P(ActTestSuite, TestRELU) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size.first, size.second); + ); + ); +} + +TEST_P(ActTestSuite, TestQGELU) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size.first, size.second); + ); + ); +} + +TEST_P(ActTestSuite, TestSRELU) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size.first, size.second); + ); + ); +} + +TEST_P(ActTestSuite, TestGeGLU) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + output_type, OutputType, + performTestGLU(size.first, size.second););); +} + +TEST_P(ActTestSuite, TestReGLU) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + output_type, OutputType, + performTestGLU(size.first, size.second););); +} + +TEST_P(ActTestSuite, TestSwiGLU) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + output_type, OutputType, + performTestGLU(size.first, size.second););); +} + +TEST_P(ActTestSuite, TestQGeGLU) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + output_type, OutputType, + performTestGLU(size.first, size.second););); +} + +TEST_P(ActTestSuite, TestSReGLU) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + output_type, OutputType, + performTestGLU(size.first, size.second););); +} + +namespace { + +std::vector> act_test_cases = {{2048, 12288}, + {768, 2816}, + {256, 65536}, + {65536, 128}, + {256, 256}, + {257, 259}, + {128, 128+1}}; + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + ActTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::ValuesIn(test::all_fp_types), + ::testing::ValuesIn(act_test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)) + "X" + + std::to_string(std::get<2>(info.param).first) + "X" + + std::to_string(std::get<2>(info.param).second); + return name; + }); diff --git a/tests/cpp/operator/test_geglu.cu b/tests/cpp/operator/test_geglu.cu deleted file mode 100644 index f25c2e1d23..0000000000 --- a/tests/cpp/operator/test_geglu.cu +++ /dev/null @@ -1,115 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include "../test_common.h" - -using namespace transformer_engine; - -template -void compute_ref_geglu_cast(const IT *input_h, OT *output_h, const CT scale, CT *amax_h, - const size_t N, const size_t H) { - CT amax = 0.; - - const int col = H * 2; - - for (size_t i = 0; i < N; i++) { - for (size_t j = 0; j < H; j++) { - CT gelu_elt = CT(input_h[i * col + j]); - gelu_elt = 0.5f * gelu_elt * - (1.0f + tanhf(0.79788456F * gelu_elt * (1.0f + 0.044715f * gelu_elt * gelu_elt))); - CT gate_elt = CT(input_h[i * col + H + j]); - CT elt = gelu_elt * gate_elt; - output_h[i * H + j] = OT(scale * elt); - amax = std::abs(elt) > amax ? std::abs(elt) : amax; - } - } - - *amax_h = amax; -} - -template -void performTestGEGLU(const size_t N, const size_t H) { - using namespace test; - - DType itype = TypeInfo::dtype; - DType otype = TypeInfo::dtype; - - Tensor input({N, H * 2}, itype); - Tensor output({N, H}, otype); - - fillUniform(&input); - setRandomScale(&output); - - std::unique_ptr ref_output = std::make_unique(N * H); - - nvte_geglu(input.data(), output.data(), 0); - - float ref_amax; - compute_ref_geglu_cast(input.cpu_dptr(), ref_output.get(), output.scale(), &ref_amax, N, - H); - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); - } - auto [atol, rtol] = getTolerances(otype); - compareResults("output_gelu", output, ref_output.get(), atol, rtol); -} - -class GeGLUTestSuite - : public ::testing::TestWithParam>> {}; - -TEST_P(GeGLUTestSuite, TestGeGLU) { - using namespace transformer_engine; - using namespace test; - - const DType input_type = std::get<0>(GetParam()); - const DType output_type = std::get<1>(GetParam()); - const auto size = std::get<2>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - output_type, OutputType, - performTestGEGLU(size.first, size.second););); -} - -namespace { - -std::vector> test_cases = { - {4096, 2048}, {768, 2816}, {256, 5120}, {128, 10240}, {256, 256}, {257, 259}, {128, 128 + 1}}; - -} // namespace - -INSTANTIATE_TEST_SUITE_P( - OperatorTest, GeGLUTestSuite, - ::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::ValuesIn(test::all_fp_types), ::testing::ValuesIn(test_cases)), - [](const testing::TestParamInfo &info) { - std::string name = test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second); - return name; - }); diff --git a/tests/cpp/operator/test_gelu.cu b/tests/cpp/operator/test_gelu.cu deleted file mode 100644 index d759aa4315..0000000000 --- a/tests/cpp/operator/test_gelu.cu +++ /dev/null @@ -1,123 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include "../test_common.h" - -using namespace transformer_engine; - -template -void compute_ref_gelu_cast(const IT *input_h, - OT *output_h, - const CT scale, - CT *amax_h, - const size_t N, - const size_t H) { - CT amax = 0.; - - for (size_t i = 0; i < N; i++) { - for (size_t j = 0; j < H; j++) { - CT elt = CT(input_h[i * H + j]); - elt = 0.5f * elt * (1.0f + tanhf(0.79788456F * elt * - (1.0f + 0.044715f * elt * elt))); - output_h[i * H + j] = OT(scale * elt); - amax = std::abs(elt) > amax ? std::abs(elt) : amax; - } - } - - *amax_h = amax; -} - -template -void performTestGelu(const size_t N, const size_t H) { - using namespace test; - - DType itype = TypeInfo::dtype; - DType otype = TypeInfo::dtype; - - Tensor input({ N, H }, itype); - Tensor output({ N, H }, otype); - - fillUniform(&input); - setRandomScale(&output); - - std::unique_ptr ref_output = std::make_unique(N*H); - - nvte_gelu(input.data(), output.data(), 0); - - float ref_amax; - compute_ref_gelu_cast(input.cpu_dptr(), ref_output.get(), - output.scale(), &ref_amax, N, H); - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); - } - auto [atol, rtol] = getTolerances(otype); - compareResults("output_gelu", output, ref_output.get(), atol, rtol); -} - -class GELUTestSuite : public ::testing::TestWithParam>> {}; - -TEST_P(GELUTestSuite, TestGELU) { - using namespace transformer_engine; - using namespace test; - - const DType input_type = std::get<0>(GetParam()); - const DType output_type = std::get<1>(GetParam()); - const auto size = std::get<2>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTestGelu(size.first, size.second); - ); - ); -} - -namespace { - -std::vector> gelu_test_cases = {{2048, 12288}, - {768, 1024}, - {256, 65536}, - {65536, 128}, - {256, 256}, - {257, 259}, - {128, 128+1}}; - -} // namespace - -INSTANTIATE_TEST_SUITE_P( - OperatorTest, - GELUTestSuite, - ::testing::Combine( - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::ValuesIn(test::all_fp_types), - ::testing::ValuesIn(gelu_test_cases)), - [](const testing::TestParamInfo& info) { - std::string name = test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second); - return name; - }); diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 2625c97e79..26204cddb8 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -47,7 +47,8 @@ __device__ inline OType qgelu(const IType val, const Empty& e) { template __device__ inline OType dqgelu(const IType val, const Empty& e) { const float cval = val; - return cval * dsigmoid(1.702f * cval, e) + sigmoid(1.702f * cval, e); + return 1.702f * cval * dsigmoid(1.702f * cval, e) + + sigmoid(1.702f * cval, e); } template From 701173062ee2b35658274eb57dd19afe1b7831a5 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 1 Aug 2024 23:02:42 +0800 Subject: [PATCH 29/72] [Bugfix] Fixes for multi-stream cuBLAS (#1045) * fix workspaces and unfused bias in multi-stream cuBLAS * Expose num_streams via pybind * Fix C-compatibility * rm importing packaging in test_fused_attn.py --------- Signed-off-by: Xin Yao Co-authored-by: Phuong Nguyen --- tests/pytorch/test_numerics.py | 20 ++++++++++++++++++- .../common/gemm/cublaslt_gemm.cu | 12 +++++------ .../common/include/transformer_engine/gemm.h | 9 +++++---- .../pytorch/csrc/extensions/gemm.cu | 13 +++++++----- .../pytorch/csrc/extensions/pybind.cpp | 1 + transformer_engine/pytorch/module/base.py | 3 +-- .../pytorch/module/grouped_linear.py | 10 +++++++++- 7 files changed, 49 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 6c967d78e9..7eed97a0ca 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1261,7 +1261,9 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean) -def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_params): +def test_grouped_linear_accuracy( + dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None +): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -1276,6 +1278,7 @@ def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_par 4 * config.hidden_size, bias=True, params_dtype=dtype, + parallel_mode=parallel_mode, device="cuda", ).eval() sequential_linear = torch.nn.ModuleList( @@ -1285,6 +1288,7 @@ def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_par 4 * config.hidden_size, bias=True, params_dtype=dtype, + parallel_mode=parallel_mode, device="cuda", ).eval() for _ in range(num_gemms) @@ -1307,6 +1311,20 @@ def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_par torch.testing.assert_close(o, o_ref, rtol=0, atol=0) +@pytest.mark.parametrize("parallel_mode", ["column", "row"]) +def test_grouped_linear_accuracy_parallel_mode(parallel_mode): + """Split the tests to reduce CI time""" + test_grouped_linear_accuracy( + dtype=torch.float32, + num_gemms=6, + bs=2, + model=list(model_configs.keys())[0], + fp8=True, + fp8_model_params=True, + parallel_mode=parallel_mode, + ) + + def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): reset_rng_states() diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 30161b68c0..c9b57752e2 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -378,10 +378,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); } -void nvte_multi_stream_cublas_gemm(std::vector A, std::vector B, - std::vector D, std::vector bias, - std::vector pre_gelu_out, bool transa, bool transb, - bool grad, std::vector workspace, bool accumulate, +void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, + const int num_gemms, bool transa, bool transb, bool grad, + NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_stream_cublas_gemm); @@ -389,14 +389,14 @@ void nvte_multi_stream_cublas_gemm(std::vector A, std::vector(A.size())); + int num_stream_used = std::min(num_streams, num_gemms); // wait for current stream to finish NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream)); for (int s = 0; s < num_stream_used; s++) { NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0])); } - for (size_t i = 0; i < A.size(); i++) { + for (int i = 0; i < num_gemms; i++) { nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, compute_streams[i % num_streams]); diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 28534dafd4..1cdbfd2eb5 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -92,6 +92,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor * \param[in,out] D List of output matrices. * \param[in] bias List of bias tensors. * \param[in,out] pre_gelu_out List of output matrix before GELU activation. + * \param[in] num_gemms Number of GEMMs to compute. * \param[in] transa Whether A matrix is transposed. * \param[in] transb Whether B matrix is transposed. * \param[in] grad Whether this operation is part of the @@ -102,10 +103,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream to wait on. */ -void nvte_multi_stream_cublas_gemm(std::vector A, std::vector B, - std::vector D, std::vector bias, - std::vector pre_gelu_out, bool transa, bool transb, - bool grad, std::vector workspace, bool accumulate, +void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + const NVTETensor* bias, NVTETensor* pre_gelu_out, + const int num_gemms, bool transa, bool transb, bool grad, + NVTETensor* workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); #ifdef __cplusplus diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index bd698ded27..01fb94cab4 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -134,12 +134,15 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int te_pre_gelu_out.emplace_back(make_tensor( pre_gelu_out[i].data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); - te_workspace.emplace_back(make_tensor(workspace[i % num_streams].data_ptr(), {workspaceSize}, - DType::kByte, nullptr, nullptr, nullptr)); + } + for (size_t i = 0; i < workspace.size(); i++) { + te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, + nullptr, nullptr, nullptr)); } // For now, we only have multi-stream cublas backend. - nvte_multi_stream_cublas_gemm(te_A, te_B, te_D, te_bias, te_pre_gelu_out, transa, transb, grad, - te_workspace, accumulate, use_split_accumulator, math_sm_count, - at::cuda::getCurrentCUDAStream()); + nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), + te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, + te_workspace.data(), accumulate, use_split_accumulator, + math_sm_count, at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index f568f4659d..89bce77ded 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -153,6 +153,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version", py::call_guard()); + m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams); // Support THD format for Context Parallel m.def("thd_read_half_tensor", &thd_read_half_tensor, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6feda77c70..cbcda20fe8 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -48,7 +48,6 @@ _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 -_NUM_MAX_CUBLAS_STREAMS = 4 layers_atomic_ring_exchange = [] @@ -73,7 +72,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: """Returns workspace for multi-stream cublas.""" global _multi_stream_cublas_workspace if not _multi_stream_cublas_workspace: - for _ in range(_NUM_MAX_CUBLAS_STREAMS): + for _ in range(tex._num_cublas_streams): _multi_stream_cublas_workspace.append( torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda") ) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 050ff6a02e..8aeb068412 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -829,7 +829,15 @@ def forward( out = linear_fn(*args) if self.gemm_bias_unfused_add: - out = [o + cast_if_needed(b, self.activation_dtype) for o, b in zip(out, bias_tensors)] + out_shape = out.shape + out = torch.cat( + [ + o + cast_if_needed(b, self.activation_dtype) + for o, b in zip( + torch.split(out.view(-1, self.out_features), m_splits), bias_tensors + ) + ] + ).view(out_shape) if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] From 9c127ef5b1ee10d466acf633f4a0ad3c8914cf1b Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Thu, 1 Aug 2024 11:24:37 -0700 Subject: [PATCH 30/72] Fix context parallelism implementation with THD format (#1012) * use 2hd layout Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change qkv_format check Signed-off-by: Xiaowei Ren * add a code comment Signed-off-by: Xiaowei Ren * tensor shape bug fix Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor shape fix Signed-off-by: Xiaowei Ren * add function to compute cu_seqlens of a cp rank Signed-off-by: Xiaowei Ren * add cu_seqlens and cu_seqlens_padded to context parallelism Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * fix FlashAttention output sequence length Signed-off-by: Xiaowei Ren * fix cu_seqlens_kv_per_step calculation Signed-off-by: Xiaowei Ren * zero dQKV for ending padded tokens Signed-off-by: Xiaowei Ren * zero dQKV tensors of FlashAttention Signed-off-by: Xiaowei Ren * fix softmax_lse correction Signed-off-by: Xiaowei Ren * remove padded tokens of KV to save comounication Signed-off-by: Xiaowei Ren * do not need to zero dkv for FlashAttention any mroe Signed-off-by: Xiaowei Ren * zero out tensors Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * fix CP unit test Signed-off-by: Xiaowei Ren * fix kv shape of cp test with thd format Signed-off-by: Xiaowei Ren * update cp unit test Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove redundant code Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: Xiaowei Ren --- .../fused_attn/run_fused_attn_with_cp.py | 101 +++- transformer_engine/pytorch/attention.py | 480 ++++++++++++------ .../pytorch/csrc/extensions/attention.cu | 2 +- 3 files changed, 413 insertions(+), 170 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 9b9b7686c2..c8f3c8c458 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist from transformer_engine.pytorch.attention import DotProductAttention +from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank import transformer_engine_torch as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn @@ -86,6 +87,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ) cu_seqlens_q = None cu_seqlens_kv = None + cu_seqlens_q_padded = None + cu_seqlens_kv_padded = None elif qkv_format == "sbhd": q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim) kv_input_shape = ( @@ -101,18 +104,36 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ) cu_seqlens_q = None cu_seqlens_kv = None + cu_seqlens_q_padded = None + cu_seqlens_kv_padded = None elif qkv_format == "thd": - seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to( - torch.int32 + q_input_shape = (config.batch_size * config.max_seqlen_q, config.num_heads, config.head_dim) + kv_input_shape = ( + config.batch_size * config.max_seqlen_q, + config.num_gqa_groups, + config.head_dim, + ) + attn_output_shape = ( + config.batch_size * config.max_seqlen_q, + config.num_heads * config.head_dim, ) - seqlens_q = seqlens_q - seqlens_q % (world_size * 2) - cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)]) + seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) + seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2) + cu_seqlens_q_padded = torch.cat( + [ + torch.zeros([1], dtype=torch.int32), + seqlens_q_padded.cumsum(0, dtype=torch.int32), + torch.tensor([q_input_shape[0]], dtype=torch.int32), + ] + ).cuda() + if kernel_backend == "FlashAttention": + cu_seqlens_q = cu_seqlens_q_padded[:-1] + else: + cu_seqlens_q = torch.cat( + [torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)] + ).cuda() cu_seqlens_kv = cu_seqlens_q - q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim) - kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim) - attn_output_shape = (cu_seqlens_q[-1], config.num_heads * config.head_dim) - cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda() - cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda() + cu_seqlens_kv_padded = cu_seqlens_q_padded else: assert False, f"{qkv_format} is an unsupported qkv_format!" @@ -132,7 +153,7 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= for x in [q, k, v, dout] + ([] if bias is None else [bias]): dist.broadcast(x, 0, group=cp_comm_group) if qkv_format == "thd": - for x in [cu_seqlens_q, cu_seqlens_kv]: + for x in [cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded]: dist.broadcast(x, 0, group=cp_comm_group) # run core_attn without CP @@ -146,6 +167,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= core_attention_bias=bias, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], + cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1], ) out.backward(dout) @@ -171,12 +194,14 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] ] elif qkv_format == "thd": - seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank) - seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank) + seq_idx_q = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, q_.shape[0], world_size, rank + ) + seq_idx_kv = tex.thd_get_partitioned_indices( + cu_seqlens_kv_padded, k_.shape[0], world_size, rank + ) q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] - cu_seqlens_q = cu_seqlens_q // world_size - cu_seqlens_kv = cu_seqlens_kv // world_size else: assert False, f"{qkv_format} is an unsupported qkv_format!" q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] @@ -187,8 +212,6 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream()) - max_seqlen_q = config.max_seqlen_q - max_seqlen_kv = config.max_seqlen_kv out_ = core_attn( q_, k_, @@ -197,8 +220,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= core_attention_bias=bias_, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, + cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], + cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1], ) out_.backward(dout_) @@ -230,9 +253,45 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= for x in [q_.grad, k_.grad, v_.grad, out_] ] elif qkv_format == "thd": - dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]] - dq_, dk_, dv_, out_ = [x.view(-1) for x in [q_.grad, k_.grad, v_.grad, out_]] + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] + dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] + cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size + cu_seqlens_q = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True + ) + cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q + num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] + for x in [dq, out, dq_, out_]: + assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_q[b] == 0 + or torch.count_nonzero( + x[(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[b + 1]] + ).item() + == 0 + ) + cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size + cu_seqlens_kv = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True + ) + cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv + num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] + for x in [dk, dv, dk_, dv_]: + assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_kv[b] == 0 + or torch.count_nonzero( + x[ + (cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]) : cu_seqlens_kv_padded[ + b + 1 + ] + ] + ).item() + == 0 + ) else: assert False, f"{qkv_format} is an unsupported qkv_format!" diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index fa72ecfa33..8aaa76a177 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1180,6 +1180,27 @@ def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step): softmax_lse.copy_(new_scale) +@jit_fuser +def get_cu_seqlens_on_cp_rank( + cu_seqlens, cu_seqlens_padded_on_cp_rank, cp_size, cp_rank, first_half, second_half +): + """Compute cu_seqlens of a context parallelism rank""" + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2 + zeros = torch.zeros_like(seqlens) + cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens) + if first_half: + seqlens_1 = seqlens - cp_rank * seqlens_padded + seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded) + cu_seqlens_on_cp_rank[1:].add_(seqlens_1) + if second_half: + seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded + seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded) + cu_seqlens_on_cp_rank[1:].add_(seqlens_2) + cu_seqlens_on_cp_rank.cumsum_(dim=0) + return cu_seqlens_on_cp_rank + + class AttnFuncWithCP(torch.autograd.Function): """ Attention implementation with context parallelism. @@ -1195,9 +1216,9 @@ def forward( k, v, cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_kv, max_seqlen_q, - max_seqlen_k, + max_seqlen_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, dropout_p, @@ -1224,7 +1245,19 @@ def forward( causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + if qkv_format in ["bshd", "sbhd"]: + qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] + else: + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + + pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) + pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) + max_seqlen_q = max_seqlen_q // cp_size + max_seqlen_kv = max_seqlen_kv // cp_size + cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size + cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size + cu_seqlens_q_per_step = [None for _ in range(cp_size)] + cu_seqlens_kv_per_step = [None for _ in range(cp_size)] if causal: if qkv_format == "bshd": @@ -1233,6 +1266,9 @@ def forward( elif qkv_format == "sbhd": # [s, b, np, hn] -> [2, s//2, b, np, hn] q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] + total_tokens_kv = None if qkv_format != "thd" else k.shape[0] + # remove padded tokens at the end + k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]] if attn_bias is not None: assert len(attn_bias.shape) == 4, ( "Only support bias shape of [b, h, sq, sk] for forward, " @@ -1273,7 +1309,10 @@ def forward( fwd_results_correction_done = torch.cuda.Event() p2p_comm_buffers = [None for _ in range(cp_size)] - p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) + if use_fused_attention and qkv_format in ["bshd", "sbhd"]: + p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) + else: + p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) send_recv_reqs = [[], []] for i in range(cp_size + 1): @@ -1298,19 +1337,33 @@ def forward( kv_inputs[i % 2] = p2p_comm_buffers[i] if causal: if i == 0: + if pad_between_seqs_q: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + if pad_between_seqs_kv: + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True + ) + else: + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size if use_fused_attention: if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view( - 2, k.shape[0], -1, *k.shape[-2:] + k.shape[0], -1, 2, *k.shape[-2:] ) elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) elif qkv_format == "thd": q_inputs[i % 2] = q if attn_bias is not None: @@ -1326,12 +1379,20 @@ def forward( fused_attn_fwd( is_training, max_seqlen_q, - max_seqlen_k, - cu_seqlens_q, - cu_seqlens_k, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), TE_DType[q.dtype], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=softmax_scale, @@ -1364,10 +1425,10 @@ def forward( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], - cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], max_seqlen_q, - max_seqlen_k, + max_seqlen_kv, dropout_p, softmax_scale, causal=True, @@ -1375,22 +1436,39 @@ def forward( **fa_optional_forward_kwargs, ) elif i <= rank: + if pad_between_seqs_q: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + if pad_between_seqs_kv: + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + False, + ) + else: + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) if use_fused_attention: if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous() + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous() elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous() + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous() elif qkv_format == "thd": q_inputs[i % 2] = q # [2, t, np, hn] -> [2, t/2, np, hn] kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_k, 0 + kv_inputs[i % 2], cu_seqlens_kv_padded, 0 ) if attn_bias is not None: idx = (rank - i) % cp_size @@ -1399,12 +1477,20 @@ def forward( fused_attn_fwd( is_training, max_seqlen_q, - max_seqlen_k // 2, - cu_seqlens_q, - cu_seqlens_k // 2, + max_seqlen_kv // 2, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), TE_DType[q.dtype], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=softmax_scale, @@ -1429,7 +1515,7 @@ def forward( if qkv_format == "thd": # [2, t, np, hn] -> [2, t/2, np, hn] kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_k, 0 + kv_inputs[i % 2], cu_seqlens_kv_padded, 0 ) else: # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] @@ -1451,10 +1537,10 @@ def forward( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], - cu_seqlens_q, - cu_seqlens_k // 2, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], max_seqlen_q, - max_seqlen_k // 2, + max_seqlen_kv // 2, dropout_p, softmax_scale, causal=False, @@ -1462,22 +1548,43 @@ def forward( **fa_optional_forward_kwargs, ) else: + if pad_between_seqs_q: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True + ) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) + if pad_between_seqs_kv: + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + True, + ) + else: + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size if use_fused_attention: if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] q_inputs[i % 2] = q[:, 1, ...].contiguous() - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view( - 2, k.shape[0], -1, *k.shape[-2:] + k.shape[0], -1, 2, *k.shape[-2:] ) elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] q_inputs[i % 2] = q[1].contiguous() - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) elif qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) + q_inputs[i % 2] = tex.thd_read_half_tensor( + q, cu_seqlens_q_padded, 1 + ) if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = torch.cat( @@ -1491,12 +1598,20 @@ def forward( fused_attn_fwd( is_training, max_seqlen_q // 2, - max_seqlen_k, - cu_seqlens_q // 2, - cu_seqlens_k, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), TE_DType[q.dtype], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=softmax_scale, @@ -1518,7 +1633,9 @@ def forward( else: if qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) + q_inputs[i % 2] = tex.thd_read_half_tensor( + q, cu_seqlens_q_padded, 1 + ) else: # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn] q_inputs[i % 2] = ( @@ -1541,10 +1658,10 @@ def forward( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], - cu_seqlens_q // 2, - cu_seqlens_k, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], max_seqlen_q // 2, - max_seqlen_k, + max_seqlen_kv, dropout_p, softmax_scale, causal=False, @@ -1552,6 +1669,23 @@ def forward( **fa_optional_forward_kwargs, ) else: + if pad_between_seqs_q: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + if pad_between_seqs_kv: + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + True, + ) + else: + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size if use_fused_attention: if attn_bias is not None: idx = (rank - i) % cp_size @@ -1566,12 +1700,20 @@ def forward( fused_attn_fwd( is_training, max_seqlen_q, - max_seqlen_k, - cu_seqlens_q, - cu_seqlens_k, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], q, - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), TE_DType[q.dtype], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=softmax_scale, @@ -1604,10 +1746,10 @@ def forward( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], - cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], max_seqlen_q, - max_seqlen_k, + max_seqlen_kv, dropout_p, softmax_scale, causal=False, @@ -1626,7 +1768,7 @@ def forward( with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if i == 1: - out = torch.empty_like(q).zero_() + out = torch.zeros_like(q) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) if causal and qkv_format != "thd": # [b, np, sq] -> [b, np, 2, sq//2] @@ -1640,7 +1782,10 @@ def forward( else: if qkv_format == "thd": tex.thd_second_half_lse_correction( - softmax_lse, softmax_lse_per_step[i - 1], cu_seqlens_q, q.size(0) + softmax_lse, + softmax_lse_per_step[i - 1], + cu_seqlens_q_padded, + max_seqlen_q, ) else: flash_attn_fwd_softmax_lse_correction( @@ -1678,7 +1823,7 @@ def forward( out_per_step[i], softmax_lse, softmax_lse_per_step[i], - cu_seqlens_q, + cu_seqlens_q_padded, False, ) else: @@ -1698,7 +1843,7 @@ def forward( out_per_step[i], softmax_lse, softmax_lse_per_step[i], - cu_seqlens_q, + cu_seqlens_q_padded, True, ) else: @@ -1718,18 +1863,19 @@ def forward( kv, out, softmax_lse, - cu_seqlens_q, - cu_seqlens_k, cu_seqlens_q_padded, cu_seqlens_kv_padded, + *cu_seqlens_q_per_step, + *cu_seqlens_kv_per_step, *rng_states, *attn_biases, ) ctx.cp_group = cp_group ctx.cp_global_ranks = cp_global_ranks ctx.dropout_p = dropout_p + ctx.total_tokens_kv = total_tokens_kv ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k + ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type @@ -1741,20 +1887,24 @@ def forward( @staticmethod def backward(ctx, dout): - (q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6] - (cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[6:8] cp_size = get_distributed_world_size(ctx.cp_group) - rng_states = ctx.saved_tensors[8 : 8 + cp_size] - attn_biases = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] - rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] + cu_seqlens_q_per_step = ctx.saved_tensors[6 : 6 + cp_size] + cu_seqlens_kv_per_step = ctx.saved_tensors[6 + cp_size : 6 + cp_size * 2] + rng_states = ctx.saved_tensors[6 + cp_size * 2 : 6 + cp_size * 3] + attn_biases = ctx.saved_tensors[6 + cp_size * 3 : 6 + cp_size * 4] + causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + if ctx.qkv_format in ["bshd", "sbhd"]: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] + else: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format if attn_biases[0] is not None: # [b, np, sq, 2*cp, sk//(2*cp)] @@ -1770,7 +1920,9 @@ def backward(ctx, dout): if causal: if ctx.qkv_format == "thd": - softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0)) + softmax_lse_ = tex.thd_read_second_half_lse( + softmax_lse, cu_seqlens_q_padded, ctx.max_seqlen_q + ) else: # [b, np, sq] -> [b, np, 2, sq//2] softmax_lse_ = softmax_lse.view( @@ -1788,6 +1940,8 @@ def backward(ctx, dout): dout = dout.view(*q.shape) # Flash Attn outputs dq = torch.empty_like(q) + if ctx.qkv_format == "thd" and causal: + dq[cu_seqlens_q_padded[-1] :].fill_(0) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), @@ -1828,16 +1982,16 @@ def backward(ctx, dout): if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] out_ = out.view(out.shape[0], -1, *out.shape[-2:]) dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_ = q.view(-1, *q.shape[-3:]) - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) # [2, sq//2, b, np, hn] -> [sq, b, np, hn] out_ = out.view(-1, *out.shape[-3:]) dout_ = dout.view(-1, *dout.shape[-3:]) @@ -1848,12 +2002,12 @@ def backward(ctx, dout): aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, - ctx.max_seqlen_k, - cu_seqlens_q, - cu_seqlens_k, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, TE_DType[q.dtype], @@ -1871,7 +2025,7 @@ def backward(ctx, dout): else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.empty_like(q_) + dq_ = torch.zeros_like(q_) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_ = kv.view(2, -1, *kv.shape[-2:]) dkv_ = torch.empty_like(kv_) @@ -1890,10 +2044,10 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], - cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, - ctx.max_seqlen_k, + ctx.max_seqlen_kv, ctx.dropout_p, ctx.softmax_scale, True, @@ -1905,34 +2059,34 @@ def backward(ctx, dout): if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_ = kv[:, :, 0, ...].contiguous() + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_ = kv[:, 0, ...].contiguous() # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] out_ = out.view(out.shape[0], -1, *out.shape[-2:]) dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_ = q.view(-1, *q.shape[-3:]) - # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn] - kv_ = kv[:, 0, ...].contiguous() + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_ = kv[0].contiguous() # [2, sq//2, b, np, hn] -> [sq, b, np, hn] out_ = out.view(-1, *out.shape[-3:]) dout_ = dout.view(-1, *dout.shape[-3:]) elif ctx.qkv_format == "thd": q_, out_, dout_ = q, out, dout # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0) + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, - ctx.max_seqlen_k // 2, - cu_seqlens_q, - cu_seqlens_k // 2, + ctx.max_seqlen_kv // 2, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, TE_DType[q.dtype], @@ -1952,10 +2106,10 @@ def backward(ctx, dout): else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.empty_like(q_) + dq_ = torch.zeros_like(q_) if ctx.qkv_format == "thd": # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0) + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) else: # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn] kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:]) @@ -1975,10 +2129,10 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], - cu_seqlens_q, - cu_seqlens_k // 2, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, - ctx.max_seqlen_k // 2, + ctx.max_seqlen_kv // 2, ctx.dropout_p, ctx.softmax_scale, False, @@ -1990,36 +2144,36 @@ def backward(ctx, dout): if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] q_ = q[:, 1, ...].contiguous() - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] out_ = out[:, 1, ...].contiguous() dout_ = dout[:, 1, ...].contiguous() elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] q_ = q[1].contiguous() - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] out_ = out[1].contiguous() dout_ = dout[1].contiguous() elif ctx.qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] - q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) - out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1) - dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1) + q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1) + out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) + dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) kv_ = kv aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q // 2, - ctx.max_seqlen_k, - cu_seqlens_q // 2, - cu_seqlens_k, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, TE_DType[q.dtype], @@ -2039,17 +2193,17 @@ def backward(ctx, dout): else: if ctx.qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] - q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) + q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1) else: # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) - dq_ = torch.empty_like(q_) + dq_ = torch.zeros_like(q_) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_ = kv.view(2, -1, *kv.shape[-2:]) dkv_ = torch.empty_like(kv_) if ctx.qkv_format == "thd": - out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1) - dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1) + out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) + dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) else: # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) @@ -2066,10 +2220,10 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], - cu_seqlens_q // 2, - cu_seqlens_k, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q // 2, - ctx.max_seqlen_k, + ctx.max_seqlen_kv, ctx.dropout_p, ctx.softmax_scale, False, @@ -2083,12 +2237,12 @@ def backward(ctx, dout): aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, - ctx.max_seqlen_k, - cu_seqlens_q, - cu_seqlens_k, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], q, - kv[0], - kv[1], + kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], + kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], out, dout, TE_DType[q.dtype], @@ -2106,7 +2260,7 @@ def backward(ctx, dout): else: # [b, sq, np, hn] -> [b*sq, np, hn] q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.empty_like(q_) + dq_ = torch.zeros_like(q_) # [2, b, sk, np, hn] -> [2, b*sk, np, hn] kv_ = kv.view(2, -1, *kv.shape[-2:]) dkv_ = torch.empty_like(kv_) @@ -2125,10 +2279,10 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], - cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, - ctx.max_seqlen_k, + ctx.max_seqlen_kv, ctx.dropout_p, ctx.softmax_scale, False, @@ -2162,21 +2316,21 @@ def backward(ctx, dout): dq[0].copy_(dq_[0]) dq[1].add_(dq_[1]) elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add") + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add") elif i > 0: if ctx.qkv_format == "bshd": dq[:, 1, ...].add_(dq_) elif ctx.qkv_format == "sbhd": dq[1].add_(dq_) elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add") + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add") else: if ctx.qkv_format == "bshd": dq[:, 1, ...].copy_(dq_) elif ctx.qkv_format == "sbhd": dq[1].copy_(dq_) elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy") + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy") else: if i == 0: dq.copy_(dq_) @@ -2206,6 +2360,10 @@ def backward(ctx, dout): dkv = p2p_comm_buffers[(i + 1) % 2][1] if ctx.use_fused_attention: dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) + if ctx.qkv_format in ["bshd", "sbhd"]: + # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] + dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): if ctx.qkv_format == "bshd": # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn] @@ -2228,7 +2386,7 @@ def backward(ctx, dout): dkv[:, 0, ...].add_(dkv_[:, 0, ...]) dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy") + tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy") else: dkv.add_(dkv_) elif i >= (cp_size - rank - 1): @@ -2238,14 +2396,14 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dkv[:, 0, ...].copy_(dkv_) elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none") + tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none") else: if ctx.qkv_format == "bshd": dkv[:, :, 0, ...].add_(dkv_) elif ctx.qkv_format == "sbhd": dkv[:, 0, ...].add_(dkv_) elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none") + tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none") elif i > 0: dkv.add_(dkv_) else: @@ -2259,14 +2417,22 @@ def backward(ctx, dout): if causal: if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - dq = dq.view(q.shape[0], -1, *q.shape[-2:]) + dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) + dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - dq = dq.view(-1, *q.shape[-3:]) + dq = dq.view(-1, *dq.shape[-3:]) # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - dkv = dkv.view(kv.shape[0], -1, *kv.shape[-3:]) + dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) + + if ctx.qkv_format == "thd": + dkv_ = torch.empty( + 2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device + ) + dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv) + dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) + dkv = dkv_ if attn_dbias is not None: # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] @@ -2303,9 +2469,9 @@ def attn_forward_func_with_cp( k, v, cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_kv, max_seqlen_q, - max_seqlen_k, + max_seqlen_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, dropout_p, @@ -2341,15 +2507,18 @@ def attn_forward_func_with_cp( """Attention bias is only supported with FusedAttention and "causal" """ """or "no_mask" mask types!""" ) + assert ( + cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None + ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!" out = AttnFuncWithCP.apply( is_training, q, k, v, cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_kv, max_seqlen_q, - max_seqlen_k, + max_seqlen_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, dropout_p, @@ -3140,7 +3309,8 @@ def forward( qkv_layout in QKVLayouts ), f"FlashAttention does not support qkv_layout = {qkv_layout}!" - context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1) + cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group) + context_parallel = cp_size > 1 qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -3167,6 +3337,8 @@ def forward( if qkv_format in ["sbhd", "bshd"]: max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size if not context_parallel: # [b * s, h, d] query_layer, key_layer, value_layer = [ @@ -3247,8 +3419,8 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - None, - None, + cu_seqlens_q, + cu_seqlens_kv, self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, @@ -3295,10 +3467,12 @@ def forward( if qkv_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) - output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous() + output = ( + output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1).contiguous() + ) elif qkv_format == "bshd": # (bs)hd -> bs(hd) - output = output.view(batch_size, max_seqlen_q, -1).contiguous() + output = output.view(batch_size, max_seqlen_q // cp_size, -1).contiguous() elif qkv_format == "thd": # thd -> t(hd) output = output.view(output.shape[0], -1).contiguous() @@ -4835,7 +5009,8 @@ def forward( qkv_layout in QKVLayouts ), f"FusedAttention does not support qkv_layout = {qkv_layout}!" - context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1) + cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group) + context_parallel = cp_size > 1 qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -4852,6 +5027,8 @@ def forward( query_layer.shape[1], key_layer.shape[1], ) + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size if "padding" in attn_mask_type: assert not context_parallel, "Padding mask not supported with context parallelism!" @@ -5540,13 +5717,22 @@ def forward( cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32 ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" if max_seqlen_q is None: - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + if cu_seqlens_q_padded is not None: + seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] + else: + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item()))) if max_seqlen_kv is None: - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + if cu_seqlens_kv_padded is not None: + seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1] + else: + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item()))) batch_size = len(cu_seqlens_q) - 1 + cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group) + context_parallel = cp_size > 1 + if qkv_format in ["sbhd", "bshd"]: assert all( len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) @@ -5557,6 +5743,8 @@ def forward( if qkv_format == "bshd": max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1]) batch_size = query_layer.shape[0] + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size if cu_seqlens_q is not None: seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] assert all( @@ -5628,10 +5816,6 @@ def forward( _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True - context_parallel = ( - self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1 - ) - core_attention_bias_shape = None if core_attention_bias is not None: if ( diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 9f4612f240..9cdc79ed64 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1565,7 +1565,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float dtype *p_per_step = reinterpret_cast(&data_per_step); dtype *p = reinterpret_cast(&data); for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { - p[k] += p_per_step[k] * lse_corrected_exp; + p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); } reinterpret_cast(cur_out)[j] = data; } From 098e3006065d806b32f2b403d4cf4ffd434dc78e Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Thu, 1 Aug 2024 17:26:17 -0700 Subject: [PATCH 31/72] Link attention docs to the main docs and fix errors reported by Sphinx (#1062) * Link attention docs to the main docs and fix errors reported by Sphinx Signed-off-by: Przemek Tredak * Lower the version of nbsphinx Signed-off-by: Przemek Tredak * More fixes Signed-off-by: Przemek Tredak * Change the URL of example_attention.py to GitHub Signed-off-by: Przemek Tredak * More fixes in the attention tutorial Signed-off-by: Przemek Tredak --------- Signed-off-by: Przemek Tredak --- .github/workflows/docs.yml | 4 +- docs/_templates/layout.html | 2 +- docs/conf.py | 2 + docs/examples/attention/attention.ipynb | 304 ++++++++++++++---- docs/index.rst | 1 + transformer_engine/jax/flax/module.py | 4 +- transformer_engine/jax/flax/transformer.py | 11 +- transformer_engine/jax/fp8.py | 4 +- .../paddle/layer/transformer.py | 8 +- transformer_engine/pytorch/ops/sequential.py | 2 +- 10 files changed, 264 insertions(+), 78 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 581ff1e935..b4eeefa70b 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -17,8 +17,8 @@ jobs: uses: actions/checkout@v3 - name: 'Install dependencies' run: | - pip install sphinx==7.1.2 sphinx_rtd_theme==2.0.0 nbsphinx==0.9.4 IPython ipython_genutils==0.2.0 ipywidgets==8.1.3 astroid==3.2.2 - pip install breathe==4.35.0 sphinx-autoapi==3.1.1 + pip install sphinx==5.1.1 sphinx_rtd_theme==1.0.0 nbsphinx==0.8.10 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==2.15.7 + pip install breathe==4.34.0 sphinx-autoapi==2.0.1 sudo apt-get install -y pandoc graphviz doxygen export GIT_SHA=$(git show-ref --hash HEAD) - name: 'Build docs' diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html index cb372b3a72..a68b4531e3 100644 --- a/docs/_templates/layout.html +++ b/docs/_templates/layout.html @@ -70,7 +70,7 @@ color: #8c0; } - html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt { + html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.citation):not(.glossary):not(.simple)>dt { background: rgba(118, 185, 0, 0.1); color: rgba(59,93,0,1); border-top: solid 3px rgba(59,93,0,1); diff --git a/docs/conf.py b/docs/conf.py index 695546a9ba..77751994d8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -109,6 +109,8 @@ ("Parallelism parameters", "params_style"), ("Optimization parameters", "params_style"), ("Values", "params_style"), + ("Graphing parameters", "params_style"), + ("FP8-related parameters", "params_style"), ] breathe_projects = {"TransformerEngine": os.path.abspath("doxygen/xml/")} diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 53d56532b9..515f420790 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -14,7 +14,8 @@ "
Figure 1: Dot product attention.
\n", "\n", "\n", - "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is,\n", + "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is\n", + "\n", "- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n", "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)\n", "- [transformer_engine.paddle.DotProductAttention](../../api/paddle.rst#transformer_engine.paddle.DotProductAttention)" @@ -28,12 +29,44 @@ "## 1. Attention Backends\n", "\n", "Transformer Engine provides multiple attention backends for each supported framework. The framework-native backends provide a robust baseline, while the fused, GPU-optimized implementations offer more performance. For example, the flash-attention and cuDNN attention backends in PyTorch. The framework-native backends are often named with \"unfused\", while the more optimized backends are \"fused\" or \"flash\".\n", - "\n", - "| Framework | Backend (Module Name) | Module Location |\n", - "| :-------- | :-------------------- | :-------------- |\n", - "| PyTorch | cuDNN attention (`FusedAttention`)
flash-attention (`FlashAttention`)
PyTorch-native attention (`UnfusedDotProductAttention`) | [transformer_engine.pytorch.attention](../../transformer_engine/pytorch/attention.py) |\n", - "| JAX | cuDNN attention (`_FusedDotProductAttention`)
JAX-native attention (`_UnfusedDotProductAttention`) | [transformer_engine.jax.flax.transformer](../../transformer_engine/jax/flax/transformer.py) |\n", - "| PaddlePaddle | cuDNN attention (`_te_forward`)
PaddlePaddle-native attention (`_pd_forward`) | [transformer_engine.paddle.layer.attention](../../transformer_engine/paddle/layer/attention.py) |\n" + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FrameworkBackend (Module Name)Module Location
PyTorchcuDNN attention (`FusedAttention`) [transformer_engine.pytorch.attention](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py)
flash-attention (`FlashAttention`)
\n", + " PyTorch-native attention (`UnfusedDotProductAttention`)\n", + "
JAXcuDNN attention (`_FusedDotProductAttention`)[transformer_engine.jax.flax.transformer](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/flax/transformer.py)
JAX-native attention (`_UnfusedDotProductAttention`)
PaddlePaddle cuDNN attention (`_te_forward`) [transformer_engine.paddle.layer.attention](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/paddle/layer/attention.py)\n", + "
PaddlePaddle-native attention (`_pd_forward`)
" ] }, { @@ -52,7 +85,9 @@ "- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n", "\n", "
\n", - "Note: Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n", + "Note \n", + " \n", + "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n", "
\n" ] }, @@ -67,19 +102,56 @@ "\n", "The flash-attention backend supports `flash-attn`'s features as they are released, and to facilitate the use of `flash-attn`, flash-attention also offers a few functionalities such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask. Please see `transformer_engine.pytorch.attention.FlashAttention` for more details.\n", "\n", - "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.7, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](../../setup.py)).\n", + "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.7, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", "\n", "To understand `flash-attn`'s performance, please refer to their [benchmarks](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n", "\n", "### 1.3 cuDNN Attention\n", "\n", - "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) and [cudnn-frontend](../../3rdparty/cudnn-frontend) to run, and has several sub-backends to support the different precisions and sequence lengths. Out of the three, sub-backends 1 and 2 are based on the flash algorithm, as `flash-attn` is.\n", - "\n", - "| Sub-Backend | Algorithm | Precision | Sequence Length | Architecture | Docs |\n", - "| :---------- | :--------- | :-------- | :-------------- | :----------- | :--- |\n", - "| 0 | Non-Flash | BF16/FP16 | <=512 | sm80, 90 | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-attention-fprop) |\n", - "| 1 | Flash | BF16/FP16 | Any | sm80+ | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop),
[cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention) |\n", - "| 2 | Flash | FP8 | cuDNN pre-9.0: <=512
cuDNN 9.0+: Any | cuDNN pre-9.0: sm90
cuDNN 9.0+: sm90+ | cuDNN 9.0+: [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-fp8) |\n", + "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths. Out of the three, sub-backends 1 and 2 are based on the flash algorithm, as `flash-attn` is.\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Sub-BackendAlgorithmPrecisionSequence LengthArchitectureAdditional info
0Non-FlashBF16/FP16 ≤512 sm80, 90 [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-attention-fprop)
1FlashBF16/FP16 Any sm80+ [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop),\n", + " [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention)\n", + "
2FlashFP8 cuDNN pre-9.0: ≤512 cuDNN pre-9.0: sm90
cuDNN 9.0+: Any cuDNN 9.0+: sm90+ cuDNN 9.0+: [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-fp8)\n", + "
\n", "\n", "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.7, cuDNN 9.0 and `flash-attn` 2.4.2,\n", "\n", @@ -91,7 +163,7 @@ "- flash-attention uses bottom right diagonal for `causal` mask in cross attention, and cuDNN attention uses top left (see `flash-attn`'s [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)).\n", "- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n", "\n", - "To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](../../benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0." + "To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0." ] }, { @@ -151,11 +223,32 @@ "\n", "When there are multiple backends available, Transformer Engine makes backend selection based on performance. In general, there are a few rules being followed in our selection logic (see table below). As we monitor the performance of different backends, the selection logic may change.\n", "\n", - "| Framework | Selection Order |\n", - "| :-------- | :--------------------- |\n", - "| PyTorch | sm90: cuDNN attention > flash-attention > PyTorch-native attention
sm80: flash-attention > cuDNN attention > PyTorch-native attention
cuDNN attention: sub-backend 1 > sub-backend 0 |\n", - "| JAX | cuDNN attention > JAX-native attention |\n", - "| PaddlePaddle | cuDNN attention > PaddlePaddle-native attention |\n" + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FrameworkSelection Order
PyTorchsm90: cuDNN attention > flash-attention > PyTorch-native attention
sm80: flash-attention > cuDNN attention > PyTorch-native attention
\n", + " cuDNN attention: sub-backend 1 > sub-backend 0\n", + "
JAXcuDNN attention > JAX-native attention
PaddlePaddle cuDNN attention > PaddlePaddle-native attention
" ] }, { @@ -171,7 +264,9 @@ "NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages\n", "```\n", "
\n", - "Note: These flags are supported in PyTorch only as of Transformer Engine 1.7. JAX and PaddlePaddle support is expected to be added in the future.\n", + "Note\n", + " \n", + "These flags are supported in PyTorch only as of Transformer Engine 1.7. JAX and PaddlePaddle support is expected to be added in the future.\n", "
" ] }, @@ -180,7 +275,7 @@ "id": "7e3b7981", "metadata": {}, "source": [ - "The [example_attention.py](./example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime." + "The [example_attention.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime." ] }, { @@ -283,14 +378,16 @@ " NVTE_ALLOW_NONDETERMINISTIC_ALGO = 0 # enables workspace optimization path\n", "```\n", "
\n", - "Note: Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n", + "Note\n", + " \n", + "Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n", "
\n", "\n", "### 2.3 Example Tests\n", "\n", - "Our [unit tests](../../tests/) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n", + "Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n", "\n", - "For example, in PyTorch, [test_dot_product_attention](../../tests/pytorch/fused_attention/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." + "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." ] }, { @@ -302,16 +399,16 @@ "\n", "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.7, Transformer Engine's attention backends have the following support matrix.\n", "\n", - "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Deterministic |\n", + "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Determinism Possible |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |\n", - "| cuDNN attention
(PyTorch, JAX, PaddlePaddle) | PyTorch: BF16, FP16, FP8
JAX, PaddlePaddle: BF16, FP16 | sm80+ | No | Yes | `bshd`,`sbhd`: Yes
`thd`: No | Sub-backend 0, 2: Yes
Sub-backend 1: Yes, if workspace optimization path |\n", - "| flash-attention
(PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | `bshd`,`thd`: Yes
`sbhd`: No | Yes, if `deterministic=True` |\n", - "| Framework-native attention
(PyTorch, JAX, PaddlePaddle) | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | No | Yes |\n", + "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes (only for `bshd`,`sbhd`) | Yes |\n", + "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes (only for `bshd`,`thd`) | Yes |\n", + "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | No | Yes |\n", "\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", - "- sliding window attention: [test_dpa_swa](../../tests/pytorch/fused_attention/test_fused_attn.py)\n", - "- MQA/GQA: [test_te_layer_mqa_gqa](../../tests/pytorch/fused_attention/test_fused_attn.py)\n", - "- context parallelism: [test_cp_with_fused_attention](../../tests/pytorch/fused_attention/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](../../tests/pytorch/fused_attention/test_fused_attn_with_cp.py)" + "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", + "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", + "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)" ] }, { @@ -331,29 +428,53 @@ "\n", "The notation system is that `b` stands for the batch size, `s` sequence length, `h` number of attention heads, `d` head dimension, and `t` the total number of tokens in the batch, i.e. `t = sum(s_i) for i in 0,...,b-1`. Here are a few examples of the layouts and their explanations to help clarify the definition.\n", "\n", - "**`qkv_layout`=`sb3hd`:**\n", + "**qkv_layout=sb3hd:**\n", "`q`, `k`, `v` are sequence first, i.e. `s` is the leading dimension in each tensor. They are different slices of one tensor `qkv`: `q, k, v = [qkv[:,:,i,:,:] for i in range(3)]`. They are interleaved at the `h * d` dimension.\n", "\n", - "**`qkv_layout`=`bshd_bsh2d`:**\n", + "**qkv_layout=bshd_bsh2d:**\n", "`q`, `k`, `v` are batch first, i.e. `b` is the leading dimension in each tensor. `q` is contiguous, and `k`, `v` are different slices of tensor `kv`: `k, v = [kv[:,:,:,i,:] for i in range(2)]`. `k`, `v` are interleaved at the `d` dimension.\n", "\n", "The `s` and `h` in `bsh2d` are the max sequence length and number of heads for `k`, `v`, which can be different from the `s` and `h` in `bshd` for `q`. We denoted them as the same for brevity reasons. Transformer Engine does differentiate their values for actual execution.\n", "\n", - "**`qkv_layout`=`thd_thd_thd`:**\n", + "**qkv_layout=thd_thd_thd:**\n", "`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n", "\n", "As of v1.7, Transformer Engine has the following support matrix.\n", "\n", - "| Backend | Supported QKV Formats | Notes |\n", - "| :--------------- | :-------------------- | :------ |\n", - "| flash-attention | `bshd`, `sbhd`, `thd`
(`sbhd` requires transpose operations) | PyTorch: 3 formats, i.e. 15 layouts|\n", - "| cuDNN attention | `bshd`, `sbhd`, `thd` | PyTorch: 3 formats, i.e. 15 layouts
JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts |\n", - "| Framework-native attention | `bshd`, `sbhd`
(`sbhd` requires transpose operations) | PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts |\n", - "\n", - "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
BackendSupported QKV FormatsNotes
flash-attention`bshd`, `sbhd`, `thd`PyTorch: 3 formats, i.e. 15 layouts
cuDNN attention`bshd`, `sbhd`, `thd`PyTorch: 3 formats, i.e. 15 layouts
\n", + " JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n", + "
Framework-native attention`bshd`, `sbhd`PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts
\n", + "\n", + "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "\n", "
\n", - "Note: When RoPE is employed, the qkv_layout may change in Transformer Engine PyTorch through [get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding hd_hd_hd layout. For example, from sbh3d in pytorch.MultiHeadAttention before RoPE, to sbhd_sbhd_sbhd in pytorch.DotProductAttention after RoPE.\n", + "Note\n", + " \n", + "When RoPE is employed, the qkv_layout may change in Transformer Engine PyTorch through [get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding hd_hd_hd layout. For example, from sbh3d in pytorch.MultiHeadAttention before RoPE, to sbhd_sbhd_sbhd in pytorch.DotProductAttention after RoPE.\n", "
\n" ] }, @@ -365,17 +486,46 @@ "### 3.2 Attention Mask\n", "\n", "Transformer Engine supports 5 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n", + "\n", "- `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), `arbitrary`\n", "\n", "Different backends offer different support for attention mask. As of Transformer Engine 1.7,\n", "\n", - "| Backend | Supported Mask Types | Requires `attention_mask` |\n", - "| :--------------- | :-------------------- | :------------------ |\n", - "| flash-attention | `no_mask`, `causal`, `padding`, `padding_causal` | `no_mask`, `causal`: No
`padding`, `padding_causal`: Yes if `cu_seqlens` not provided|\n", - "| cuDNN attention | `no_mask`, `causal`, `padding`, `padding_causal` | `no_mask`, `causal`: No
`padding`, `padding_causal`: Yes if `cu_seqlens` not provided|\n", - "| Framework-native attention | `no_mask`, `causal`, `arbitrary` | `no_mask`, `causal`: No
`arbitrary`: Yes |\n", - "\n", - "**`padding` and `padding_causal`:** For these two mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.7, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
BackendSupported Mask TypesRequires `attention_mask`
flash-attention`no_mask`, `causal`, `padding`, `padding_causal``no_mask`, `causal`: No
`padding`, `padding_causal`: Yes if `cu_seqlens` not provided
cuDNN attention`no_mask`, `causal`, `padding`, `padding_causal``no_mask`, `causal`: No
\n", + " `padding`, `padding_causal`: Yes if `cu_seqlens` not provided\n", + "
Framework-native attention`no_mask`, `causal`, `arbitrary``no_mask`, `causal`: No
`arbitrary`: Yes
\n", + "\n", + "**padding and padding_causal:** For these two mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.7, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n", "\n", "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n", " - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n", @@ -384,9 +534,9 @@ "\n", "* JAX and PaddlePaddle: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n", "\n", - "**`qkv_format`=`thd`:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n", + "**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n", "\n", - "**`Arbitrary` mask:** cuDNN does not support `Arbitrary` mask type as of v9.0. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](./arbitrary_mask_to_post_scale_bias.py).\n" + "**Arbitrary mask:** cuDNN does not support `Arbitrary` mask type as of v9.0. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py).\n" ] }, { @@ -416,23 +566,53 @@ "id": "e045c284", "metadata": {}, "source": [ - "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](../../tests/pytorch/fused_attention/test_fused_attn.py).\n", + "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n", "\n", "### 3.3 Attention Bias\n", "\n", "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.7, their support matrix is as follows.\n", "\n", - "| Backend | Bias Type | Bias Shape | Bias Data Type | Architecture |\n", - "| :------ | :-------- | :--------- | :--------- | :----------- |\n", - "| flash-attention | `no_bias`, `ALiBi` (with slopes) | N/A | ALiBi slopes: FP32 | sm80+ |\n", - "| cuDNN attention | PyTorch: `no_bias`, `post_scale_bias`, `ALiBi` (without slopes)
JAX, PaddlePaddle: `no_bias`, `post_scale_bias` | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS for forward, 1HSS for backward | `post_scale_bias`: same as QKV type
ALiBi slopes: FP32 | cuDNN 8.9.6+: sm90
cuDNN 9.0+: sm80+ |\n", - "| Framework-native attention | `no_bias`, `pre_scale_bias`, `post_scale_bias` | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS | `post_scale_bias`: same as QKV type | sm80+ |\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
BackendBias TypeBias ShapeBias Data TypeArchitecture
flash-attention`no_bias`, `ALiBi` (with slopes)N/AALiBi slopes: FP32sm80+
cuDNN attentionPyTorch: `no_bias`, `post_scale_bias`, `ALiBi` (without slopes)`post_scale_bias`: BHSS, 1HSS, B1SS, 11SS for forward, 1HSS for backward`post_scale_bias`: same as QKV typecuDNN 8.9.6+: sm90
JAX, PaddlePaddle: `no_bias`, `post_scale_bias`ALiBi slopes: FP32cuDNN 9.0+: sm80+
Framework-native attention`no_bias`, `pre_scale_bias`, `post_scale_bias``post_scale_bias`: BHSS, 1HSS, B1SS, 11SS `post_scale_bias`: same as QKV typesm80+
\n", "\n", "The flash-attention backend enables `ALiBi` by asking user to pass in an `alibi_slopes` tensor, which can be the default slopes of vanilla ALiBi, or user-defined slopes. On the other hand, cuDNN attention supports `ALiBi` by taking in a `Boolean` flag, and it only supports vanilla ALiBi as of cuDNN 9.0.\n", "\n", "The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n", "\n", - "More examples of how to use the various attention biases are at [test_dpa_bias](../../tests/pytorch/fused_attention/test_fused_attn.py)." + "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)." ] }, { @@ -450,7 +630,7 @@ "\n", "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n", "\n", - "Examples of using the two features are available at [test_dpa_fp8_vs_f16](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_mha_fp8_vs_f16](../../tests/pytorch/fused_attention/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`. This should result in the following print when the debug flags are turned on, `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2`.\n", + "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`. This should result in the following print when the debug flags are turned on, `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2`.\n", "```\n", "[DEBUG | DotProductAttention]: Running with fp8_recipe.fp8_mha=False, fp8_recipe.fp8_dpa=True and NVTE_FP8_DPA_BWD=0\n", "[DEBUG | FusedAttnFunc ]: Running forward in FP8\n", diff --git a/docs/index.rst b/docs/index.rst index d64cebbfa2..47b8388dd2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -51,3 +51,4 @@ Transformer Engine documentation :caption: Advanced api/c/index + examples/attention/attention.ipynb diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index e7388c20e0..8b13c47cd4 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -366,8 +366,8 @@ def generate_a_set(target_postfix): class DenseGeneral(TransformerEngineBase): - """ - Applies a linear transformation to the incoming data :math:`y = xA^T + b` + r""" + Applies a linear transformation to the incoming data :math:`y = xA^T + b`. Parameters ---------- diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 05c4ed7c42..d53a4e5202 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1531,19 +1531,20 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Indicate the min and max time-scales of rotary position embedding, only used when :attr:`enable_rotary_pos_emb=True` rotary_pos_emb_group_method: str, default = 'consecutive' - Indicate the method to coupled the coordinates. It should be one of - ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` - , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. + Indicate the method to couple the coordinates. It should be one of + ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`, + where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with + :math:`i + 1`. low_rank_adaptation_scope: str, default = 'none' Indicate the scope to apply low rank adaptation. It should be one of ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj', - 'exclude_output_proj', 'exclude_mlp'] + 'exclude_output_proj', 'exclude_mlp'] low_rank_adaptation_dim: int, default = 32 The dimension for low rank adaptation, only used when :attr:`enable_low_rank_adaptation=True` low_rank_adaptation_alpha: float, default = None The alpha for computing the scaling factor of LoRA output. - :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. + :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling. enable_sequence_parallel: bool, default = False Whether to enable sequence parallelism to operations except dot. diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index 4766203f69..5df8ce4386 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -328,8 +328,8 @@ def fp8_autocast( pjit(transformer.init, ...)(...) .. note:: - We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len` - , and :attr:`amax_compute_algo`(with value 'max' and 'most_recent') in + We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`, + and :attr:`amax_compute_algo` (with value 'max' and 'most_recent') in recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling will trigger an assertion. diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py index c2835a3160..4a9c2c38dc 100644 --- a/transformer_engine/paddle/layer/transformer.py +++ b/transformer_engine/paddle/layer/transformer.py @@ -9,9 +9,11 @@ import paddle from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd -from transformer_engine.paddle.layer import LayerNormMLP, LayerNorm, MultiHeadAttention -from transformer_engine.paddle.constants import AttnMaskTypes, LayerTypes, dist_group_type -from transformer_engine.paddle.distributed import get_tp_group_and_world_size, track_rng_state +from .layernorm_mlp import LayerNormMLP +from .layernorm import LayerNorm +from .attention import MultiHeadAttention +from ..constants import AttnMaskTypes, LayerTypes, dist_group_type +from ..distributed import get_tp_group_and_world_size, track_rng_state class TransformerLayer(paddle.nn.Layer): diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index cd3c104860..57b4036bba 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -10,7 +10,7 @@ import torch -from transformer_engine.pytorch.ops import FusibleOperation +from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.fuser import OperationFuser From 27c6342ea8ad88034bf04b587dd13cb6088d2474 Mon Sep 17 00:00:00 2001 From: Li Tao Date: Sat, 3 Aug 2024 02:42:28 +0800 Subject: [PATCH 32/72] Fix an argument issue when flash_attn>=2.5.7 (#1068) fix an argument issue when flash_attn>=2.5.7 Signed-off-by: Li Tao Co-authored-by: Li Tao Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8aaa76a177..d899934d76 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -79,6 +79,7 @@ _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") +_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") if _flash_attn_version >= _flash_attn_version_required: from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func @@ -1292,6 +1293,8 @@ def forward( fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1] if _flash_attn_2_4_plus: fa_optional_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_optional_forward_kwargs["block_table"] = None # Flash Attn inputs q_inputs = [None, None] @@ -3448,6 +3451,8 @@ def forward( fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic + if _flash_attn_2_5_7_plus: + fa_optional_forward_kwargs["block_table"] = None output = flash_attn_forward_func( query_layer, key_layer, From 87939be1e3fc1b59c422b78500dad6f98957a33b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 6 Aug 2024 09:40:46 -0700 Subject: [PATCH 33/72] [C/PyTorch] Add support for multi-latent attention (MLA) (#1039) * add multi-latent attention for DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Jax/Paddle API Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix typo in test script Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix too-many-boolean lint error Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "fix lint" This reverts commit 67399a3a6f45bb4ce9e5eaa6bcce40b28e347e5b. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix stride check in get_qkv_layout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: fix layout_thd tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP: debug info Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix merge conflict Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix thd pad_between_seqs=False/True tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 137 +++++++++----- tests/pytorch/test_onnx_export.py | 2 +- .../common/fused_attn/fused_attn.cpp | 73 ++++---- .../fused_attn_f16_arbitrary_seqlen.cu | 175 +++++++++--------- .../fused_attn_f16_arbitrary_seqlen.h | 6 +- .../common/fused_attn/fused_attn_fp8.cu | 2 + transformer_engine/common/fused_attn/utils.cu | 17 +- transformer_engine/common/fused_attn/utils.h | 12 +- .../include/transformer_engine/fused_attn.h | 7 +- .../jax/csrc/extensions/attention.cpp | 18 +- transformer_engine/paddle/csrc/common.h | 8 +- transformer_engine/pytorch/attention.py | 78 ++++---- .../pytorch/cpp_extensions/fused_attn.py | 12 +- transformer_engine/pytorch/csrc/extensions.h | 13 +- .../pytorch/csrc/extensions/attention.cu | 33 ++-- 15 files changed, 343 insertions(+), 250 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 73dfa23d9a..afc2081752 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -77,12 +77,13 @@ def __init__( batch_size: int, num_heads: int, num_gqa_groups: int, - head_dim: int, + head_dim_qk: int, max_seqlen_q: int, max_seqlen_kv: int, dropout_p: float, attn_mask_type: str, attn_bias_type: str, + head_dim_v: int = None, alibi_type: str = "none", num_layers: int = 1, bias_shape: str = "1hss", @@ -91,9 +92,10 @@ def __init__( self.batch_size = batch_size self.num_heads = num_heads self.num_gqa_groups = num_gqa_groups - self.head_dim = head_dim - self.hidden_size = num_heads * head_dim - self.hidden_size_kv = num_gqa_groups * head_dim + self.head_dim_qk = head_dim_qk + self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v + self.hidden_size = num_heads * head_dim_qk + self.hidden_size_kv = num_gqa_groups * self.head_dim_v self.max_seqlen_q = max_seqlen_q self.max_seqlen_kv = max_seqlen_kv self.dropout_p = dropout_p @@ -137,7 +139,11 @@ def _get_attention_backends( ) core_attention_bias_requires_grad = False # d=256 is supported by cuDNN 9.0+ for inference but not training - if config.attn_bias_type == "post_scale_bias" and config.head_dim <= 128: + if ( + config.attn_bias_type == "post_scale_bias" + and config.head_dim_qk <= 128 + and config.head_dim_v <= 128 + ): core_attention_bias_requires_grad = True fused_attn_backends = [] @@ -153,7 +159,8 @@ def test(): num_gqa_groups=config.num_gqa_groups, max_seqlen_q=config.max_seqlen_q, max_seqlen_kv=config.max_seqlen_kv, - head_dim=config.head_dim, + head_dim_qk=config.head_dim_qk, + head_dim_v=config.head_dim_v, attn_mask_type=config.attn_mask_type, window_size=window_size, alibi_slopes_shape=alibi_slopes_shape, @@ -218,11 +225,12 @@ def test_dot_product_attention( if dtype == torch.bfloat16: tols = dict(atol=2.5e-2, rtol=2.5e-2) config = model_configs[model] + is_mla = config.head_dim_qk != config.head_dim_v if qkv_layout is None: if config.attn_type == "self": - qkv_layout = "sb3hd" + qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd" else: - qkv_layout = "sbhd_sb2hd" + qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd" if "3" in qkv_layout and config.attn_type == "cross": pytest.skip("No need to test this layout for cross attention") @@ -241,14 +249,17 @@ def test_dot_product_attention( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # mannually pads and unpads the input and output of FlashAttention for testing purposes - if pad_between_seqs: + if pad_between_seqs and not ( + config.max_seqlen_q != config.max_seqlen_kv + and config.attn_mask_type in ["causal", "padding_causal"] + ): flash_attn_supported = True # Skip if only unfused backend is supported if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") - is_training = config.head_dim <= 128 + is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128 # UnfusedDotProductAttention backend if unfused_attn_supported: unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( @@ -343,6 +354,38 @@ def test_dpa_checkpoint(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) +model_configs_mla = { + # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend + "mla_1_0": ModelConfig( + 8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # self , 0 + "mla_1_1": ModelConfig( + 4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # cross, 0 + "mla_2_0": ModelConfig( + 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64 + ), # self , 1 + "mla_2_1": ModelConfig( + 1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64 + ), # cross, 1 + "mla_3_0": ModelConfig( + 8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64 + ), # inference + "mla_3_1": ModelConfig( + 8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # inference +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model_configs", [model_configs_mla]) +@pytest.mark.parametrize("model", model_configs_mla.keys()) +def test_dpa_mla(dtype, model_configs, model): + """Test DotProductAttention module with Multi-Latent Attention (MLA)""" + test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) + + model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), @@ -586,14 +629,16 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): @pytest.mark.parametrize("qkv_layout", qkv_layouts_thd) def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with different QKV layouts""" - pad_between_seqs = False - test_dot_product_attention( - dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs - ) pad_between_seqs = True test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs ) + if get_cudnn_version() >= (9, 3, 0): + # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run + pad_between_seqs = False + test_dot_product_attention( + dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs + ) def _run_dot_product_attention( @@ -736,7 +781,8 @@ def _run_dot_product_attention( "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim, + "dqk": config.head_dim_qk, + "dv": config.head_dim_v, "t": cu_seqlens_q_after_pad[-1], "tg": cu_seqlens_kv_after_pad[-1], "3": 3, @@ -753,12 +799,16 @@ def _run_dot_product_attention( layout = layout.replace("s", "skv") layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") + if i == 2: + layout = layout.replace("d", "dv") + else: + layout = layout.replace("d", "dqk") tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor_orig = tensor if qkv_format == "thd" and pad_between_seqs: tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - if layout in ["t_h_d", "t_3_h_d", "t_h_3_d"]: + if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]: for i in range(1, config.batch_size + 1): valid_range = ( cu_seqlens_q_after_pad[i - 1], @@ -772,7 +822,7 @@ def _run_dot_product_attention( tensor_orig = torch.cat( [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0 ) - if layout in ["tg_hg_d", "tg_2_hg_d", "tg_hg_2_d"]: + if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]: for i in range(1, config.batch_size + 1): valid_range = ( cu_seqlens_kv_after_pad[i - 1], @@ -811,13 +861,14 @@ def _run_dot_product_attention( # Create output gradient qkv_format_kv = "_".join(qkv_format) qkv_format_kv = qkv_format_kv.replace("s", "sq") + qkv_format_kv = qkv_format_kv.replace("d", "dv") out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda") out_grad_orig = out_grad if qkv_format == "thd" and pad_between_seqs: out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - if qkv_format_kv == "t_h_d": + if qkv_format_kv == "t_h_dv": for i in range(1, config.batch_size + 1): valid_range = ( cu_seqlens_q_after_pad[i - 1], @@ -851,7 +902,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: # Set up model block = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, qkv_format=qkv_format, @@ -906,9 +957,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if backend == "FusedAttention": if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + if is_training: + q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) for i in range(1, config.batch_size + 1): valid_range_q = ( cu_seqlens_q_after_pad[i - 1], @@ -919,15 +971,16 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: cu_seqlens_kv_after_pad[i] - pad_len[i - 1], ) out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0) - q_grad_orig = torch.cat( - [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0 - ) - k_grad_orig = torch.cat( - [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 - ) - v_grad_orig = torch.cat( - [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 - ) + if is_training: + q_grad_orig = torch.cat( + [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0 + ) + k_grad_orig = torch.cat( + [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 + ) + v_grad_orig = torch.cat( + [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 + ) if is_training: return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig) else: @@ -1168,7 +1221,7 @@ def _run_transformer_layer( # Create RoPE rotary_pos_emb = None if RoPE: - PE = RotaryPositionEmbedding(dim=config.head_dim) + PE = RotaryPositionEmbedding(dim=config.head_dim_qk) rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda") # Set up model @@ -1183,7 +1236,7 @@ def _run_transformer_layer( init_method=init_method, output_layer_init_method=output_layer_init_method, layer_number=layer_number, - kv_channels=config.head_dim, + kv_channels=config.head_dim_qk, self_attn_mask_type=config.attn_mask_type, tp_group=None, tp_size=1, @@ -1356,7 +1409,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: mha = MultiheadAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_heads, - kv_channels=config.head_dim, + kv_channels=config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, layer_number=1, @@ -1387,7 +1440,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim, + "d": config.head_dim_qk, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -1531,7 +1584,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: with fp8_model_init(enabled=fp8_dpa): dpa = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, sequence_parallel=False, @@ -1560,7 +1613,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim, + "d": config.head_dim_qk, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -1732,7 +1785,7 @@ def _run_custom_mha_fp8(dtype, config, backend): inp = 0.0001 * torch.randint( -100, 100, - (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim), + (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk), dtype=dtype, device="cuda", requires_grad=True, @@ -1743,7 +1796,7 @@ def _run_custom_mha_fp8(dtype, config, backend): out_grad = 0.01 * torch.randn( config.batch_size * config.max_seqlen_q, - config.num_heads * config.head_dim, + config.num_heads * config.head_dim_qk, dtype=dtype, device="cuda", ) @@ -1766,7 +1819,7 @@ def _run_custom_mha_fp8(dtype, config, backend): return ( out.view(config.batch_size, config.max_seqlen_q, -1), dqkv.view( - config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim + config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk ).contiguous(), ) @@ -1809,7 +1862,7 @@ def get_dummy_cuda_rng_tracker(): block = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, attention_dropout=config.dropout_p, sequence_parallel=False, tp_size=1, @@ -2105,7 +2158,7 @@ def __init__(self, config, params_dtype: torch.dtype = torch.float32): self.p_dropout = config.dropout_p self.h = config.num_heads self.hidden_size = config.hidden_size - self.head_dim = config.head_dim + self.head_dim = config.head_dim_qk self.fast_zero_fill = True self.mask_type = config.attn_mask_type diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index bdc459cdcc..e8361a2190 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -1083,7 +1083,7 @@ def test_export_core_attention( model = te.attention.DotProductAttention( num_attention_heads=num_attention_heads, - kv_channels=kv_channels, + k_channels=kv_channels, attention_dropout=0.5, qkv_format=qkv_format, attn_mask_type=attn_mask_type, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 895baea789..0fe62f8cb4 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -72,8 +72,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, - int64_t window_size_right) { + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -84,10 +84,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) && (sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) && - (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim == 64) && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || + (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) && + (head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || ((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) && - (max_seqlen_kv % 128 == 0) && (head_dim == 128) && + (max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) && ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -104,8 +104,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool flag_m512 = false; bool flag_arb = false; if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && - (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim == 64) && - (num_attn_heads == num_gqa_groups) && + (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) && + (head_dim_v == 64) && (num_attn_heads == num_gqa_groups) && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -131,11 +131,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || (cudnn_runtime_version >= 8907)) && // head dimension - ((head_dim <= 128 && head_dim % 8 == 0) || + ((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) || // TODO (cyang): add is_training to nvte_get_fused_attn_backend // d=256 only supported for forward - (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 && - head_dim % 8 == 0)) && + (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 && + head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version >= 8906) && @@ -155,6 +155,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || ((cudnn_runtime_version >= 90300) && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) && @@ -259,7 +260,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, window_size_left, window_size_right); + max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -336,7 +337,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, window_size_left, window_size_right); + max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -430,7 +431,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -514,7 +515,7 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -595,7 +596,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -603,13 +605,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, + input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); @@ -617,18 +619,18 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, - handle); + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, + input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -674,7 +676,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -682,15 +685,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_K, input_V, input_dO, output_S, - output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, - input_cu_seqlens_kv, wkspace, stream, handle); + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif @@ -705,9 +708,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_K, - input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, + input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -721,7 +724,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 7ee7ba33bd..42fb779717 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -48,11 +48,11 @@ namespace transformer_engine { namespace fused_attn { void fused_attn_arbitrary_seqlen_fwd_impl( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, - int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, + float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, + void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, @@ -86,7 +86,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( hg, s_q, s_kv, - d, + d_qk, + d_v, bias_b, bias_h, scaling_factor, @@ -167,41 +168,41 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); if (is_ragged) { Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_ragged_offset(offset_q)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_ragged_offset(offset_k)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_ragged_offset(offset_v)); } else { Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride)); } @@ -265,15 +266,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { O->set_output(true) - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_ragged_offset(offset_o); } else { - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); + O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); } Stats->set_output(true) @@ -360,7 +361,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devOffsetsO = static_cast(devOffsetsV) + (b + 1) * sizeof(int32_t); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d, static_cast(devPtrSeqOffsetsQ), + layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), static_cast(devOffsetsQ), static_cast(devOffsetsK), static_cast(devOffsetsV), static_cast(devOffsetsO)); @@ -381,13 +382,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } void fused_attn_arbitrary_seqlen_bwd_impl( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, - int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, - void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, - void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, + void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, + void *devPtrBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, + void *devPtrdBias, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -419,7 +420,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( hg, s_q, s_kv, - d, + d_qk, + d_v, bias_b, bias_h, scaling_factor, @@ -505,61 +507,61 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_ragged_offset(offset_q)); k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_ragged_offset(offset_k)); v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_ragged_offset(offset_v)); o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_ragged_offset(offset_o)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_ragged_offset(offset_o)); } else { q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride)); k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride)); v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride)); o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride)); } stats = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -644,21 +646,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_ragged) { dQ->set_output(true) - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_ragged_offset(offset_q); dK->set_output(true) - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_ragged_offset(offset_k); dV->set_output(true) - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_ragged_offset(offset_v); } else { - dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); - dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); - dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); + dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); + dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride); + dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride); } std::tuple, // q @@ -758,7 +760,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void *devOffsetsO = static_cast(devOffsetsV) + (b + 1) * sizeof(int32_t); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d, static_cast(devPtrSeqOffsetsQ), + layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), static_cast(devOffsetsQ), static_cast(devOffsetsK), static_cast(devOffsetsV), static_cast(devOffsetsO)); @@ -865,11 +867,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, + bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, + devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, + handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -941,11 +944,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, - devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, + bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { @@ -1051,12 +1054,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, + bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1131,12 +1134,13 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, - devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, + bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1155,8 +1159,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -1233,12 +1237,12 @@ void fused_attn_arbitrary_seqlen_fwd( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1257,7 +1261,7 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, @@ -1302,12 +1306,13 @@ void fused_attn_arbitrary_seqlen_bwd( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, - devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 27a2dd37ea..4b523cca1a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -58,8 +58,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -68,7 +68,7 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index fcce30d6a1..bda3f5beba 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1679,6 +1679,7 @@ void fused_attn_fp8_fwd_impl_v1( s_q, s_kv, d, + d, bias_b, bias_h, scaling_factor, @@ -1976,6 +1977,7 @@ void fused_attn_fp8_bwd_impl_v1( s_q, s_kv, d, + d, bias_b, bias_h, scaling_factor, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 7467462d2a..56dbb278b4 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -363,29 +363,30 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu // convert cu_seqlens_padded to offsets __global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h, - size_t hg, size_t d, int32_t *cu_seqlens_q_padded, + size_t hg, size_t d_qk, size_t d_v, + int32_t *cu_seqlens_q_padded, int32_t *cu_seqlens_kv_padded, int32_t *offsets_q, int32_t *offsets_k, int32_t *offsets_v, int32_t *offsets_o) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < b + 1) { - offsets_o[tid] = h * d * cu_seqlens_q_padded[tid]; + offsets_o[tid] = h * d_v * cu_seqlens_q_padded[tid]; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - offsets_q[tid] = h * d * cu_seqlens_q_padded[tid]; - offsets_k[tid] = hg * d * cu_seqlens_kv_padded[tid]; - offsets_v[tid] = offsets_k[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; + offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[tid]; + offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[tid]; break; case NVTE_QKV_Layout_Group::NVTE_3HD: case NVTE_QKV_Layout_Group::NVTE_H3D: - offsets_q[tid] = 3 * h * d * cu_seqlens_q_padded[tid]; + offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[tid]; offsets_k[tid] = offsets_q[tid]; offsets_v[tid] = offsets_q[tid]; break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - offsets_q[tid] = h * d * cu_seqlens_q_padded[tid]; - offsets_k[tid] = 2 * hg * d * cu_seqlens_kv_padded[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; + offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[tid]; offsets_v[tid] = offsets_k[tid]; break; } diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 74d1628a33..d5cf450181 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -91,7 +91,8 @@ struct FADescriptor_v1 { std::int64_t hg; std::int64_t s_q; std::int64_t s_kv; - std::int64_t d; + std::int64_t d_qk; + std::int64_t d_v; std::int64_t bias_b; std::int64_t bias_h; float attnScale; @@ -107,11 +108,11 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t bwd_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { - return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining, + return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < - std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h, - rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, + std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b, + rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); } @@ -126,7 +127,8 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu int32_t *kv_seqlens); __global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h, - size_t hg, size_t d, int32_t *cu_seqlens_q_padded, + size_t hg, size_t d_qk, size_t d_v, + int32_t *cu_seqlens_q_padded, int32_t *cu_seqlens_kv_padded, int32_t *offsets_q, int32_t *offsets_k, int32_t *offsets_v, int32_t *offsets_o); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 342c53bc7f..fa358bc86c 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -147,15 +147,16 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); * \param[in] num_gqa_groups The number of heads in K, V. * \param[in] max_seqlen_q The sequence length of Q. * \param[in] max_seqlen_kv The sequence length of K, V. - * \param[in] head_dim The head dimension of Q, K, V. + * \param[in] head_dim_qk The head dimension of Q, K. + * \param[in] head_dim_v The head dimension of V. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, - int64_t window_size_right); + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right); /*! \brief Compute dot product attention with packed QKV input. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 640869ac36..382b17d207 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -19,7 +19,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, auto backend = nvte_get_fused_attn_backend( static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, - head_dim, -1, -1); + head_dim, head_dim, -1, -1); return backend; } @@ -255,10 +255,10 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - auto backend = - nvte_get_fused_attn_backend(static_cast(dtype), static_cast(dtype), - qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, - num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1); + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + head_dim, head_dim, -1, -1); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -486,10 +486,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); - auto backend = - nvte_get_fused_attn_backend(static_cast(dtype), static_cast(dtype), - qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, - num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1); + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + head_dim, head_dim, -1, -1); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, rng_state, bias); diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h index 60f06a2188..6ce250432a 100644 --- a/transformer_engine/paddle/csrc/common.h +++ b/transformer_engine/paddle/csrc/common.h @@ -131,10 +131,10 @@ inline NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) { - NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend(static_cast(q_dtype), static_cast(kv_dtype), - qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, - num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, -1, -1); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, + attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, + head_dim, head_dim, -1, -1); return fused_attention_backend; } diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d899934d76..0790315400 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -142,8 +142,10 @@ class AttentionParams: Maximum sequence length of the query tensor. max_seqlen_kv: int, default = 128 Maximum sequence length of the key and value tensors. - head_dim: int, default = 64 - The size of each attention head. + head_dim_qk: int, default = 64 + The size of each attention head in query and key tensors. + head_dim_v: int, default = 64 + The size of each attention head in the value tensor. attn_mask_type: str, default = `no_mask` Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} @@ -182,7 +184,8 @@ class AttentionParams: num_gqa_groups: int = 16 max_seqlen_q: int = 128 max_seqlen_kv: int = 128 - head_dim: int = 64 + head_dim_qk: int = 64 + head_dim_v: int = 64 attn_mask_type: str = "no_mask" window_size: Union[Tuple[int, int], None] = None alibi_slopes_shape: Union[torch.Size, List, None] = None @@ -245,7 +248,8 @@ def get_attention_backend( num_gqa_groups = attention_params.num_gqa_groups max_seqlen_q = attention_params.max_seqlen_q max_seqlen_kv = attention_params.max_seqlen_kv - head_dim = attention_params.head_dim + head_dim_qk = attention_params.head_dim_qk + head_dim_v = attention_params.head_dim_v attn_mask_type = attention_params.attn_mask_type window_size = attention_params.window_size alibi_slopes_shape = attention_params.alibi_slopes_shape @@ -352,19 +356,31 @@ def get_attention_backend( use_unfused_attention = False # Filter: Head dimension + if use_flash_attention and head_dim_qk != head_dim_v: + logger.debug("Disabling FlashAttention as it does not support MLA.") + use_flash_attention = False if use_flash_attention and ( - head_dim > 256 - or head_dim % 8 != 0 - or (head_dim > 192 and device_compute_capability not in ((8, 0), (9, 0))) + head_dim_qk > 256 + or head_dim_qk % 8 != 0 + or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0))) ): logger.debug( - "Disabling FlashAttention due to unsupported head_dim. " - "Supported: head_dim %%8 = 0, head_dim <= 256 (>192 requires sm80/90). " - "Found: head_dim = %s on sm%s.", - head_dim, + "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " + "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " + "head_dim_qk <= 256 (>192 requires sm80/90). " + "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", + head_dim_qk, + head_dim_v, ".".join([str(i) for i in device_compute_capability]), ) use_flash_attention = False + qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") + if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": + logger.debug( + "Disabling FusedAttention as MLA is not supported with qkv_layout = %s", + qkv_layout, + ) + use_fused_attention = False # Filter: QKV layout qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -557,7 +573,8 @@ def get_attention_backend( num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim, + head_dim_qk, + head_dim_v, window_size[0], window_size[1], ) @@ -3132,12 +3149,14 @@ def run_iteratively(q, k, v): stride = q.stride() check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) stride = k.stride() - check_strides_kv = all(stride == x.stride() for x in [k, v]) + check_strides_kv = torch.equal( + torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1] + ) shape = q.shape check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) shape = k.shape - check_shapes_kv = all(shape == x.shape for x in [k, v]) + check_shapes_kv = shape[:-1] == v.shape[:-1] last_dim_size = q.shape[-1] check_last_dim_offsets_qkv = all( @@ -5177,8 +5196,10 @@ class DotProductAttention(TransformerEngineBaseModule): ---------- num_attention_heads : int number of attention heads in the transformer layer. - kv_channels : int - number of key-query-value channels per attention head. + k_channels : int + number of channels per attention head in key. + v_channels : Optional[int] = None + number of channels per attention head in value. num_gqa_groups : Optional[int] = None number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -5264,7 +5285,8 @@ class DotProductAttention(TransformerEngineBaseModule): def __init__( self, num_attention_heads: int, - kv_channels: int, + k_channels: int, + v_channels: Optional[int] = None, num_gqa_groups: Optional[int] = None, attention_dropout: float = 0.0, qkv_format: str = "sbhd", @@ -5304,7 +5326,8 @@ def __init__( self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream - self.hidden_size_per_attention_head = kv_channels + self.hidden_size_per_attention_head = k_channels + self.v_channels = k_channels if v_channels is None else v_channels self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) @@ -5322,7 +5345,7 @@ def __init__( attention_dropout_ctx = self.rng_states_tracker.fork if softmax_scale is None: - softmax_scale = 1.0 / math.sqrt(kv_channels) + softmax_scale = 1.0 / math.sqrt(k_channels) self.deterministic = ( not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) @@ -5469,16 +5492,6 @@ def forward( Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`. - .. note:: - - Input tensor :attr:`query_layer` must be of shape - (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`, - :attr:`kv_channels`) and the tensors :attr:`key_layer` and :attr:`value_layer` - must each be of shape (:attr:`sequence_length`, :attr:`batch_size`, - :attr:`num_gqa_groups`, :attr:`kv_channels`). Output of shape - (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads` - * :attr:`kv_channels`) is returned. - .. note:: DotProductAttention supports three backends: 1) FlashAttention which calls @@ -5628,7 +5641,9 @@ def forward( assert ( query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype ), "Queries, keys and values must have the same data type!" - assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" + assert ( + key_layer.shape[:-1] == value_layer.shape[:-1] + ), "Keys and values must have the same batch size, sequence length and number of heads!" if attn_mask_type is None: attn_mask_type = self.attn_mask_type @@ -5861,7 +5876,8 @@ def forward( num_gqa_groups=key_layer.shape[-2], max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, - head_dim=query_layer.shape[-1], + head_dim_qk=query_layer.shape[-1], + head_dim_v=value_layer.shape[-1], attn_mask_type=attn_mask_type, window_size=window_size, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 4dc169da00..d0ba644621 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -140,7 +140,7 @@ def fused_attn_fwd_qkvpacked( output tensor, amax of O, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -342,7 +342,7 @@ def fused_attn_bwd_qkvpacked( output tensor, amax of dQKV, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -508,7 +508,7 @@ def fused_attn_fwd_kvpacked( output tensor, amax of O, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -729,7 +729,7 @@ def fused_attn_bwd_kvpacked( output tensor, amax of dQKV, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -907,7 +907,7 @@ def fused_attn_fwd( output tensor, amax of O, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -1135,7 +1135,7 @@ def fused_attn_bwd( output tensor, amax of dQ, dK and dV, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f06b0cb197..bd908e9336 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -14,11 +14,14 @@ * Attention **************************************************************************************************/ -NVTE_Fused_Attn_Backend get_fused_attn_backend( - const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right); +NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, + const transformer_engine::DType kv_dtype, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, float p_dropout, + size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right); std::vector fused_attn_fwd_qkvpacked( size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 9cdc79ed64..50eb7b830f 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -14,11 +14,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right) { + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim, window_size_left, window_size_right); + head_dim_qk, head_dim_v, window_size_left, window_size_right); return fused_attention_backend; } @@ -761,7 +762,11 @@ std::vector fused_attn_fwd( std::vector v_shape{v_sizes.begin(), v_sizes.end()}; // create output tensor O - auto O = torch::empty_like(Q); + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto o_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + o_shape[o_shape.size() - 1] = v_sizes[v_sizes.size() - 1]; + std::vector o_shape_tmp{o_shape.begin(), o_shape.end()}; + auto O = torch::empty(c10::IntArrayRef(o_shape_tmp), options); // construct NVTE tensors TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; @@ -790,7 +795,7 @@ std::vector fused_attn_fwd( descale_QKV.value().data_ptr()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { @@ -801,7 +806,7 @@ std::vector fused_attn_fwd( te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } @@ -839,8 +844,7 @@ std::vector fused_attn_fwd( auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); @@ -935,8 +939,11 @@ std::vector fused_attn_bwd( std::vector v_shape{v_sizes.begin(), v_sizes.end()}; auto h_q = q_shape[q_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; + auto d_qk = q_shape[q_shape.size() - 1]; + auto d_v = v_shape[v_shape.size() - 1]; auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); + std::vector o_shape{q_sizes.begin(), q_sizes.end()}; + o_shape[o_shape.size() - 1] = d_v; at::Tensor dQ; at::Tensor dK; @@ -1015,7 +1022,7 @@ std::vector fused_attn_bwd( TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && ((h_q * d) % block_size == 0) && ((h_kv * d) % block_size == 0) && + if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -1041,9 +1048,9 @@ std::vector fused_attn_bwd( descale_QKV.value().data_ptr()); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, + te_dO = makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); @@ -1068,9 +1075,9 @@ std::vector fused_attn_bwd( te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); te_dO = - makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); + makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, nullptr); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dQ = From d74e65f5c693b9a769fd418bebfbd86ae19a4648 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 7 Aug 2024 00:47:34 +0800 Subject: [PATCH 34/72] [JAX] Reduce lowering time after cuDNN 90300 (#1032) * Support actlen = 0 after cuDNN 9.3.0 Signed-off-by: Reese Wang * Add runtime_segment < max_segment tests Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang --- tests/jax/test_fused_attn.py | 7 ++- .../jax/cpp_extensions/attention.py | 31 +++++++--- transformer_engine/jax/cpp_extensions/misc.py | 14 +++++ .../jax/csrc/extensions/attention.cpp | 58 +++++++++++++------ .../jax/csrc/extensions/pybind.cpp | 1 + transformer_engine/jax/csrc/utils.cu | 2 + transformer_engine/jax/csrc/utils.h | 1 + 7 files changed, 85 insertions(+), 29 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 796d5bcffa..b003fe4e3d 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -391,7 +391,7 @@ def generate_random_segment_ids( return segment_ids, segment_pad if get_qkv_format(self.qkv_layout) == QKVFormat.THD: - self.num_segments_per_seq = 3 + self.num_segments_per_seq = 2 self.token_q, self.segment_pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) @@ -461,7 +461,8 @@ def test_forward(self): "dropout_probability": self.dropout_prob, "is_training": self.is_training, "qkv_layout": self.qkv_layout, - "max_segments_per_seq": self.num_segments_per_seq, + # +1 for testing runtime_segments < max_segments + "max_segments_per_seq": self.num_segments_per_seq + 1, } # Convert the outputs to float32 for the elementwise comparison @@ -518,7 +519,7 @@ def grad_func(func, *args, **kwargs): "dropout_probability": self.dropout_prob, "is_training": self.is_training, "qkv_layout": self.qkv_layout, - "max_segments_per_seq": self.num_segments_per_seq, + "max_segments_per_seq": self.num_segments_per_seq + 1, } # We can compute dBias only for the [1, h, s, s] layout diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 4e94de08c4..6fa43b7961 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -30,6 +30,7 @@ jax_dtype_to_te_dtype, te_dtype_to_jax_dtype, get_padded_spec, + get_cudnn_version, ) from ..sharding import ( all_reduce_sum_along_dp_fsdp, @@ -393,12 +394,12 @@ def impl( if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: - def _fix_len_take(x, condition): + def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape x = x.flatten() size = x.size indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] - y = jnp.take(x, indices, fill_value=-1) + y = jnp.take(x, indices, fill_value=fill_value) return jnp.reshape(y, x_shape) def convert_to_2d(offsets, batch, max_seqlen): @@ -425,9 +426,16 @@ def convert_to_2d(offsets, batch, max_seqlen): kv_batch = reduce(operator.mul, k.shape[:-3]) # Gather valid q_seqlen, which is greater than 0 + # cuDNN version < 9.3.0: # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] - q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0) - kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0) + # cuDNN version >= 9.3.0, which supports act_seqlen = 0 + # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]] + if get_cudnn_version() >= (9, 3, 0): + fill_value = 0 + else: + fill_value = -1 + q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) + kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) # Flatten the offset calculation # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] @@ -788,13 +796,13 @@ def impl( if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: - def _fix_len_take(x, condition): + def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape x = x.flatten() size = x.size indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] # TODO(rewang): try indices_are_sorted - y = jnp.take(x, indices, fill_value=-1) + y = jnp.take(x, indices, fill_value=fill_value) return jnp.reshape(y, x_shape) def convert_to_2d(offsets, batch, max_seqlen): @@ -821,9 +829,16 @@ def convert_to_2d(offsets, batch, max_seqlen): kv_batch = reduce(operator.mul, k.shape[:-3]) # Gather valid q_seqlen, which is greater than 0 + # cuDNN version < 9.3.0: # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] - q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0) - kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0) + # cuDNN version >= 9.3.0, which supports act_seqlen = 0 + # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]] + if get_cudnn_version() >= (9, 3, 0): + fill_value = 0 + else: + fill_value = -1 + q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) + kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) # Flatten the offset calculation # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index b27e97d7b5..9ad7354815 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -3,12 +3,16 @@ # See LICENSE for license information. """JAX/TE miscellaneous for custom ops""" +import functools +from typing import Tuple + import numpy as np import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import dtype_to_ir_type from transformer_engine.transformer_engine_jax import DType as TEDType +from transformer_engine import transformer_engine_jax from ..sharding import get_padded_spec as te_get_padded_spec @@ -128,3 +132,13 @@ def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary): *shape[transpose_axis_boundary:], *shape[transpose_start_idx:transpose_axis_boundary], ) + + +@functools.lru_cache(maxsize=None) +def get_cudnn_version() -> Tuple[int, int, int]: + """Runtime cuDNN version (major, minor, patch)""" + encoded_version = transformer_engine_jax.get_cudnn_version() + major_version_magnitude = 1000 if encoded_version < 90000 else 10000 + major, encoded_version = divmod(encoded_version, major_version_magnitude) + minor, patch = divmod(encoded_version, 100) + return (major, minor, patch) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 382b17d207..866147b336 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -139,7 +139,13 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; - for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) { + size_t min_num_segments = input_batch; + auto cudnn_runtime_version = cudnnGetVersion(); + if (is_ragged && cudnn_runtime_version >= 90300) { + // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 + min_num_segments = input_batch * max_segments_per_seq; + } + for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { // the last one is the largest which will be the returned workspace size auto q_cu_seqlens_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); @@ -227,14 +233,19 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments if (is_ragged) { - // workspace can be reused here as it is not used with cuDNN graph at the same time - size_t runtime_num_segments_q = - GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); - size_t runtime_num_segments_kv = - GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); - NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); - NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); - num_segments = runtime_num_segments_q; + auto cudnn_runtime_version = cudnnGetVersion(); + if (cudnn_runtime_version >= 90300) { + num_segments = input_batch * max_segments_per_seq; + } else { + // workspace can be reused here as it is not used with cuDNN graph at the same time + size_t runtime_num_segments_q = + GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); + size_t runtime_num_segments_kv = + GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); + NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); + NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); + num_segments = runtime_num_segments_q; + } cudaMemsetAsync(output, 0, input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream); } @@ -366,7 +377,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; - for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) { + size_t min_num_segments = input_batch; + auto cudnn_runtime_version = cudnnGetVersion(); + if (is_ragged && cudnn_runtime_version >= 90300) { + // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 + min_num_segments = input_batch * max_segments_per_seq; + } + for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { // the last one is the largest which will be the returned workspace size auto q_cu_seqlens_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); @@ -460,14 +477,19 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments if (is_ragged) { - // workspace can be reused here as it is not used with cuDNN graph at the same time - size_t runtime_num_segments_q = - GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); - size_t runtime_num_segments_kv = - GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); - NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); - NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); - num_segments = runtime_num_segments_q; + auto cudnn_runtime_version = cudnnGetVersion(); + if (cudnn_runtime_version >= 90300) { + num_segments = input_batch * max_segments_per_seq; + } else { + // workspace can be reused here as it is not used with cuDNN graph at the same time + size_t runtime_num_segments_q = + GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); + size_t runtime_num_segments_kv = + GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); + NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); + NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); + num_segments = runtime_num_segments_q; + } } auto q_cu_seqlens_tensor = diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 95fe3101c9..fb293f2fe1 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -59,6 +59,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_cuda_version", &GetCudaRuntimeVersion); + m.def("get_cudnn_version", &GetCudnnRuntimeVersion); m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes); diff --git a/transformer_engine/jax/csrc/utils.cu b/transformer_engine/jax/csrc/utils.cu index d9451dca32..8ca34013b3 100644 --- a/transformer_engine/jax/csrc/utils.cu +++ b/transformer_engine/jax/csrc/utils.cu @@ -19,6 +19,8 @@ int GetCudaRuntimeVersion() { return ver; } +size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); } + int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); } __global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed, diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index fd3ebe8d8c..32de33bac9 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -22,6 +22,7 @@ namespace transformer_engine { namespace jax { int GetCudaRuntimeVersion(); +size_t GetCudnnRuntimeVersion(); int GetDeviceComputeCapability(int gpu_id); void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, From 5bb3a412c8d8f3f798fd8e539008c21d7ebe1b98 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 7 Aug 2024 00:48:26 +0800 Subject: [PATCH 35/72] [JAX] Add the missing 1HSS tests (#1052) Add the missing 1HSS tests Signed-off-by: Reese Wang --- tests/jax/test_fused_attn.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index b003fe4e3d..390a3e2c4e 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -295,7 +295,10 @@ def _check_configs(self): if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: pytest.skip("Unsupported inputs combination or device compute capability.") - if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS: + if ( + self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS + and self.bias_shape != BiasShape.BIAS_1HSS + ): if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: pytest.skip( "B1SS, BHSS and 11SS bias shapes are only supported for " From 121ff62af8c6938c9cdab15e04e62c97ac524264 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 6 Aug 2024 10:00:20 -0700 Subject: [PATCH 36/72] [PyTorch] Improve logging/messaging in attention (#1074) * fix logging in attention Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove logging in fwd/bwd methods due to CPU overhead Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: fix check_set_window_size messaging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix typo Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix window_size messaging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove redundant imports Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 60 +++++++------------ .../pytorch/module/grouped_linear.py | 23 ------- .../pytorch/module/layernorm_linear.py | 22 ------- transformer_engine/pytorch/module/linear.py | 23 ------- 4 files changed, 22 insertions(+), 106 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 0790315400..7586cc1bcb 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -98,12 +98,12 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) +_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL +_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} +_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] +_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") +_stream_handler = logging.StreamHandler() +_stream_handler.setFormatter(_formatter) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) @@ -266,6 +266,9 @@ def get_attention_backend( # Run config logger = logging.getLogger("DotProductAttention") + logger.setLevel(_log_level) + if not logger.hasHandlers(): + logger.addHandler(_stream_handler) device_compute_capability = get_device_compute_capability() cudnn_version = get_cudnn_version() run_config = { @@ -3236,31 +3239,28 @@ def check_set_window_size( """ orig_window_size = window_size if "causal" in attn_mask_type: - if orig_window_size is None or ( - orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0] - ): + if orig_window_size is None: window_size = (-1, 0) - warnings.warn( - "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type - ) - elif orig_window_size[0] >= 0: + elif orig_window_size == (-1, -1) or ( + orig_window_size[0] >= 0 and orig_window_size[1] != 0 + ): window_size = (orig_window_size[0], 0) warnings.warn( "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type ) - else: + elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0): assert False, ( "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type ) elif attn_mask_type in ["no_mask", "padding", "arbitrary"]: - if orig_window_size is None or ( - orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0] - ): + if orig_window_size is None: + window_size = (-1, -1) + elif orig_window_size == (-1, 0): window_size = (-1, -1) warnings.warn( "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type ) - elif orig_window_size[0] < 0 or orig_window_size[1] < 0: + elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0): assert False, ( "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type ) @@ -3560,9 +3560,7 @@ def forward( fp8_meta, deterministic, ): - logger = logging.getLogger("FusedAttnFunc_qkvpacked") if fp8: - logger.debug("Running forward in FP8") if fp8_meta["recipe"].fp8_mha: assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA." fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv @@ -3646,7 +3644,6 @@ def forward( fp8_meta["scaling_fwd"].scale_inv.clone(), ) else: - logger.debug("Running forward in %s", qkv.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( is_training, max_seqlen, @@ -3699,7 +3696,6 @@ def forward( @staticmethod def backward(ctx, d_out): - logger = logging.getLogger("FusedAttnFunc_qkvpacked") if ctx.fp8_meta["recipe"].fp8_mha: assert isinstance( d_out, Float8Tensor @@ -3753,7 +3749,6 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): if ctx.fp8: - logger.debug("Running backward in FP8") fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False @@ -3819,7 +3814,6 @@ def backward(ctx, d_out): ctx.qkv_dtype, ).view(dqkv_fp8.shape) else: - logger.debug("Running backward in %s", qkv.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(qkv.dtype) dqkv, *rest = fused_attn_bwd_qkvpacked( @@ -3937,9 +3931,7 @@ def forward( fp8_meta, deterministic, ): - logger = logging.getLogger("FusedAttnFunc_kvpacked") if fp8: - logger.debug("Running forward in FP8") if fp8_meta["recipe"].fp8_mha: assert isinstance(q, Float8Tensor) and isinstance( kv, Float8Tensor @@ -4036,7 +4028,6 @@ def forward( fp8_meta["scaling_fwd"].scale_inv.clone(), ) else: - logger.debug("Running forward in %s", q.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( is_training, max_seqlen_q, @@ -4100,7 +4091,6 @@ def forward( @staticmethod def backward(ctx, d_out): - logger = logging.getLogger("FusedAttnFunc_kvpacked") if ctx.fp8_meta["recipe"].fp8_mha: assert isinstance( d_out, Float8Tensor @@ -4158,7 +4148,6 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): if ctx.fp8: - logger.debug("Running backward in FP8") fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False @@ -4243,7 +4232,6 @@ def backward(ctx, d_out): ctx.qkv_dtype, ).view(dkv_fp8.shape) else: - logger.debug("Running backward in %s", q.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(q.dtype) dq, dkv, *rest = fused_attn_bwd_kvpacked( @@ -4374,9 +4362,7 @@ def forward( fp8_meta, deterministic, ): - logger = logging.getLogger("FusedAttnFunc") if fp8: - logger.debug("Running forward in FP8") fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if fp8_meta["recipe"].fp8_mha: @@ -4544,7 +4530,6 @@ def forward( fp8_meta["scaling_fwd"].scale_inv.clone(), ) else: - logger.debug("Running forward in %s", q.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -4618,7 +4603,6 @@ def forward( @staticmethod def backward(ctx, d_out): - logger = logging.getLogger("FusedAttnFunc") if ctx.fp8_meta["recipe"].fp8_mha: assert isinstance( d_out, Float8Tensor @@ -4680,7 +4664,6 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn"): if ctx.fp8: - logger.debug("Running backward in FP8") fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False @@ -4818,7 +4801,6 @@ def backward(ctx, d_out): ctx.qkv_dtype, ).view(dv_fp8.shape) else: - logger.debug("Running backward in %s", q.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(q.dtype) dq, dk, dv, *rest = fused_attn_bwd( @@ -4959,7 +4941,6 @@ def __init__( ) -> None: super().__init__() - self.logger = logging.getLogger("FusedAttention") self.softmax_scale = softmax_scale self.attention_dropout = attention_dropout self.attention_dropout_ctx = attention_dropout_ctx @@ -5306,6 +5287,9 @@ def __init__( super().__init__() self.logger = logging.getLogger("DotProductAttention") + self.logger.setLevel(_log_level) + if not self.logger.hasHandlers(): + self.logger.addHandler(_stream_handler) self.qkv_format = qkv_format attn_mask_type = attn_mask_type.replace(",", "_") if attn_mask_type == "causal_padding": @@ -5619,7 +5603,7 @@ def forward( if self.fp8_meta["recipe"].fp8_mha: if not self.fp8_meta["recipe"].fp8_dpa: self.fp8_meta["recipe"].fp8_dpa = True - self.logger.WARNING( + self.logger.warning( """Forcing fp8_meta["recipe"].fp8_dpa=True due to """ """fp8_meta["recipe"].fp8_mha=True""" ) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 8aeb068412..c55225eed9 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -3,8 +3,6 @@ # See LICENSE for license information. """GroupedLinear API""" -import os -import logging from typing import Union, Optional, Callable, Tuple, List, Dict, Any import torch @@ -45,17 +43,6 @@ from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) - __all__ = ["GroupedLinear"] """ @@ -97,7 +84,6 @@ def forward( is_grad_enabled: bool, *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], ) -> torch.Tensor: - logger = logging.getLogger("GroupedLinear") num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms] @@ -151,8 +137,6 @@ def forward( inputmats = inputmats_no_fp8 if fp8: - logger.debug("Running forward in FP8") - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases @@ -184,8 +168,6 @@ def forward( use_split_accumulator=_2X_ACC_FPROP, ) else: - logger.debug("Running forward in %s", activation_dtype) - # Cast for native AMP weights = [cast_if_needed(w, activation_dtype) for w in weights] biases = ( @@ -286,8 +268,6 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - logger = logging.getLogger("GroupedLinear") - with torch.cuda.nvtx.range("_GroupedLinear_backward"): ( fwd_scale_inverses, @@ -353,7 +333,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: if ctx.fp8: - logger.debug("Running backward in FP8") dgrad = torch.empty( (sum(ctx.m_splits), weights_fp8[i].size(1)), dtype=ctx.activation_dtype, @@ -376,8 +355,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator=_2X_ACC_DGRAD, ) else: - logger.debug("Running backward in %s", ctx.activation_dtype) - dgrad = torch.empty( (sum(ctx.m_splits), weights[0].size(1)), dtype=ctx.activation_dtype, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 262c6f8d16..10560cdad6 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -5,7 +5,6 @@ """LayerNormLinear API""" import os import warnings -import logging from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -48,17 +47,6 @@ from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) - __all__ = ["LayerNormLinear"] @@ -104,7 +92,6 @@ def forward( ub_name: str, fsdp_group: Union[dist_group_type, None], ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: - logger = logging.getLogger("LayerNormLinear") # Make sure input dimensions are compatible in_features = ln_weight.numel() assert inp.shape[-1] == in_features, "GEMM not possible" @@ -203,8 +190,6 @@ def forward( ln_out = ln_out_total if fp8: - logger.debug("Running forward in FP8") - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias = cast_if_needed(bias, bias_dtype) if use_bias else bias @@ -259,8 +244,6 @@ def forward( dtype=activation_dtype, ) else: - logger.debug("Running forward in %s", activation_dtype) - # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) bias = cast_if_needed(bias, activation_dtype) if use_bias else bias @@ -379,7 +362,6 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - logger = logging.getLogger("LayerNormLinear") if isinstance(grad_outputs[0], Float8Tensor): ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ 0 @@ -500,8 +482,6 @@ def backward( ub_obj = None if ctx.fp8: - logger.debug("Running backward in FP8") - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) out_index, meta_tensor, out_te_type, out_type = ( @@ -544,8 +524,6 @@ def backward( ) clear_tensor_data(grad_output_c) else: - logger.debug("Running backward in %s", ctx.activation_dtype) - # DGRAD: Evaluated unconditionally to feed into Linear backward _, _, _ = tex.gemm( weight, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7510254a9d..68d333262d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,8 +3,6 @@ # See LICENSE for license information. """Linear API""" -import os -import logging from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -51,17 +49,6 @@ from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) - __all__ = ["Linear"] @@ -97,7 +84,6 @@ def forward( is_first_module_in_mha: bool, fsdp_group: Union[dist_group_type, None], ) -> torch.Tensor: - logger = logging.getLogger("Linear") is_input_fp8 = isinstance(inp, Float8Tensor) if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0] @@ -158,8 +144,6 @@ def forward( else: inputmat_total = inputmat if fp8: - logger.debug("Running forward in FP8") - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias = cast_if_needed(bias, bias_dtype) if use_bias else bias @@ -248,8 +232,6 @@ def forward( dtype=activation_dtype, ) else: - logger.debug("Running forward in %s", activation_dtype) - # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) bias = cast_if_needed(bias, activation_dtype) if use_bias else bias @@ -373,7 +355,6 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - logger = logging.getLogger("Linear") if isinstance(grad_output, Float8Tensor): ctx.fp8_meta["scaling_bwd"].scale_inv[ tex.FP8BwdTensors.GRAD_OUTPUT1 @@ -450,8 +431,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: if ctx.fp8: - logger.debug("Running backward in FP8") - if ctx.is_input_fp8: out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8BwdTensors.GRAD_INPUT1, @@ -494,8 +473,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, ) else: - logger.debug("Running backward in %s", ctx.activation_dtype) - dgrad, _, _ = gemm( weight, grad_output, From 8833a8d0f114de2ac8023907c75f22ff53bb300b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 6 Aug 2024 10:14:55 -0700 Subject: [PATCH 37/72] [PyTorch] Reduce the amount of roundup for max_seqlen in THD (#1079) reduce the roundup of max_seqlen for THD to multiples of 64 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7586cc1bcb..c8ca157c28 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5725,13 +5725,13 @@ def forward( seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] else: seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item()))) + max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64) if max_seqlen_kv is None: if cu_seqlens_kv_padded is not None: seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1] else: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item()))) + max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) batch_size = len(cu_seqlens_q) - 1 cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group) From 6717554f11f9b8bd79f917560e525d538c95b3bc Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 6 Aug 2024 11:03:30 -0700 Subject: [PATCH 38/72] Add user to TE CI (#1081) Signed-off-by: Tim Moon --- .github/workflows/trigger-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 5091e5d4f6..cd47fa9a54 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -32,7 +32,7 @@ jobs: || github.actor == 'sudhakarsingh27' || github.actor == 'Oleg-Goncharov' || github.actor == 'phu0ngng' - || github.actor == 'nvcforster' + || github.actor == 'xrennvidia' ) steps: - name: Check if comment is issued by authorized person From 86f27e12a58b9a5e2b3be9531e695a3ac93a4f69 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Thu, 8 Aug 2024 21:28:26 +0800 Subject: [PATCH 39/72] [JAX] Support non-deterministic algo for cuDNN FA (#1056) * Support non-deterministic algo Signed-off-by: Reese Wang * Refine the helper function name Signed-off-by: Reese Wang * Move fixture to conftest.py Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- tests/jax/conftest.py | 19 ++++++++++++ tests/jax/test_praxis_layers.py | 14 --------- .../jax/cpp_extensions/attention.py | 14 ++++++++- transformer_engine/jax/csrc/extensions.h | 6 ++-- .../jax/csrc/extensions/attention.cpp | 30 ++++++++++--------- .../jax/csrc/extensions/packing.cpp | 5 ++-- transformer_engine/jax/flax/transformer.py | 8 +++++ 7 files changed, 63 insertions(+), 33 deletions(-) diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 55494c42d6..ccb6690a87 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -2,9 +2,12 @@ # # See LICENSE for license information. """conftest for tests/jax""" +import os import jax import pytest +from transformer_engine.transformer_engine_jax import get_device_compute_capability + @pytest.fixture(autouse=True, scope="function") def clear_live_arrays(): @@ -14,3 +17,19 @@ def clear_live_arrays(): yield for arr in jax.live_arrays(): arr.delete() + + +@pytest.fixture(autouse=True, scope="module") +def enable_fused_attn(): + """ + Enable fused attn for hopper+ arch. + Fused attn kernels on pre-hopper arch are not deterministic. + """ + if get_device_compute_capability(0) >= 90: + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + yield + if "NVTE_FUSED_ATTN" in os.environ: + del os.environ["NVTE_FUSED_ATTN"] + if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ: + del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 92a6c80028..ccab73088a 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -15,7 +15,6 @@ from utils import assert_allclose -from transformer_engine.transformer_engine_jax import get_device_compute_capability from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.jax import fp8_autocast, update_collections from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral @@ -43,19 +42,6 @@ FP8_FORMATS = [Format.E4M3, Format.HYBRID] -@pytest.fixture(autouse=True, scope="module") -def enable_fused_attn(): - """ - Enable fused attn for hopper+ arch. - Fused attn kernels on pre-hopper arch are not deterministic. - """ - if get_device_compute_capability(0) >= 90: - os.environ["NVTE_FUSED_ATTN"] = "1" - yield - if "NVTE_FUSED_ATTN" in os.environ: - del os.environ["NVTE_FUSED_ATTN"] - - def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): for key in ref_fd: assert key in test_fd, f"{key} not found in test dict {test_fd}" diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6fa43b7961..76ccec363b 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3,8 +3,9 @@ # See LICENSE for license information. """JAX/TE custom ops for attention""" from dataclasses import dataclass -from functools import partial, reduce +from functools import partial, reduce, cache import operator +import os from typing import Optional, Tuple import warnings @@ -84,6 +85,12 @@ def get_fused_attn_backend(self): self.head_dim, ) + @staticmethod + @cache + def is_non_deterministic_allowed(): + """Check if non-deterministic kernels are allowed""" + return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + @staticmethod def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): """Parse qkv aval""" @@ -365,6 +372,7 @@ def lowering( jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training, + not FusedAttnHelper.is_non_deterministic_allowed(), ) out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) @@ -642,6 +650,8 @@ def abstract( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) + deterministic = not FusedAttnHelper.is_non_deterministic_allowed() + input_batch = reduce(operator.mul, batch_shape) wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( input_batch, @@ -659,6 +669,7 @@ def abstract( qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training, + deterministic, max_segments_per_seq, ) @@ -764,6 +775,7 @@ def lowering( jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training, + not FusedAttnHelper.is_non_deterministic_allowed(), ) out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index c541fb8afa..c084ab09e9 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -147,6 +147,7 @@ struct CustomCallFusedAttnDescriptor { DType dtype; DType wkspace_dtype; bool is_training; + bool deterministic; }; pybind11::bytes PackCustomCallFusedAttnDescriptor( @@ -154,7 +155,8 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training); + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic); // Transpose @@ -260,7 +262,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, - size_t max_segments_per_seq); + bool deterministic, size_t max_segments_per_seq); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 866147b336..1d367f5cc1 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -336,7 +336,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, - size_t max_segments_per_seq) { + bool deterministic, size_t max_segments_per_seq) { // For qkv_packed auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); @@ -392,13 +392,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - nvte_fused_attn_bwd_qkvpacked( - qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, -1, -1, true, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, -1, -1, deterministic, + query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -408,7 +409,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, - -1, true, query_workspace_tensor.data(), nullptr); + -1, deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -419,7 +420,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, - -1, true, query_workspace_tensor.data(), nullptr); + -1, deterministic, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -467,6 +468,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; auto dtype = descriptor.dtype; + auto deterministic = descriptor.deterministic; auto max_segments_per_seq = descriptor.max_segments_per_seq; /* Input tensors */ @@ -539,7 +541,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, -1, -1, true, workspace_tensor.data(), stream); + bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -566,7 +568,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true, + dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q = buffers[0]; @@ -602,8 +604,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true, - workspace_tensor.data(), stream); + dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, + deterministic, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 8c948d0a8f..128564db64 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -68,11 +68,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) { + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic) { return PackOpaque(CustomCallFusedAttnDescriptor{ input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, - mask_type, qkv_layout, dtype, wkspace_dtype, is_training}); + mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic}); } } // namespace jax diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index d53a4e5202..c62c2bb77d 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -359,6 +359,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods kernel is not available on the system, a warning will be issued, and the module will automatically fall back to the unfused backend. + .. note:: + The DotProductAttention default setting enables non-deterministic kernels for reduced + workspace requirements and faster computation. Users can disable the non-deterministic + kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable: + + * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels. + * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default). + Parameters ---------- head_dim: int From fa4b866d5f7159338d3367c15cc34a2d7ff96d2c Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 9 Aug 2024 10:26:16 -0500 Subject: [PATCH 40/72] [C/PyTorch] Fixed incorrect use of `torch.distributed.new_group()` when creating intra-node group in `initialize_ub()` (#1087) * updated initialize_ub() to use new_subgroups_by_enumeration() to generate intra-node groups, added new unit tests for TE layers with comm overlap Signed-off-by: Alp Dener * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ...th_overlap.py => te_layer_with_overlap.py} | 267 ++++++++---- qa/L1_pytorch_distributed_unittest/test.sh | 1 + .../distributed/run_gemm_with_overlap.py | 392 ++++++++++++------ .../distributed/run_layer_with_overlap.py | 352 ++++++++++++++++ .../distributed/test_comm_gemm_overlap.py | 255 +++++++++--- .../pytorch/csrc/comm_gemm_overlap.h | 76 ++-- .../pytorch/csrc/userbuffers/userbuffers.cu | 8 + transformer_engine/pytorch/module/base.py | 56 +-- 8 files changed, 1064 insertions(+), 343 deletions(-) rename examples/pytorch/comm_gemm_overlap/{ln_mlp_with_overlap.py => te_layer_with_overlap.py} (50%) create mode 100644 tests/pytorch/distributed/run_layer_with_overlap.py diff --git a/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py similarity index 50% rename from examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py rename to examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index 412c948a83..ab6b656be9 100644 --- a/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -7,6 +7,8 @@ import os import sys import socket +import fcntl +import struct import argparse import warnings @@ -15,15 +17,37 @@ from torch.nn.parallel import DistributedDataParallel import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpp_extensions as tex from transformer_engine.common.recipe import Format, DelayedScaling +warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + + +def _te_layer_argtype(name): + te_layers = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers)) + if name.lower() not in layer_map.keys(): + raise argparse.ArgumentTypeError( + f"Invalid TE layer name! Please choose from: {layer_map.keys()}" + ) + return layer_map[name.lower()] + def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser( - description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers." + description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers." ) parser.add_argument( "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations." @@ -37,10 +61,10 @@ def _parse_args(argv=None, namespace=None): "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." ) parser.add_argument( - "--mlp-expansion-factor", - type=int, - default=4, - help="MLP block intermediate size as a factor of hidden dimension.", + "--layer-type", + type=_te_layer_argtype, + default=te.TransformerLayer, + help="Transformer Engine layer to train with comm+GEMM overlap.", ) parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument( @@ -88,9 +112,57 @@ def _parse_args(argv=None, namespace=None): help="Print out additional debug information.", ) args = parser.parse_args(argv, namespace) + if args.bootstrap_backend == "nccl": + args.bind_to_device = True return args +def _get_layer_args(config, tp_group, tp_size, reference=False): + hidden_size = config.num_heads * config.head_dim + input_shape = [config.seq_length, config.batch_size, hidden_size] + args = [hidden_size] + kwargs = { + "params_dtype": torch.float32, + "device": "cuda", + "tp_group": tp_group, + "tp_size": tp_size, + "sequence_parallel": True, + } + kwargs["ub_overlap_ag"] = not config.no_comm_overlap + + if config.layer_type is te.Linear: + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["parallel_mode"] = "row" + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + kwargs["ub_name"] = "proj" + else: + input_shape[0] = config.seq_length // tp_size + kwargs["ub_bulk_wgrad"] = not config.no_comm_overlap + kwargs["ub_bulk_dgrad"] = not config.no_comm_overlap + if config.layer_type is te.LayerNormLinear: + args.append(3 * hidden_size) + kwargs["parallel_mode"] = "column" + kwargs["ub_name"] = "qkv" + else: + kwargs["set_parallel_mode"] = True + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + args.append(4 * hidden_size) + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap + kwargs["hidden_dropout"] = 0.0 + + return args, kwargs, input_shape + + def _train(opts): if "OMPI_COMM_WORLD_SIZE" in os.environ: # Execution with `mpirun -np N` @@ -110,19 +182,6 @@ def _train(opts): raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") NUM_NODES = WORLD_SIZE // LOCAL_SIZE - def dist_print(msg, group=None, end="\n", debug=False): - if debug and not opts.debug: - return - group = dist.new_group() if group is None else group - group_rank = dist.get_rank(group) - group_size = dist.get_world_size(group) - all_ranks = dist.get_process_group_ranks(group) - ranks_skip = all_ranks[1] - all_ranks[0] > 1 - group_id = WORLD_RANK % group_size if ranks_skip else WORLD_RANK // group_size - if group_rank == 0 or opts.verbose: - print(f"[rank{WORLD_RANK}:node{group_id}] {msg}{end}", end="", flush=True) - dist.barrier(group) - # Initialize torch.distributed global process group and get DP/TP groups torch.cuda.set_device(LOCAL_RANK) dist_init_kwargs = { @@ -143,75 +202,117 @@ def dist_print(msg, group=None, end="\n", debug=False): assert dist.is_nccl_available() dist.init_process_group(**dist_init_kwargs) nccl_world = dist.new_group(backend="nccl") - dist_print(f"Initialized default NCCL process group with {WORLD_RANK} GPUs", nccl_world) + + def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False): + if debug and not opts.debug: + return + group_rank = dist.get_rank(group) + stream = sys.stderr if error else sys.stdout + if group_rank == src: + stream.write(f"[rank{WORLD_RANK}] {msg}{end}") + dist.barrier(group) + + dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") # Figure out process groups for tensor- and data-parallelism (if any) if NUM_NODES > 1: # Create a list of world ranks on this node - hostnames = [None for _ in range(WORLD_SIZE)] hostname = socket.gethostname() + ifname = os.getenv( + "NVTE_UB_SOCKET_IFNAME", + os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), + ) + + if ifname is not None: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + hostname = socket.inet_ntoa( + fcntl.ioctl( + s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) + )[20:24] + ) + except OSError as err: + raise OSError(f"Invalid network interface: {ifname}") from err + + hostnames = [None for _ in range(WORLD_SIZE)] dist.all_gather_object(hostnames, hostname) - node_ranks = [] + unique_hosts = [] + for host in hostnames: + if host not in unique_hosts: + unique_hosts.append(host) + assert len(unique_hosts) == NUM_NODES + + ranks_per_node_list = [[] for _ in range(NUM_NODES)] + self_node_idx = -1 for i, host in enumerate(hostnames): + node_idx = unique_hosts.index(host) + ranks_per_node_list[node_idx].append(i) if host == hostname: - node_ranks.append(i) + self_node_idx = node_idx + assert self_node_idx >= 0 + self_node_ranks = ranks_per_node_list[self_node_idx] if opts.num_replicas > 1: # Split node ranks into multiple replicas - assert len(node_ranks) % opts.num_replicas == 0 - tp_size = len(node_ranks) // opts.num_replicas - found_replica = False - for replica in range(opts.num_replicas): - start = replica * tp_size - end = start + tp_size - tp_ranks = node_ranks[start:end] - if WORLD_RANK in tp_ranks: - found_replica = True + assert len(self_node_ranks) % opts.num_replicas == 0 + tp_size = len(self_node_ranks) // opts.num_replicas + ranks_per_replica_list = [] + for node_ranks in ranks_per_node_list: + for i in range(opts.num_replicas): + start = i * tp_size + end = start + tp_size + ranks_per_replica_list.append(node_ranks[start:end]) + + self_replica_idx = -1 + for i, replica_ranks in enumerate(ranks_per_replica_list): + if WORLD_RANK in replica_ranks: + self_replica_idx = i break - assert found_replica + assert self_replica_idx >= 0 + else: # The entire node is the tensor-parallel group - tp_ranks = node_ranks - - tp_group = dist.new_group(backend="nccl", ranks=tp_ranks) - tp_size = dist.get_world_size(tp_group) - tp_rank = dist.get_rank(tp_group) + ranks_per_replica_list = ranks_per_node_list + self_replica_idx = self_node_idx - # Data-parallelism across TP groups - dp_start = tp_rank - dp_end = dp_start + WORLD_SIZE - dp_ranks = list(range(dp_start, dp_end, tp_size)) - dp_group = dist.new_group(backend="nccl", ranks=dp_ranks) + tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") + ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) else: if opts.num_replicas > 1: # Mixed data- and tensor-parallelism on a single node # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") - mesh2d = all_ranks.reshape((opts.num_replicas, LOCAL_SIZE // opts.num_replicas)) - node_idx = (mesh2d == LOCAL_RANK).nonzero().squeeze().tolist() - - tp_ranks = mesh2d[node_idx[0], :].tolist() - tp_group = dist.new_group(backend="nccl", ranks=tp_ranks) - - dp_ranks = mesh2d[:, node_idx[1]].tolist() - dp_group = dist.new_group(backend="nccl", ranks=dp_ranks) + ranks_per_replica_tensor = all_ranks.reshape( + (opts.num_replicas, LOCAL_SIZE // opts.num_replicas) + ) + tp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.tolist(), backend="nccl" + ) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) else: dp_group = None tp_group = nccl_world - tp_rank = dist.get_rank(tp_group) - tp_size = dist.get_world_size(tp_group) - + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) dist_print( f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}", group=tp_group, ) if dp_group is not None: + dp_rank = dist.get_rank(dp_group) dist_print( f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}", group=dp_group, ) + else: + dp_rank = 0 # Intialize userbuffers hidden_size = opts.num_heads * opts.head_dim @@ -226,26 +327,12 @@ def dist_print(msg, group=None, end="\n", debug=False): ) # Initialize the fused LayerNorm + Multi-layer Perceptron module - torch.manual_seed(opts.seed + tp_rank) + torch.manual_seed(opts.seed + dp_rank) torch.cuda.manual_seed(opts.seed + tp_rank) - model = te.LayerNormMLP( - hidden_size, - opts.mlp_expansion_factor * hidden_size, - params_dtype=torch.bfloat16, - device="cuda", - tp_group=tp_group, - tp_size=tp_size, - set_parallel_mode=True, - sequence_parallel=True, # this is required for comm+GEMM overlap - seq_length=opts.seq_length, - ub_overlap_rs=not opts.no_comm_overlap, - ub_overlap_ag=not opts.no_comm_overlap, - ub_overlap_rs_dgrad=not opts.no_comm_overlap, - ub_bulk_dgrad=False, - ub_bulk_wgrad=not opts.no_comm_overlap, - ) + layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size) + model = opts.layer_type(*layer_args, **layer_kwargs) if dp_group is not None: - model = DistributedDataParallel(model, process_group=dp_group) + model = DistributedDataParallel(model, dim=1, process_group=dp_group) # Initialize optimizer with model parameters optim = torch.optim.Adam(model.parameters(), lr=0.0001) @@ -255,28 +342,28 @@ def dist_print(msg, group=None, end="\n", debug=False): fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") # Start dummy "training" iterations - dist_print("Starting training iterations...", nccl_world) + dist_print("Starting training iterations...") for i in range(opts.num_iters): - dist_print(f" Iter {i+1}", tp_group, debug=True) - - dist_print(" |-- Generate random input batch", tp_group, debug=True) - x = torch.rand( - (opts.seq_length // tp_size, opts.batch_size, hidden_size), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - - dist_print(" |-- Forward pass", tp_group, debug=True) - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): - y = model(x) - dist_print(" |-- Compute loss", tp_group, debug=True) - loss = y.flatten().sum() - - dist_print(" |-- Backward pass", tp_group, debug=True) + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) loss.backward() - dist_print(" |-- Optimizer step", tp_group, debug=True) + dist_print(" |-- Optimizer step", group=tp_group, debug=True) optim.step() torch.cuda.synchronize() diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 71c55851d5..fef48fd4b0 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -5,6 +5,7 @@ set -e : ${TE_PATH:=/opt/transformerengine} +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py git clone https://github.com/NVIDIA/Megatron-LM.git cd Megatron-LM diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index d7dc3e1ce1..5ba70ccbdd 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -46,17 +46,20 @@ def _mapped_argtype(opt, typemap): def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser(description="Test comm+GEMM overlap with Userbuffers.") parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") - parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.") + parser.add_argument("-s", "--seq-length", type=int, default=512, help="Input sequence length.") parser.add_argument( - "-n", "--num-heads", type=int, default=64, help="Number of attention heads." + "-n", "--num-heads", type=int, default=12, help="Number of attention heads." ) parser.add_argument( - "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." + "-d", "--head-dim", type=int, default=64, help="Dimension of each attention head." ) parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument( "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." ) + parser.add_argument( + "--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM." + ) parser.add_argument( "--p2p", action="store_true", default=False, help="Test overlap with P2P comms." ) @@ -106,7 +109,7 @@ def _parse_args(argv=None, namespace=None): help="Set device clock speed to a fixed value via `nvidia-smi`.", ) parser.add_argument( - "--scale", type=float, default=1e-2, help="Set scaling factor for input and weight tensors." + "--std", type=float, default=0.023, help="Standard deviation for input and weight tensors." ) parser.add_argument( "--tcp-init", @@ -135,6 +138,9 @@ def _parse_args(argv=None, namespace=None): + "initialization." ), ) + parser.add_argument( + "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA graphs." + ) parser.add_argument( "-v", "--verbose", action="store_true", default=False, help="Verbose info messages." ) @@ -150,14 +156,17 @@ def _parse_args(argv=None, namespace=None): if opts.fp8: warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") opts.fp8 = False - elif opts.comm_type == 1 and not opts.p2p: - warnings.warn("All-gather overlap is only supported with point-2-point comms.") - opts.p2p = True + elif opts.comm_type == 1: + if opts.atomic: + setattr(opts, "atomic_rs_p2p", opts.p2p) + if not opts.p2p: + warnings.warn("All-gather overlap is only supported with point-2-point comms.") + opts.p2p = True if opts.atomic: if not te.fp8.check_fp8_support(): assert not opts.fp8, "Atomic GEMM is only supported in FP8." - opts.fp8 = True + opts.fp8 = True return opts @@ -203,13 +212,14 @@ def _main(opts): print(f"[rank:{LOCAL_RANK}] {msg}\n", end="", flush=True) # Info printout - def dist_print(msg, src=None, info=False, section=False, group=None): + def dist_print(msg, src=None, info=False, error=False, section=False, group=None): group = dist.new_group() if group is None else group rank = dist.get_rank(group) + stream = sys.stderr if error else sys.stdout if info or opts.verbose: if section: if rank == (0 if src is None else src): - print("\n", end="", flush=True) + stream.write("\n") dist.barrier(group) if src is None or rank == src: prefix = "[GLOBAL] " if src is not None else f"[rank:{rank}] " @@ -217,7 +227,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): msg = "\n".join( [prefix + lines[0]] + [(" " * len(prefix)) + line for line in lines[1:]] ) - print(msg + "\n", end="", flush=True) + stream.write(msg + "\n") dist.barrier(group) # Initialize torch.distributed global process group and get TP group @@ -312,7 +322,9 @@ def dist_print(msg, src=None, info=False, section=False, group=None): hidden_size = opts.num_heads * opts.head_dim inp_shape = (opts.seq_length, opts.batch_size, hidden_size) outer_size = reduce(operator.mul, inp_shape[:-1], 1) - ubuf_dtype = torch.uint8 if opts.fp8 and opts.comm_type == 1 else torch.bfloat16 + ubuf_dtype = torch.bfloat16 + if opts.fp8 and not opts.bulk_overlap and (opts.comm_type == 1 or opts.fp8_output): + ubuf_dtype = torch.uint8 sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda") ub_obj = ub_obj = ( tex.UbufP2PCommOverlap( @@ -331,7 +343,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): 3, # Max concurrent GEMM streams opts.comm_type == 0, # overlap with reduce scatter opts.atomic, # use a single GEMM with atomic-counters - True, # Use copy engine for P2P communications + not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), ub_callbacks, ) if opts.p2p @@ -349,7 +361,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): 4, # Number of communication splits True, # Set SM margin 3, # Max concurrent GEMM streams - opts.atomic, # uUe a single GEMM with atomic-counters + opts.atomic, # Use a single GEMM with atomic-counters ub_callbacks, ) ) @@ -357,25 +369,49 @@ def dist_print(msg, src=None, info=False, section=False, group=None): # Numerical check on AG + atomic GEMM requires testing an AG+RS pair ub_obj2 = None if opts.atomic and opts.comm_type == 1 and opts.check_numerics: - sample_buffer2 = torch.empty((outer_size, hidden_size), dtype=torch.bfloat16, device="cuda") - ub_obj2 = tex.UbufP2PCommOverlap( - sample_buffer2, # Sample userbuffer - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes - tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 1, # Number of communication SMs - 1, # CGA cluster size - True, # Set SM margin - False, # Aggregate 2X GEMM chunks - 3, # Max concurrent GEMM streams - True, # overlap with reduce scatter - True, # use a single GEMM with atomic-counters - True, # use copy engine for P2P communications - ub_callbacks, + sample_buffer2 = torch.empty( + (outer_size, hidden_size), + dtype=torch.uint8 if opts.fp8_output else torch.bfloat16, + device="cuda", + ) + ub_obj2 = ( + tex.UbufP2PCommOverlap( + sample_buffer2, # Sample userbuffer + WORLD_RANK, # World rank + WORLD_SIZE, # World size + LOCAL_RANK, # Rank within the node + LOCAL_SIZE, # Number of ranks/GPUs per node + 0, # Node ID + 1, # Number of nodes + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + 1, # Number of communication SMs + 1, # CGA cluster size + True, # Set SM margin + False, # Aggregate 2X GEMM chunks + 3, # Max concurrent GEMM streams + True, # overlap with reduce scatter + True, # use a single GEMM with atomic-counters + True, # use copy engine for P2P communications + ub_callbacks, + ) + if opts.atomic_rs_p2p + else tex.UbufCommOverlap( + sample_buffer2, # Sample userbuffer + WORLD_RANK, # World rank + WORLD_SIZE, # World size + LOCAL_RANK, # Rank within the node + LOCAL_SIZE, # Number of ranks/GPUs per node + 0, # Node ID + 1, # Number of nodes + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + 16, # Number of communication SMs + 2, # CGA cluster size + 4, # Number of communication splits + True, # Set SM margin + 3, # Max concurrent GEMM streams + True, # uUe a single GEMM with atomic-counters + ub_callbacks, + ) ) # Figure out problem sizing: @@ -409,43 +445,53 @@ def dist_print(msg, src=None, info=False, section=False, group=None): # Initialize distributed input tensor and GEMM kernels torch.manual_seed(opts.seed + tp_rank) torch.cuda.manual_seed(opts.seed + tp_rank) - inp = torch.mul(torch.rand(local_inp_shape, dtype=torch.bfloat16, device="cuda"), opts.scale) - kernel_t = torch.mul( - torch.rand(local_kernel_t_shape, dtype=torch.bfloat16, device="cuda"), opts.scale + inp = torch.nn.init.normal_( + torch.empty(local_inp_shape, dtype=torch.bfloat16, device="cuda"), + mean=0.0, + std=opts.std, + ) + kernel_t = torch.nn.init.normal_( + torch.empty(local_kernel_t_shape, dtype=torch.bfloat16, device="cuda"), + mean=0.0, + std=opts.std, ) if ub_obj2 is not None: - kernel2_t = torch.mul( - torch.rand(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"), opts.scale + kernel2_t = torch.nn.init.normal_( + torch.empty(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"), + mean=0.0, + std=opts.std, ) # Gather global tensors and calculate reference result (need these first for Fp8 scales) if opts.bulk_overlap: ker_g = torch.transpose(kernel_t, 0, 1) inp_g = inp - bulk_inp = torch.mul( - torch.rand(bulk_inp_shape, dtype=torch.bfloat16, device="cuda"), opts.scale + bulk_inp = torch.nn.init.normal_( + torch.empty(bulk_inp_shape, dtype=torch.bfloat16, device="cuda"), + mean=0.0, + std=opts.std, ) else: if opts.comm_type == 1: # AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K) ker_g = torch.transpose( te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1 - ) + ).to(dtype=torch.float32) # AG Input: (M/P, N) -> gather -> (M, N) - inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0] + inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0].to(dtype=torch.float32) if ub_obj2 is not None: ker2_g = te.distributed.gather_along_first_dim( torch.transpose(kernel2_t, 0, 1), tp_group - )[0] + )[0].to(dtype=torch.float32) else: # RS Kernel: (N, K/P) -> T -> (K/P, N) -> gather -> (K, N) ker_g = te.distributed.gather_along_first_dim( torch.transpose(kernel_t, 0, 1), tp_group - )[0] + )[0].to(dtype=torch.float32) # RS Input: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) inp_g = torch.transpose( te.distributed.gather_along_first_dim(torch.transpose(inp, 0, 1), tp_group)[0], 0, 1 - ) + ).to(dtype=torch.float32) if opts.bulk_overlap: if opts.comm_type == 1: @@ -459,7 +505,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): else: ref_g = torch.matmul(inp_g, ker_g) if ub_obj2 is not None: - inp2_g = torch.mul(ref_g, opts.scale) + inp2_g = torch.nn.functional.gelu(ref_g) ref2_g = torch.matmul(inp2_g, ker2_g) if opts.fp8: @@ -483,7 +529,10 @@ def dist_print(msg, src=None, info=False, section=False, group=None): fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) ref_amax = torch.max(torch.abs(ref_g)) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) - if ub_obj2 is not None: + if opts.bulk_overlap and opts.comm_type == 0: + bulk_amax = torch.max(torch.abs(bulk_inp)) + fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax) + elif ub_obj2 is not None: inp2_amax = torch.max(torch.abs(inp2_g)) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_INPUT].copy_(inp2_amax) ker2_amax = torch.max(torch.abs(ker2_g)) @@ -502,7 +551,11 @@ def dist_print(msg, src=None, info=False, section=False, group=None): kernel_t_fp8 = tex.cast_to_fp8( kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype ) - if ub_obj2 is not None: + if opts.bulk_overlap and opts.comm_type == 0: + bulk_inp_fp8 = tex.cast_to_fp8( + bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype + ) + elif ub_obj2 is not None: kernel2_t_fp8 = tex.cast_to_fp8( kernel2_t, fp8_meta, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype ) @@ -521,7 +574,14 @@ def dist_print(msg, src=None, info=False, section=False, group=None): rtol=0.125, atol=0.0675, ) - if ub_obj2 is not None: + if opts.bulk_overlap and opts.comm_type == 0: + torch.allclose( + bulk_inp.to(dtype=torch.float32), + bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT], + rtol=0.125, + atol=0.0675, + ) + elif ub_obj2 is not None: torch.allclose( kernel2_t.to(dtype=torch.float32), kernel2_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT], @@ -534,6 +594,8 @@ def dist_print(msg, src=None, info=False, section=False, group=None): ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) if ub_obj2 is not None: ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) + elif opts.bulk_overlap: + ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) else: ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_OUTPUT]) @@ -556,7 +618,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): ) else: if opts.bulk_overlap: - ub_obj.copy_input_to_ubuf(bulk_inp, 0) + ub_obj.copy_input_to_ubuf(bulk_inp_fp8 if opts.fp8 else bulk_inp, 0) ubuf_out = None else: ubuf_out = ub_obj.get_ubuf_output(1) @@ -565,80 +627,131 @@ def dist_print(msg, src=None, info=False, section=False, group=None): (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" ) + # Wrap GEMM ops in condensed functions to make CUDA Graphs easier to use + def _fp8_gemm(): + return tex.fp8_gemm( + kernel_t_fp8, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype, + gemm_inp, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype, + torch.uint8 if opts.fp8_output else torch.bfloat16, + te.module.base.get_workspace(), + bias=None, + use_bias=False, + gelu=False, + use_split_accumulator=te.module.base._2X_ACC_FPROP, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, + out=ubuf_out, + D_dtype=fp8_dtype if opts.fp8_output else None, + fp8_meta_tensor=fp8_meta if opts.fp8_output else None, + out_index=tex.FP8FwdTensors.GEMM1_OUTPUT if opts.fp8_output else None, + ) + + def _fp8_gemm2(gemm1_out): + gemm2_inp = tex.gelu( + ( + tex.cast_from_fp8( + gemm1_out, + fp8_meta, + tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_dtype, + tex.DType.kFloat32, + ) + if opts.fp8_output + else gemm1_out + ), + fp8_meta, + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype, + ) + return tex.fp8_gemm( + kernel2_t_fp8, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_dtype, + gemm2_inp, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype, + torch.uint8 if opts.fp8_output else torch.bfloat16, + te.module.base.get_workspace(), + bias=None, + use_bias=False, + gelu=False, + use_split_accumulator=te.module.base._2X_ACC_FPROP, + ub_algo=( + tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + if opts.atomic_rs_p2p + else tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ), + ub=ub_obj2, + extra_output_tensor=rs_out2, + out=ubuf_out2, + D_dtype=fp8_dtype if opts.fp8_output else None, + fp8_meta_tensor=fp8_meta if opts.fp8_output else None, + out_index=tex.FP8FwdTensors.GEMM2_OUTPUT if opts.fp8_output else None, + ) + + def _gemm(): + return tex.gemm( + kernel_t, + gemm_inp, + torch.bfloat16, + te.module.base.get_workspace(), + bias=None, + use_bias=False, + gelu=False, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, + out=ubuf_out, + ) + # Trigger GEMM total_iters = opts.warmup_iters + opts.timing_iters start_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)] torch.cuda.synchronize() - if opts.fp8: + if opts.use_cuda_graphs: + # Trace the CUDA graph first + g = torch.cuda.CUDAGraph() + if opts.fp8: + if ub_obj is None: + with torch.cuda.graph(g): + all_outputs = _fp8_gemm() + else: + with torch.cuda.graph(g): + all_outputs = _fp8_gemm() + _ = _fp8_gemm2(all_outputs[0]) + else: + with torch.cuda.graph(g): + all_outputs = _gemm() + + # Now replay the CUDA graph in a loop for i in range(total_iters): start_events[i].record() - all_outputs = tex.fp8_gemm( - kernel_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype, - gemm_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype, - torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, - use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, - ) + g.replay() end_events[i].record() - if ub_obj2 is not None: - gemm2_inp = tex.cast_to_fp8( - torch.mul(all_outputs[0], opts.scale), - fp8_meta, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, - ) - all_outputs = tex.fp8_gemm( - kernel2_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_WEIGHT, - fp8_dtype, - gemm2_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, - torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, - use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P, - ub=ub_obj2, - extra_output_tensor=rs_out2, - out=ubuf_out2, - ) + else: for i in range(total_iters): - start_events[i].record() - all_outputs = tex.gemm( - kernel_t, - gemm_inp, - torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, - ) - end_events[i].record() + if opts.fp8: + start_events[i].record() + all_outputs = _fp8_gemm() + end_events[i].record() + if ub_obj2 is not None: + _fp8_gemm2(all_outputs[0]) + else: + start_events[i].record() + all_outputs = _gemm() + end_events[i].record() torch.cuda.synchronize() gpu_times = [ @@ -679,7 +792,11 @@ def dist_print(msg, src=None, info=False, section=False, group=None): ref_out = ref_g output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}" - dist_print(output_info, src=0 if opts.comm_type == 0 else None, section=True) + dist_print( + output_info, + src=0 if opts.comm_type == 0 else None, + section=True, + ) test_nonzeros = torch.count_nonzero(test_out) ref_nonzeros = torch.count_nonzero(ref_out) @@ -691,11 +808,21 @@ def dist_print(msg, src=None, info=False, section=False, group=None): if opts.comm_type == 1: if ub_obj2 is not None: # AG+RS Output: (M/P, N) -> gather -> (M, N) - output = rs_out2 + output = rs_out2.to(dtype=torch.float32) test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] else: # AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) - output = all_outputs[0] + output = ( + tex.cast_from_fp8( + all_outputs[0], + fp8_meta, + tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_dtype, + tex.DType.kFloat32, + ) + if opts.fp8_output + else all_outputs[0] + ) test_out = torch.transpose( te.distributed.gather_along_first_dim( torch.transpose(output, 0, 1), tp_group @@ -705,7 +832,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): ) else: # RS Output: (M/P, N) -> gather -> (M, N) - output = rs_out + output = rs_out.to(dtype=torch.float32) test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] if opts.fp8: @@ -755,30 +882,33 @@ def dist_print(msg, src=None, info=False, section=False, group=None): torch.cuda.synchronize() dist.barrier(tp_group) - test_out = test_out.to(dtype=torch.float32) - ref_out = ref_out.to(dtype=torch.float32) - error_below_tol = torch.allclose( - test_out, - ref_out, - rtol=0.125 if opts.fp8 else 0.02, - atol=0.0675 if opts.fp8 else 0.001, - ) diff = torch.abs(test_out - ref_out).flatten() m = torch.argmax(diff) abs_err = diff[m].item() - rel_err = abs_err / (ref_out.flatten()[m].item() + 1e-5) - if not error_below_tol: + rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5) + rtol = 0.125 if opts.fp8 else 0.02 + atol = 0.0625 if opts.fp8 else 0.001 + if rel_err > rtol and abs_err > atol: numerics_failed = True numerics_info = ( "NUMERICAL CHECK FAILED: " + f"Outputs not close enough at index {m.item()} " - + f"with {test_out.flatten()[m].item()} vs {ref_out.flatten()[m].item()} " - + f"(abs error = {abs_err} | rel error = {rel_err})." + + f"with {test_out.flatten()[m].item()} vs {ref_out.flatten()[m].item()} | " + + f"rel. error = {rel_err} (tol = {rtol}) | " + + f"abs. error = {abs_err} (tol = {atol})" ) else: - numerics_info = f"NUMERICAL CHECK PASSED: abs error = {abs_err} | rel error = {rel_err}" + numerics_info = "NUMERICAL CHECK PASSED: " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err < atol else "" + ) + if abs_err <= atol: + numerics_info += f"abs. error = {abs_err} (tol = {atol})" - dist_print(numerics_info, src=0, section=True, info=True, group=tp_group) + dist_print( + numerics_info, src=0, section=True, info=True, error=numerics_failed, group=tp_group + ) dist.barrier(tp_group) if LOCAL_RANK == 0: diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py new file mode 100644 index 0000000000..e5653bda01 --- /dev/null +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -0,0 +1,352 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import socket +import argparse +import warnings +from functools import partial + +import torch +import torch.distributed as dist + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + + +def _te_layer_argtype(name): + te_layers = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers)) + if name.lower() not in layer_map.keys(): + raise argparse.ArgumentTypeError( + f"Invalid TE layer name! Please choose from: {layer_map.keys()}" + ) + return layer_map[name.lower()] + + +def _get_layer_args(config, tp_group, tp_size, reference=False): + hidden_size = config.num_heads * config.head_dim + input_shape = [config.seq_length, config.batch_size, hidden_size] + args = [hidden_size] + kwargs = { + "params_dtype": torch.float32, + "device": "cuda", + "tp_group": tp_group, + "tp_size": tp_size, + "sequence_parallel": True, + } + kwargs["ub_overlap_ag"] = not reference + + if config.layer_type is te.Linear: + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["parallel_mode"] = "row" + kwargs["ub_overlap_rs"] = not reference + kwargs["ub_name"] = "proj" + else: + input_shape[0] = config.seq_length // tp_size + kwargs["ub_bulk_wgrad"] = not reference + kwargs["ub_bulk_dgrad"] = not reference + if config.layer_type is te.LayerNormLinear: + args.append(3 * hidden_size) + kwargs["parallel_mode"] = "column" + kwargs["ub_name"] = "qkv" + else: + kwargs["set_parallel_mode"] = True + kwargs["ub_overlap_rs"] = not reference + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + args.append(4 * hidden_size) + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not reference + kwargs["hidden_dropout"] = 0.0 + + return args, kwargs, input_shape + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser( + description="Test a Transformer Engine layer with GEMM+comm overlap via Userbuffers." + ) + parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP) + parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") + parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.") + parser.add_argument( + "-n", "--num-heads", type=int, default=12, help="Number of attention heads." + ) + parser.add_argument( + "-d", "--head-dim", type=int, default=64, help="Dimension of each attention head." + ) + parser.add_argument("--seed", type=int, default=42, help="RNG seed.") + parser.add_argument( + "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + ) + parser.add_argument( + "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + ) + parser.add_argument( + "--tcp-init", + action="store_true", + default=False, + help="Initialize torch.distributed with TcpStore.", + ) + parser.add_argument( + "--bind-to-device", + action="store_true", + default=False, + help="Initialize torch.distributed with `device_id` to bind each rank to a single device.", + ) + parser.add_argument( + "--bootstrap-backend", + type=str.lower, + default="nccl", + choices=["gloo", "mpi", "nccl"], + help="Communications backend for host tensor collectives during Userbuffers bootstrapping.", + ) + parser.add_argument( + "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs." + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Print out additional debug information.", + ) + args = parser.parse_args(argv, namespace) + + if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!") + args.use_cuda_graphs = False + + return args + + +def _compare_tensors(name, test, ref, rtol, atol): + # Make sure tensors aren't zero and we don't pass trivially + if test.count_nonzero() == 0: + if ref.count_nonzero() == 0: + warnings.warn( + f"WARNING: {name} is a zero-tensor for both test and reference models!", + category=RuntimeWarning, + ) + else: + numerics_info = ( + f"NUMERICAL CHECK FAILED: {name} is a zero-tensor but does not match reference!" + ) + return 1, numerics_info + + diff = torch.abs(test - ref).flatten() + m = torch.argmax(diff) + abs_err = diff[m].item() + rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5) + numerics_failed = 0 + if rel_err > rtol and abs_err > atol: + numerics_failed = 1 + numerics_info = ( + "NUMERICAL CHECK FAILED: " + + f"{name} not close enough at index {m.item()} " + + f"with {test.flatten()[m].item()} vs {ref.flatten()[m].item()} | " + + f"rel. error = {rel_err} (tol = {rtol}) | " + + f"abs. error = {abs_err} (tol = {atol})" + ) + else: + numerics_info = f"NUMERICAL CHECK PASSED: {name} | " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err <= atol else "." + ) + if abs_err <= atol: + numerics_info += f" abs. error = {abs_err} (tol = {atol})" + + return numerics_failed, numerics_info + + +def _train(opts): + if "OMPI_COMM_WORLD_SIZE" in os.environ: + # Execution with `mpirun -np N` + WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) + WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) + opts.tcp_init = True + opts.bind_to_device = True + opts.bootstrap_backend = "mpi" + elif "TORCHELASTIC_RUN_ID" in os.environ: + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + else: + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + assert LOCAL_SIZE == WORLD_SIZE + + def dist_print(msg, src=None, end="\n", debug=False, error=False): + if debug and not opts.debug: + return + stream = sys.stderr if error else sys.stdout + if WORLD_RANK == (0 if src is None else src): + stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n") + dist.barrier() + + # Set device and initialize RNG states + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(opts.seed) + torch.cuda.manual_seed(opts.seed) + + # Initialize torch.distributed global process group and get DP/TP groups + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + if opts.tcp_init: + MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())) + MASTER_PORT = os.getenv("MASTER_PORT", "1234") + dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}" + if opts.bind_to_device or opts.bootstrap_backend == "nccl": + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") + + # Intialize userbuffers + te.module.base.initialize_ub( + [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], + WORLD_SIZE, + use_fp8=opts.fp8, + dtype=torch.bfloat16, + bootstrap_backend=opts.bootstrap_backend, + ) + + # Initialize the Transformer Engine layer with overlap + args, kwargs, input_shape = _get_layer_args(opts, nccl_world, WORLD_SIZE) + with te.fp8_model_init(enabled=opts.fp8_init): + test_model = opts.layer_type(*args, **kwargs) + dist_print("Initialized test model...", debug=True) + + # Initialize the reference model and copy all parameters + ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True) + with te.fp8_model_init(enabled=opts.fp8_init): + ref_model = opts.layer_type(*ref_args, **ref_kwargs) + dist_print("Initialized reference model...", debug=True) + for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()): + with torch.no_grad(): + ref_param.copy_(test_param) + torch.testing.assert_close(test_param, ref_param, rtol=0.0, atol=0.0) + dist_print("Copied parameters from test model to reference model...", debug=True) + + # Fp8 recipe setup + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + + # Prepare random input tensors + test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True) + test_x.retain_grad() + ref_x = torch.empty_like(test_x).requires_grad_(True) + with torch.no_grad(): + ref_x.copy_(test_x) + torch.testing.assert_close(test_x, ref_x, rtol=0.0, atol=0.0) + ref_x.retain_grad() + + # Execute fwd/bwd and collect tensors to test + def run_fwd_bwd(model, x): + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + loss = out.sum() + loss.backward() + return out + + torch_rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}")) + if opts.use_cuda_graphs: + test_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(test_graph): + test_out = run_fwd_bwd(test_model, test_x) + test_graph.replay() + del test_graph + else: + test_out = run_fwd_bwd(test_model, test_x) + test_grads = [test_out, test_x.grad] + names = ["output", "input.grad"] + for test_name, test_param in test_model.named_parameters(): + if test_param.requires_grad and "layer_norm" not in test_name: + test_grads.append(test_param.grad) + names.append(test_name + ".grad") + + torch.set_rng_state(torch_rng_state) + torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}")) + if opts.use_cuda_graphs: + ref_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(ref_graph): + ref_out = run_fwd_bwd(ref_model, ref_x) + ref_graph.replay() + del ref_graph + else: + ref_out = run_fwd_bwd(ref_model, ref_x) + ref_grads = [ref_out, ref_x.grad] + for ref_name, ref_param in ref_model.named_parameters(): + if ref_param.requires_grad and "layer_norm" not in ref_name: + ref_grads.append(ref_param.grad) + + # Make sure we have the same number of gradients + numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") + if len(test_grads) != len(ref_grads): + numerics_failed[0] = 1 + numerics_info = ( + "NUMERICAL CHECK FAILED: Incorrect number of gradients, " + + f"expected {len(ref_grads)} but got {len(test_grads)}." + ) + dist_print(numerics_info, src=WORLD_RANK, error=True) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + + # Now validate accuracy + if not bool(numerics_failed.item()): + for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): + rtol = 0.125 if opts.fp8 else 0.025 + atol = 0.0625 if opts.fp8 else 0.00125 + grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) + dist_print(grad_info, src=WORLD_RANK, error=grad_failed) + numerics_failed[0] = int(grad_failed) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + if bool(numerics_failed.item()): + break + + te.module.base.destroy_ub() + dist_print("Destroying Userbuffers objects...", debug=True) + + dist_print("Destroying all process groups...", debug=True) + dist.destroy_process_group() + if opts.debug and WORLD_RANK == 0: + print("Exiting...\n", end="", flush=True) + + return numerics_failed[0].item() + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index d0745aebf6..63310195ae 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -7,16 +7,27 @@ import pytest import torch +import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +if torch.cuda.device_count() < 2: + pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.") + fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() RNG_SEED: int = 1234 -SEQ_LENGTH: int = 2024 +SEQ_LENGTH: int = 512 BATCH_SIZE: int = 2 -NUM_HEADS: int = 64 -HEAD_DIM: int = 128 +NUM_HEADS: int = 12 +HEAD_DIM: int = 64 +TE_LAYERS = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, +] TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(torch.cuda.device_count(), 4) @@ -32,66 +43,28 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" -@pytest.mark.skipif(NUM_PROCS < 2, reason="Comm+GEMM overlap requires at least 2 GPUs.") -@pytest.mark.parametrize( - "fp8,p2p,comm_type,aggregate,atomic,bulk", - [ - # FP8, P2P, Type, Aggregate, Atomic, Bulk - (False, True, "AG", False, False, False), - (False, True, "AG", True, False, False), - (True, True, "AG", False, False, False), - (True, True, "AG", True, False, False), - (False, False, "RS", False, False, False), - (False, True, "RS", False, False, False), - (True, False, "RS", False, False, False), - (True, True, "RS", False, False, False), - (True, False, "RS", False, True, False), - (True, True, "RS", False, True, False), - (False, False, "AG", False, False, True), - (False, False, "RS", False, False, True), - ], - ids=[ - " AG -> SPLIT GEMM | BF16 | RING-EXCHANGE ", - " AG -> SPLIT GEMM | BF16 | RING-EXCHANGE (2X AGGREGATED) ", - " AG -> SPLIT GEMM | FP8 | RING-EXCHANGE ", - " AG -> SPLIT GEMM | FP8 | RING-EXCHANGE (2X AGGREGATED) ", - " SPLIT GEMM -> RS | BF16 | PIPELINE ", - " SPLIT GEMM -> RS | BF16 | RING-EXCHANGE ", - " SPLIT GEMM -> RS | FP8 | PIPELINE ", - " SPLIT GEMM -> RS | FP8 | RING-EXCHANGE ", - " ATOMIC GEMM -> RS | FP8 | PIPELINE ", - " ATOMIC GEMM -> RS | FP8 | RING-EXCHANGE ", - " BULK AG & GEMM | BF16 | PIPELINE ", - " BULK RS & GEMM | BF16 | PIPELINE ", - ], -) -def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk): - """ - Test comm+GEMM overlap algorithms with direct calls to - te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm - """ +def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate): test_path = TEST_ROOT / "run_gemm_with_overlap.py" - test_cmd = ( - LAUNCH_CMD - + [str(test_path)] - + [ - "--check-numerics", - f"--seed={RNG_SEED}", - f"--seq-length={SEQ_LENGTH}", - f"--batch-size={BATCH_SIZE}", - f"--num-heads={NUM_HEADS}", - f"--head-dim={HEAD_DIM}", - f"--comm-type={comm_type}", - ] - ) + test_cmd = LAUNCH_CMD + [ + str(test_path), + "--check-numerics", + f"--seed={RNG_SEED}", + f"--seq-length={SEQ_LENGTH}", + f"--batch-size={BATCH_SIZE}", + f"--num-heads={NUM_HEADS}", + f"--head-dim={HEAD_DIM}", + f"--comm-type={comm_type}", + ] if bulk: test_cmd.append("--bulk-overlap") else: - if fp8: + if fp8_in: if not fp8_available: pytest.skip(reason_for_no_fp8) test_cmd.append("--fp8") + if fp8_out: + test_cmd.append("--fp8-output") if p2p: test_cmd.append("--p2p") if aggregate: @@ -101,5 +74,173 @@ def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk): pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.") test_cmd.append("--atomic") - output = subprocess.run(test_cmd, env=os.environ, text=True, capture_output=True, check=False) - assert "NUMERICAL CHECK PASSED" in str(output) + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + if ( + result.returncode != 0 + or "NUMERICAL CHECK FAILED" in result.stderr.decode() + or "NUMERICAL CHECK PASSED" not in result.stdout.decode() + ): + raise AssertionError(result.stderr.decode()) + + +def _run_layer_with_overlap(layer_type, fp8, fp8_init): + test_path = TEST_ROOT / "run_layer_with_overlap.py" + test_cmd = LAUNCH_CMD + [ + str(test_path), + f"--seed={RNG_SEED}", + f"--seq-length={SEQ_LENGTH}", + f"--batch-size={BATCH_SIZE}", + f"--num-heads={NUM_HEADS}", + f"--head-dim={HEAD_DIM}", + f"--layer-type={layer_type}", + ] + + if fp8: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + test_cmd.append("--fp8") + if fp8_init: + test_cmd.append("--fp8-init") + + os.environ["PYTORCH_JIT"] = "0" + os.environ["NVTE_TORCH_COMPILE"] = "0" + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + + os.unsetenv("PYTORCH_JIT") + os.unsetenv("NVTE_TORCH_COMPILE") + os.unsetenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO") + + if ( + result.returncode != 0 + or "NUMERICAL CHECK FAILED" in result.stderr.decode() + or "NUMERICAL CHECK PASSED" not in result.stdout.decode() + ): + raise AssertionError(result.stderr.decode()) + + +@pytest.mark.parametrize( + "fp8,aggregate", + [ + (False, False), + (False, True), + (True, False), + (True, True), + ], + ids=[ + " BF16 IN - RING-EXCHANGE ", + " BF16 IN - RING-EXCHANGE - 2x AGGREGATED ", + " FP8 IN - RING-EXCHANGE ", + " FP8 IN - RING-EXCHANGE - 2x AGGREGATED ", + ], +) +def test_split_all_gather_overlaps(fp8, aggregate): + """ + Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or + te.cpp_extensions.fp8_gemm. + """ + _run_gemm_with_overlap("AG", False, True, False, fp8, False, aggregate) + + +@pytest.mark.parametrize( + "fp8_in,fp8_out,p2p", + [ + (False, False, False), + (False, False, True), + (True, False, False), + (True, False, True), + (True, True, False), + (True, True, True), + ], + ids=[ + " BF16 IN - BF16 OUT - PIPELINE ", + " BF16 IN - BF16 OUT - RING-EXCHANGE ", + " FP8 IN - BF16 OUT - PIPELINE ", + " FP8 IN - BF16 OUT - RING-EXCHANGE ", + " FP8 IN - FP8 OUT - PIPELINE ", + " FP8 IN - FP8 OUT - RING-EXCHANGE ", + ], +) +def test_split_reduce_scatter_overlaps(fp8_in, fp8_out, p2p): + """ + Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or + te.cpp_extensions.fp8_gemm. + """ + _run_gemm_with_overlap("RS", False, p2p, False, fp8_in, fp8_out, False) + + +@pytest.mark.parametrize( + "ag_type,rs_type,p2p,fp8_out", + [ + (0, 0, False, False), + (0, 1, False, False), + (0, 1, False, True), + (0, 2, False, False), + (0, 2, False, True), + (0, 0, True, False), + (0, 0, True, True), + (1, 0, True, False), + (1, 0, True, True), + ], + ids=[ + " NON-ATOMIC AG - NON-ATOMIC RS - PIPELINE - BF16 OUT ", + " NON-ATOMIC AG - ATOMIC RS - PIPELINE - BF16 OUT ", + " NON-ATOMIC AG - ATOMIC RS - PIPELINE - FP8 OUT ", + " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - BF16 OUT ", + " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - FP8 OUT ", + " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", + " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", + " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", + " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", + ], +) +def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): + """ + Test paired (all-gather -> atomic GEMM) and (atomic GEMM -> reduce-scatter) overlaps with + direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. + """ + os.environ["NVTE_AG_P2P_MULTI_ATOMIC"] = str(ag_type) + os.environ["NVTE_RS_STRIDED_ATOMIC"] = str(rs_type) + _run_gemm_with_overlap("AG", False, p2p, True, True, fp8_out, False) + + +@pytest.mark.parametrize( + "comm_type,fp8", + [ + ("AG", False), + ("RS", False), + ("RS", True), + ], + ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "], +) +def test_bulk_overlaps(comm_type, fp8): + """ + Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. + """ + _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + + +@pytest.mark.parametrize( + "layer_type", + [layer.__name__ for layer in TE_LAYERS], + ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS], +) +@pytest.mark.parametrize( + "fp8,fp8_init", + [ + (False, False), + (True, False), + (True, True), + ], + ids=[ + " BF16 GEMM - BF16 PARAMS ", + " FP8 GEMM - BF16 PARAMS ", + " FP8 GEMM - FP8 PARAMS ", + ], +) +def test_layers_with_overlap(layer_type, fp8, fp8_init): + """ + Test Transformer Engine layers with comm+GEMM overlap. + """ + _run_layer_with_overlap(layer_type, fp8, fp8_init) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 88609b6ddb..bae46cffc9 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -166,7 +166,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { // Initialize userbuf communicator if (!comm_created) { if (myrank == 0) { - printf("!!! [UB] Create UbufCommOverlap Communicator\n"); + printf("!!! [UB] Create Userbuffers Communicator\n"); } #ifdef NVTE_UB_WITH_MPI create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); @@ -184,16 +184,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { // Allocate and register extra userbuffers int ubuf_bytes = sample.numel() * sample.element_size(); - if (transformer_engine::getenv("UB_SKIPMC")) { - _ubuf = torch::zeros_like(sample); - _ubuf_ptr = _ubuf.data_ptr(); - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, false); - } else { - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); - } + _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, + _ub_comm, true); + _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); if (_ub_comm->myrank == 0) { printf("!!! [UB] Register UBuf %d\n", _ub_reg); @@ -264,6 +257,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type, at::Tensor rs_output) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -319,6 +313,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim1 = _ubuf.size(1); output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); + _ub_comm->sms = ori_sms; return {D, output_tensor}; } // bulk_overlap @@ -336,6 +331,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, at::Tensor rs_output) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -352,7 +348,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); int *counter_ptr = reinterpret_cast(counter.data_ptr()); char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - int ori_sms = _ub_comm->sms; // Catch up the default torch stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); @@ -388,7 +383,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(_ubuf_scale_inv_initialized); float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reducescatter2_userbuff_strided_atomic_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);); @@ -402,7 +397,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(_ubuf_scale_inv_initialized); float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reducescatter2_userbuff_strided_multiatomic_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);); @@ -413,10 +408,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } break; } else { + assert(_ubuf.element_size() != 1); consumer(counter_ptr, i, (cudaStream_t)_stream_comm); - // if (i == _num_splits-1) { - // _ub_comm->sms = UB_MAX_SM; - // } reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); } @@ -447,6 +440,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, at::Tensor rs_output) { // Get GEMM dimensions + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -464,7 +458,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - int ori_sms = _ub_comm->sms; // Catch up the default torch stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); @@ -517,7 +510,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(_ubuf_scale_inv_initialized); float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reducescatter2_userbuff_stridedoutput_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); @@ -541,7 +534,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(_ubuf_scale_inv_initialized); float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reducescatter2_userbuff_stridedoutput_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); @@ -577,7 +570,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(_ubuf_scale_inv_initialized); float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reducescatter2_userbuff_stridedoutput_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); @@ -682,7 +675,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Initialize userbuf communicator if (!comm_created) { if (myrank == 0) { - printf("!!! [UB] Create UbufP2PCommOverlap Communicator\n"); + printf("!!! [UB] Create Userbuffers Communicator\n"); } #ifdef NVTE_UB_WITH_MPI create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); @@ -708,19 +701,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ubuf_bytes = static_cast(ubuf_bytes / tp_size * (tp_size * 2 - 1)); num_ubuf_chunks = static_cast(tp_size * 2 - 1); } - if (transformer_engine::getenv("UB_SKIPMC")) { - _ubuf = torch::zeros({sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, - sample.options()); - _ubuf_ptr = _ubuf.data_ptr(); - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, false); - } else { - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - _ubuf = - torch::from_blob(_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, - sample.options()); - } + + _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, + _ub_comm, true); + _ubuf = torch::from_blob( + _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options()); if (_ub_comm->myrank == 0) { printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); } @@ -728,9 +713,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(_ubuf.data_ptr()); for (int i = 0; i < num_ubuf_chunks; i++) { - torch::Tensor ubuf_chunk = torch::from_blob( - ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options()); - _ubufs.push_back(ubuf_chunk); + auto ubuf_chunk = torch::from_blob(ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, + sample.options()); + _ubufs.push_back(std::move(ubuf_chunk)); ubuf_byte_ptr += ubuf_chunk_bytes; } @@ -769,6 +754,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); if (_rank == 0 && env_p != nullptr) { if (env_p[0] == '1') { + _use_ce = 0; + _ub_comm->push = 1; printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); } } @@ -818,6 +805,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -866,6 +854,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); if (env_p != nullptr && env_p[0] == '1') { if (i == 0) { + _ub_comm->use_ce = 0; userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, true, (cudaStream_t)_stream_recv); @@ -906,6 +895,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); // Return the last N rows of D_buffer + _ub_comm->sms = ori_sms; torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n); return D_return; } // atomic_gemm_overlap_ag @@ -926,6 +916,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -1078,6 +1069,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); at::cuda::setCurrentCUDAStream(stream_main); + _ub_comm->sms = ori_sms; return D; } // split_overlap_ag @@ -1094,6 +1086,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -1149,7 +1142,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main);); } else { @@ -1157,6 +1150,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); torch::sum_out(rs_output, reduce_buf, 0); } + _ub_comm->sms = ori_sms; } /* @@ -1171,6 +1165,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -1245,7 +1240,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main);); } else { @@ -1259,6 +1254,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { } NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); + _ub_comm->sms = ori_sms; } /* diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index 03a1a6a3df..0cd2a0253b 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -1861,6 +1861,14 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) } +template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + +template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream) { diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cbcda20fe8..651a1a4c1a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -107,7 +107,7 @@ def initialize_ub( world_size = torch.distributed.get_world_size(mpi_group) local_rank = world_rank % tp_size local_size = tp_size - node_id = world_rank // tp_size + self_node_idx = world_rank // tp_size num_nodes = world_size // tp_size ub_callbacks = tex.UbufBootstrapCallbacks() else: @@ -127,13 +127,6 @@ def initialize_ub( world_rank = torch.distributed.get_rank(world_group) world_size = torch.distributed.get_world_size(world_group) - if world_rank == 0: - print( - f'!!! [NVTE] Bootstrapping Userbuffers with backend="{bootstrap_backend}"\n', - end="", - flush=True, - ) - # Construct an intra-node communicator based on global ranks that share the same hostname # NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host # address on that interface instead of the hostname. This can help avoid issues when @@ -157,28 +150,41 @@ def initialize_ub( hostnames = [None for _ in range(world_size)] torch.distributed.all_gather_object(hostnames, hostname, world_group) - intra_node_ranks = [] - for i, host in enumerate(hostnames): - if host == hostname: - intra_node_ranks.append(i) - if len(intra_node_ranks) == world_size: + unique_hosts = [] + for host in hostnames: + if host not in unique_hosts: + unique_hosts.append(host) + num_nodes = len(unique_hosts) + + if num_nodes > 1: + ranks_per_node_list = [[] for _ in range(num_nodes)] + self_node_idx = -1 + for i, host in enumerate(hostnames): + node_idx = unique_hosts.index(host) + ranks_per_node_list[node_idx].append(i) + if host == hostname: + self_node_idx = node_idx + assert self_node_idx >= 0, "Internal TE error!" + + intra_node_group, _ = torch.distributed.new_subgroups_by_enumeration( + ranks_per_node_list, backend=bootstrap_backend + ) + local_rank = torch.distributed.get_rank(intra_node_group) + local_size = torch.distributed.get_world_size(intra_node_group) + intra_node_ranks = torch.distributed.get_process_group_ranks(intra_node_group) + + else: + self_node_idx = 0 intra_node_group = world_group local_rank = world_rank local_size = world_size intra_node_ranks = list(range(world_size)) - else: - intra_node_group = torch.distributed.new_group( - backend=bootstrap_backend, ranks=intra_node_ranks - ) - local_rank = torch.distributed.get_rank(intra_node_group) - local_size = torch.distributed.get_world_size(intra_node_group) - node_id = world_rank // local_size - num_nodes = world_size // local_size + if world_rank == 0: + print(f"!!! [UB] Number of physical nodes: {num_nodes}\n", end="", flush=True) if local_rank == 0: print( - f"!!! [NVTE] Number of physical nodes: {num_nodes}\n" - + f"!!! [NVTE] Global ranks on node {node_id}: {intra_node_ranks}\n", + f"!!! [UB] Global ranks on node {self_node_idx}: {intra_node_ranks}\n", end="", flush=True, ) @@ -293,7 +299,7 @@ def add_ub( world_size, # World size local_rank, # Rank within the node local_size, # Number of ranks/GPUs per node - node_id, # Node ID + self_node_idx, # Node ID num_nodes, # Number of nodes tp_size, # Tensor-parallel group size (may be different than local_size) num_sm, # Number of communication SMs @@ -313,7 +319,7 @@ def add_ub( world_size, # World size local_rank, # Rank within the node local_size, # Number of ranks/GPUs per node - node_id, # Node ID + self_node_idx, # Node ID num_nodes, # Number of nodes tp_size, # Tensor-parallel group size (may be different than local_size) num_sm, # Number of communication SMs From 8b3260599a0e2fab382717e58860bc7184717fbf Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sat, 10 Aug 2024 07:01:17 +0800 Subject: [PATCH 41/72] [PyTorch] Reduce the CPU overheads of `GroupedLinear` (#1072) * use fused_multi_cast_transpose Signed-off-by: Xin Yao * fix input being empty tensor Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * allocate output tensors in C++ Signed-off-by: Xin Yao * simplify code Signed-off-by: Xin Yao * avoid cudaGetDriverEntryPoint Signed-off-by: Xin Yao * reduce torch.Tensor() calls Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update test Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 3 +- .../pytorch/cpp_extensions/gemm.py | 19 +++-- .../pytorch/cpp_extensions/transpose.py | 25 +++++- transformer_engine/pytorch/csrc/extensions.h | 5 ++ .../pytorch/csrc/extensions/pybind.cpp | 3 + .../pytorch/csrc/extensions/transpose.cu | 83 ++++++++++++++----- transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 6 +- .../pytorch/module/grouped_linear.py | 46 +++++----- 8 files changed, 138 insertions(+), 52 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 7eed97a0ca..a219f24674 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1228,7 +1228,8 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False inp_hidden_states.retain_grad() m = config.seq_len // 16 - dist = torch.sort(torch.randint(0, m, (num_gemms - 1,))).values.tolist() + dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() + dist.append(dist[-1]) # Manually add a zero m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) m_splits = m_splits * 16 assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 38392a5795..8502f70491 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Python interface for GEMM extensions""" +import functools from typing import Optional, Tuple, Union, List import torch import transformer_engine_torch as tex @@ -13,6 +14,12 @@ __all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"] +@functools.lru_cache(maxsize=None) +def _empty_tensor() -> torch.Tensor: + """Get tensor with no entries and no data""" + return torch.Tensor() + + def fp8_gemm( A: torch.Tensor, A_scale_inv: torch.Tensor, @@ -39,7 +46,7 @@ def fp8_gemm( ) -> torch.Tensor: """TN layout GEMM with fp8 inputs.""" - empty_tensor = torch.Tensor() + empty_tensor = _empty_tensor() if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: assert fp8_meta_tensor is not None and out_index is not None assert_dim_for_fp8_exec(A) @@ -195,7 +202,7 @@ def gemm( assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." transa = layout[0] == "T" transb = layout[1] == "T" - empty_tensor = torch.Tensor() + empty_tensor = _empty_tensor() fp8_index = -1 # dummy index if out is None: @@ -313,8 +320,8 @@ def grouped_gemm( transa = layout[0] == "T" transb = layout[1] == "T" num_gemms = len(A) - empty_tensor = torch.Tensor() - empty_tensors = [torch.Tensor()] * num_gemms + empty_tensor = _empty_tensor() + empty_tensors = [empty_tensor] * num_gemms if gelu and not grad: gelu_input = [ @@ -401,8 +408,8 @@ def fp8_grouped_gemm( """ num_gemms = len(A) - empty_tensor = torch.Tensor() - empty_tensors = [torch.Tensor()] * num_gemms + empty_tensor = _empty_tensor() + empty_tensors = [empty_tensor] * num_gemms if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: assert fp8_meta_tensor is not None and out_offset is not None for a, b in zip(A, B): diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index de83bcd7f5..d96b743b9e 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """Python interface for transpose extensions""" -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import transformer_engine_torch as tex from ..constants import TE_DType @@ -13,6 +13,7 @@ "fp8_cast_transpose_fused", "fp8_cast_transpose_bgrad_fused", "fp8_cast_transpose_bgrad_dgelu_fused", + "fp8_multi_cast_transpose_fused", "fp8_transpose_bgrad_fused", ] @@ -118,3 +119,25 @@ def fp8_cast_transpose_bgrad_dgelu_fused( amax_offset=int(fp8_tensor), scale_inv_offset=int(fp8_tensor), ) + + +def fp8_multi_cast_transpose_fused( + input_list: List[torch.Tensor], + fp8_meta_tensor: tex.FP8TensorMeta, + scale_indices: List[int], + amax_indices: List[int], + scale_inv_indices: List[int], + otype: tex.DType, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Cast + Transpose with FP8 output""" + + return tex.fused_multi_cast_transpose_alloc( + input_list, + fp8_meta_tensor.scale, + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + scale_indices, + amax_indices, + scale_inv_indices, + otype, + ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index bd908e9336..cd5bda8b63 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -180,6 +180,11 @@ void fused_multi_cast_transpose(std::vector input_list, std::vector scale_inv_output_list, transformer_engine::DType otype); +std::tuple, std::vector> fused_multi_cast_transpose_alloc( + std::vector input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + std::vector scale_indices, std::vector amax_indices, + std::vector scale_inv_indices, transformer_engine::DType otype); + at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype); void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 89bce77ded..c97c66dd98 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -84,6 +84,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, "Fused Multi-tensor Cast + Transpose", py::call_guard()); + m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, + "Fused Multi-tensor Cast + Transpose with allocating output tensors", + py::call_guard()); m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard()); m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index 473954d099..56f6b56769 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -75,7 +75,7 @@ std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::T // Return immediately if tensors are empty if (M == 0 || N == 0) { - return {grad_bias, grad_output_cast, grad_output_transpose}; + return {grad_bias.zero_(), grad_output_cast, grad_output_transpose}; } // Get pointers for FP8 scale, amax, scale-inverse @@ -196,22 +196,21 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, return {grad_bias, dgelu, dgelu_transpose}; } -void fused_multi_cast_transpose(std::vector input_list, - std::vector scale_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_list, - std::vector scale_inv_list, - transformer_engine::DType otype) { +void fused_multi_cast_transpose_base(std::vector input_list, + std::vector scale_dptr_list, + std::vector cast_output_list, + std::vector transposed_output_list, + std::vector amax_dptr_list, + std::vector scale_inv_dptr_list, + transformer_engine::DType otype) { using namespace transformer_engine; // Extract properties from PyTorch tensors - std::vector input_dptr_list, scale_dptr_list, cast_output_dptr_list, - transposed_output_dptr_list, amax_dptr_list, scale_inv_dptr_list; - std::vector> input_shape_list, scale_shape_list, cast_output_shape_list, - transposed_output_shape_list, amax_shape_list, scale_inv_shape_list; - std::vector input_type_list, scale_type_list, cast_output_type_list, - transposed_output_type_list, amax_type_list, scale_inv_type_list; + std::vector input_dptr_list, cast_output_dptr_list, transposed_output_dptr_list; + std::vector> input_shape_list, cast_output_shape_list, + transposed_output_shape_list; + std::vector input_type_list, cast_output_type_list, + transposed_output_type_list; auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, std::vector& dptr_list, std::vector>& shape_list) { dptr_list.push_back(tensor.data_ptr()); @@ -232,20 +231,14 @@ void fused_multi_cast_transpose(std::vector input_list, }; for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { extract_tensor_props(input_list[tensor_id], input_dptr_list, input_shape_list, input_type_list); - extract_tensor_props(scale_list[tensor_id], scale_dptr_list, scale_shape_list, scale_type_list); extract_tensor_props_skip_dtype(cast_output_list[tensor_id], cast_output_dptr_list, cast_output_shape_list); cast_output_type_list.push_back(otype); extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], transposed_output_dptr_list, transposed_output_shape_list); transposed_output_type_list.push_back(otype); - extract_tensor_props(amax_list[tensor_id], amax_dptr_list, amax_shape_list, amax_type_list); - extract_tensor_props(scale_inv_list[tensor_id], scale_inv_dptr_list, scale_inv_shape_list, - scale_inv_type_list); } - transformer_engine::TensorWrapper workspace; - // Construct TE tensors std::vector nvte_input_list, nvte_cast_output_list, nvte_transposed_output_list; std::vector tensor_wrappers; @@ -257,6 +250,7 @@ void fused_multi_cast_transpose(std::vector input_list, return tensor_wrappers.back().data(); }; for (size_t i = 0; i < input_dptr_list.size(); ++i) { + if (input_dptr_list[i] == nullptr) continue; nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i], nullptr, nullptr, nullptr)); nvte_cast_output_list.emplace_back( @@ -280,6 +274,55 @@ void fused_multi_cast_transpose(std::vector input_list, at::cuda::getCurrentCUDAStream()); } +void fused_multi_cast_transpose(std::vector input_list, + std::vector scale_list, + std::vector cast_output_list, + std::vector transposed_output_list, + std::vector amax_list, + std::vector scale_inv_list, + transformer_engine::DType otype) { + std::vector scale_dptr_list, amax_dptr_list, scale_inv_dptr_list; + for (size_t i = 0; i < scale_list.size(); ++i) { + scale_dptr_list.push_back(scale_list[i].data_ptr()); + amax_dptr_list.push_back(amax_list[i].data_ptr()); + scale_inv_dptr_list.push_back(scale_inv_list[i].data_ptr()); + } + + fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list, + transposed_output_list, amax_dptr_list, scale_inv_dptr_list, + otype); +} + +std::tuple, std::vector> fused_multi_cast_transpose_alloc( + std::vector input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + std::vector scale_indices, std::vector amax_indices, + std::vector scale_inv_indices, transformer_engine::DType otype) { + using namespace transformer_engine; + + std::vector cast_output_list; + std::vector transposed_output_list; + std::vector scale_dptr_list, amax_dptr_list, scale_inv_dptr_list; + for (size_t i = 0; i < input_list.size(); ++i) { + auto input_i = input_list[i]; + // construct cast output tensors + auto cast_output_i = allocateTorchTensor(input_i.size(0), input_i.size(1), DType::kByte); + cast_output_list.push_back(cast_output_i); + // construct transposed output tensors + auto transposed_output_i = allocateTorchTensor(input_i.size(1), input_i.size(0), DType::kByte); + transposed_output_list.push_back(transposed_output_i); + // construct amax/scale/scale_inv dptr lists + amax_dptr_list.push_back(getDataPtr(amax, amax_indices[i])); + scale_dptr_list.push_back(getDataPtr(scale, scale_indices[i])); + scale_inv_dptr_list.push_back(getDataPtr(scale_inv, scale_inv_indices[i])); + } + + fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list, + transposed_output_list, amax_dptr_list, scale_inv_dptr_list, + otype); + + return std::make_tuple(std::move(cast_output_list), std::move(transposed_output_list)); +} + at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype) { using namespace transformer_engine; diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index e1bcfecc13..8515092ae0 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -258,7 +258,8 @@ at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_te // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs - const int sm_count = transformer_engine::cuda::sm_count(); + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; @@ -293,7 +294,8 @@ std::vector te_grouped_gemm_ts( // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs - const int sm_count = transformer_engine::cuda::sm_count(); + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); te_grouped_gemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c55225eed9..a91ff5c361 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -34,7 +34,7 @@ from ..cpp_extensions import ( cast_to_fp8, fp8_cast_transpose_bgrad_fused, - fp8_cast_transpose_fused, + fp8_multi_cast_transpose_fused, fp8_grouped_gemm, grouped_gemm, ) @@ -82,12 +82,12 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, + weights_fp8: List[Union[Float8Tensor, None]], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], ) -> torch.Tensor: num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] - weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms] - biases = weights_and_biases[2 * num_gemms :] + biases = weights_and_biases[num_gemms:] # Make sure input dimensions are compatible in_features = weights[0].shape[-1] @@ -113,15 +113,15 @@ def forward( and not sequence_parallel ): # FP8 input for forward, FP8 input transpose for backward wgrad - for i in range(num_gemms): - mat, mat_t = fp8_cast_transpose_fused( - inputmats_no_fp8[i], - fp8_meta["scaling_fwd"], - _GEMM_INPUT + i, - fp8_dtype_forward, - ) - inputmats.append(mat) - inputmats_t.append(mat_t) + indices = list(range(_GEMM_INPUT, _GEMM_INPUT + num_gemms)) + inputmats, inputmats_t = fp8_multi_cast_transpose_fused( + inputmats_no_fp8, + fp8_meta["scaling_fwd"], + indices, # scale_indices + indices, # amax_indices + indices, # scale_inv_indices + fp8_dtype_forward, + ) else: # FP8 input for forward inputmats = [ @@ -308,13 +308,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) else: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - for i in range(ctx.num_gemms): - grad_output_c[i], grad_output_t[i] = fp8_cast_transpose_fused( - grad_output_mats[i], - ctx.fp8_meta["scaling_bwd"], - _GRAD_OUTPUT + i, - fp8_dtype_backward, - ) + indices = list(range(_GRAD_OUTPUT, _GRAD_OUTPUT + ctx.num_gemms)) + grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused( + grad_output_mats, + ctx.fp8_meta["scaling_bwd"], + indices, # scale_indices + indices, # amax_indices + indices, # scale_inv_indices + fp8_dtype_backward, + ) else: for i in range(ctx.num_gemms): grad_output_c[i] = cast_to_fp8( @@ -334,7 +336,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: if ctx.fp8: dgrad = torch.empty( - (sum(ctx.m_splits), weights_fp8[i].size(1)), + (sum(ctx.m_splits), weights_fp8[0].size(1)), dtype=ctx.activation_dtype, device=grad_output.device, ) @@ -487,8 +489,8 @@ def handle_custom_ddp_from_mcore(w, wgrad): None, # activation_dtype None, # parallel_mode None, # is_grad_enabled + None, # weights_fp8 *wgrad_list, - *([None] * ctx.num_gemms), # weights_fp8 *grad_biases, ) @@ -799,8 +801,8 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), + weight_tensors_fp8, *weight_tensors, - *weight_tensors_fp8, *bias_tensors, ) out = linear_fn(*args) From e0aa7992c549395e88762ce337fa4e1ae988e2bd Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 9 Aug 2024 18:04:11 -0700 Subject: [PATCH 42/72] [PyTorch] Branching operations (#1027) * Add op for in-place add Signed-off-by: Tim Moon * Add op for in-place add Signed-off-by: Tim Moon * Add op that adds extra output to fuser Signed-off-by: Tim Moon * Add fused op for GEMM+bias+add Signed-off-by: Tim Moon * Add fused op for dgrad+add Signed-off-by: Tim Moon * Add documentation Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon * Review suggestions from @ptrendx Output tensor dtype and device take precedence over weight tensor in linear functional API. Move some index calculation to fuser constructor. Avoid some unnecessary dereferences. Signed-off-by: Tim Moon * Debug test failures Signed-off-by: Tim Moon * Update transformer_engine/pytorch/ops/fuser.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 381 +++++++++++++++++- transformer_engine/pytorch/ops/__init__.py | 2 + .../pytorch/ops/basic/__init__.py | 2 + .../pytorch/ops/basic/add_in_place.py | 79 ++++ .../pytorch/ops/basic/basic_linear.py | 205 ++++++++-- .../pytorch/ops/basic/make_extra_output.py | 80 ++++ .../ops/{fused_forward => fused}/__init__.py | 10 +- .../pytorch/ops/fused/backward_linear_add.py | 156 +++++++ .../forward_linear_bias_activation.py} | 13 +- .../ops/fused/forward_linear_bias_add.py | 196 +++++++++ transformer_engine/pytorch/ops/fuser.py | 138 ++++++- transformer_engine/pytorch/ops/op.py | 115 ++++-- transformer_engine/pytorch/ops/sequential.py | 46 ++- 13 files changed, 1300 insertions(+), 123 deletions(-) create mode 100644 transformer_engine/pytorch/ops/basic/add_in_place.py create mode 100644 transformer_engine/pytorch/ops/basic/make_extra_output.py rename transformer_engine/pytorch/ops/{fused_forward => fused}/__init__.py (52%) create mode 100644 transformer_engine/pytorch/ops/fused/backward_linear_add.py rename transformer_engine/pytorch/ops/{fused_forward/linear_bias_activation.py => fused/forward_linear_bias_activation.py} (93%) create mode 100644 transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9aab3b2702..3523e1cda5 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -15,8 +15,10 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor -from transformer_engine.pytorch.ops.fused_forward import ( +from transformer_engine.pytorch.ops.fused import ( + BackwardLinearAdd, ForwardLinearBiasActivation, + ForwardLinearBiasAdd, ) from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex @@ -84,15 +86,14 @@ def make_reference_and_test_tensors( """ ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + test = ref.to(device=test_device, dtype=test_dtype) if test_is_fp8: - test = Float8Tensor.to_float8(ref) + test = Float8Tensor.to_float8(test) test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1) test._transpose = test._transpose.contiguous() test._transpose_invalid = False - else: - test = ref.to(device=test_device, dtype=test_dtype) - if test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif test.data_ptr() == ref.data_ptr(): + test = test.clone() ref.copy_(test) ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) @@ -320,14 +321,13 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("in_shape", ((1,),)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("fp8", (False, True)) def test_identity( self, *, - in_shape: Iterable[int], + in_shape: Iterable[int] = (1,), dtype: torch.dtype, device: torch.device, fp8: bool, @@ -737,6 +737,123 @@ def test_linear( db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("device", ("cuda", "cpu")) + @pytest.mark.parametrize("fp8", (False, True)) + def test_add_in_place( + self, + *, + in_shape: Iterable[int] = (1,), + dtype: torch.dtype, + device: torch.device, + fp8: bool, + ) -> None: + + # Skip invalid configurations + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8 and torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x1_ref, x1_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + x2_ref, x2_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = x2_ref.detach() + y_ref += x1_ref + dx1_ref = dy_ref + dx2_ref = dy_ref + + # Implementation with fusible operation + op = te_ops.AddInPlace() + y_test = op(x1_test, x2_test) + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + if fp8: + tols = dtype_tols(x1_test._fp8_dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") + dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0) + torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("device", ("cuda", "cpu")) + @pytest.mark.parametrize("fp8", (False, True)) + def test_make_extra_output( + self, + *, + in_shape: Iterable[int] = (1,), + dtype: torch.dtype, + device: torch.device, + fp8: bool, + ) -> None: + + # Skip invalid configurations + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8 and torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + dy1_ref, dy1_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + dy2_ref, dy2_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y1_ref = x_ref + y2_ref = x_ref + (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() + + # Implementation with fusible operation + op = te_ops.MakeExtraOutput() + y1_test, y2_test = op(x_test) + (y1_test * dy1_test + y2_test * dy2_test).sum().backward() + + # Check results + tols = dtype_tols(dtype) + y1_test = y1_test.to(dtype=torch.float64, device="cpu") + y2_test = y2_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y1_test, y1_ref, rtol=0, atol=0) + torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + class TestFusedOps: """Tests for fused operations""" @@ -754,7 +871,7 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("fp8_compute", (False, True)) @pytest.mark.parametrize("fp8_input", (False, True)) @pytest.mark.parametrize("fp8_weight", (False, True)) - def test_linear_bias_activation( + def test_forward_linear_bias_activation( self, *, bias: bool = True, @@ -766,7 +883,7 @@ def test_linear_bias_activation( fp8_input: bool, fp8_weight: bool, ) -> None: - """GEMM + bias + activation""" + """Forward GEMM + bias + activation""" # Make input and weight shapes consistent out_features, in_features = weight_shape @@ -951,3 +1068,247 @@ def test_fp8_linear( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw0_test, w0_ref.grad, **tols) torch.testing.assert_close(dw1_test, w1_ref.grad, **tols) + + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_compute", (False, True)) + def test_forward_linear_bias_add( + self, + *, + bias: bool, + weight_shape: tuple[int, int] = (16, 16), + in_shape: Iterable[int] = (16, -1), + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_compute: bool, + fp8_input: bool = False, + fp8_weight: bool = False, + fp8_output: bool = False, + ) -> None: + """Forward GEMM + bias + add""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + if fp8_input or fp8_weight or fp8_output or fp8_compute: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + if fp8_compute: + if ( + math.prod(in_shape[:-1]) % 16 != 0 + or in_features % 16 != 0 + or out_features % 16 != 0 + ): + pytest.skip("FP8 GEMMs require dims that are divisible by 16") + if fp8_output and not fp8_compute: + pytest.skip("FP8 output requires FP8 compute") + if fp8_compute and dtype not in (torch.float16, torch.bfloat16): + pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") + + # Random data + x1_ref, x1_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_input), + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_weight), + ) + b_ref, b_test = None, None + if bias: + b_ref, b_test = make_reference_and_test_tensors( + out_features, + test_dtype=dtype, + test_device=device, + ) + x2_ref, x2_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_output, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x1_ref, w_ref, bias=b_ref) + x2_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operations + with te.fp8_model_init(enabled=fp8_weight): + model = te_ops.Sequential( + te_ops.Linear( + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + ), + te_ops.AddInPlace(), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + if bias: + model[0].bias.copy_(b_test) + del w_test + del b_test + with te.fp8_autocast(enabled=fp8_compute): + y_test = model(x1_test, x2_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = model._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance(forward_ops[0][0], ForwardLinearBiasAdd) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + model[0].weight._fp8_dtype + if is_float8_tensor(model[0].weight) + else tex.DType.kFloat8E4M3 + ) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") + dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx1_test, x1_ref.grad, **tols) + torch.testing.assert_close(dx2_test, x2_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + if bias: + db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, b_ref.grad, **tols) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_compute", (False, True)) + def test_backward_linear_add( + self, + *, + weight_shape: tuple[int, int] = (16, 16), + in_shape: Iterable[int] = (16, -1), + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_compute: bool, + fp8_input: bool = False, + fp8_weight: bool = False, + fp8_output: bool = False, + ) -> None: + """Backward dgrad GEMM + add""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + if fp8_input or fp8_weight or fp8_output or fp8_compute: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + if fp8_compute: + if ( + math.prod(in_shape[:-1]) % 16 != 0 + or in_features % 16 != 0 + or out_features % 16 != 0 + ): + pytest.skip("FP8 GEMMs require dims that are divisible by 16") + if fp8_output and not fp8_compute: + pytest.skip("FP8 output requires FP8 compute") + if fp8_compute and dtype not in (torch.float16, torch.bfloat16): + pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_input), + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_weight), + ) + dy1_ref, dy1_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + dy2_ref, dy2_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y1_ref = torch.nn.functional.linear(x_ref, w_ref) + y2_ref = x_ref + (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() + + # Implementation with fusible operations + with te.fp8_model_init(enabled=fp8_weight): + model = te_ops.Sequential( + te_ops.MakeExtraOutput(), + te_ops.Linear( + in_features, + out_features, + bias=False, + device=device, + dtype=dtype, + ), + ) + with torch.no_grad(): + model[1].weight.copy_(w_test) + del w_test + with te.fp8_autocast(enabled=fp8_compute): + y1_test, y2_test = model(x_test) + (y1_test * dy1_test + y2_test * dy2_test).sum().backward() + + # Check that backward operations have been fused + backward_ops = model._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance(backward_ops[0][0], BackwardLinearAdd) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + model[1].weight._fp8_dtype + if is_float8_tensor(model[1].weight) + else tex.DType.kFloat8E4M3 + ) + + # Check results + y1_test = y1_test.to(dtype=torch.float64, device="cpu") + y2_test = y2_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y1_test, y1_ref, **tols) + torch.testing.assert_close(y2_test, y2_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index ec3d4fd315..f437f877b4 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -9,11 +9,13 @@ """ from transformer_engine.pytorch.ops.basic import ( + AddInPlace, AllGather, AllReduce, BasicLinear, Bias, Identity, + MakeExtraOutput, ReduceScatter, Reshape, ) diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 3621910c8b..1003cc0337 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,10 +4,12 @@ """Single tensor operations supported by the operation fuser.""" +from .add_in_place import AddInPlace from .all_gather import AllGather from .all_reduce import AllReduce from .basic_linear import BasicLinear from .bias import Bias from .identity import Identity +from .make_extra_output import MakeExtraOutput from .reduce_scatter import ReduceScatter from .reshape import Reshape diff --git a/transformer_engine/pytorch/ops/basic/add_in_place.py b/transformer_engine/pytorch/ops/basic/add_in_place.py new file mode 100644 index 0000000000..041888f5d7 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/add_in_place.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for in-place add.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) + + +class AddInPlace(BasicOperation): + """Add in-place + + This operation requires an extra tensor input to the operation + fuser. The main input is added in-place to the extra input, and a + view of the extra input is output. + + This operation is considered an advanced feature and most users + are discouraged from using it. In-place operations break some + autograd assumptions and they can result in subtle, esoteric bugs. + + Compare to `MakeExtraOutput`, which does a similar operation in + the backward pass. + + """ + + # Operation expects buffer for output tensor + num_extra_inputs: int = 1 + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + output = basic_op_extra_inputs[0][0].detach() + output += input_ + return output, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + return grad_output, [], [(grad_output,)] diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 49923e7af8..826807d1c0 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -12,7 +12,11 @@ import torch -from transformer_engine.pytorch.cpp_extensions import fp8_gemm, gemm +from transformer_engine.pytorch.cpp_extensions import ( + FP8TensorMeta, + fp8_gemm, + gemm, +) from transformer_engine.pytorch.distributed import ( CudaRNGStatesTracker, gather_along_first_dim, @@ -32,6 +36,7 @@ canonicalize_device, canonicalize_dtype, convert_tensor, + devices_match, is_float8_tensor, reshape, ) @@ -308,6 +313,8 @@ def _functional_forward( bias: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + out: Optional[torch.Tensor] = None, + accumulate_into_out: bool = False, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, @@ -330,6 +337,10 @@ def _functional_forward( Tensor device dtype: torch.dtype, default = default dtype Tensor datatype + out: torch.Tensor, optional + Output tensor + accumulate_into_out: bool, default = `False` + Add result to output tensor instead of overwriting tensor_parallel_mode: {`None`, "column", "row"}, default = `None` Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group @@ -365,19 +376,25 @@ def _functional_forward( # Check device if device is None: - device = weight.device + device = weight.device if out is None else out.device device = canonicalize_device(device) if device.type != "cuda": raise ValueError(f"Only CUDA devices are supported (got {device})") + if out is not None and not devices_match(out.device, device): + raise ValueError( + f"Output tensor has invalid device (expected {device}, got {out.device})" + ) # Check datatype if dtype is None: - dtype = weight.dtype + dtype = weight.dtype if out is None else out.dtype dtype = canonicalize_dtype(dtype) if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + if out is not None and out.dtype != dtype: + raise ValueError(f"Output tensor has invalid dtype (expected {dtype}, got {out.dtype})") - # Check tensor dims + # Check input tensor dims input_dims = tuple(input.size()) weight_dims = tuple(weight.size()) if len(weight_dims) != 2: @@ -389,6 +406,32 @@ def _functional_forward( "are not compatible" ) + # Check output tensor dims + output_dims: list[int] + if out is None: + output_dims = list(input_dims) + output_dims[0] = -1 + output_dims[-1] = weight_dims[0] + else: + output_dims = list(out.size()) + if len(output_dims) == 0 or weight_dims[0] != output_dims[-1]: + raise ValueError( + f"Output tensor (shape={output_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + + # Check if accumulating into output tensor + if accumulate_into_out: + if out is None: + raise ValueError( + "Attempted to accumulate into output tensor without providing output tensor" + ) + if tensor_parallel_mode == "row": + raise ValueError( + "Accumulating into output tensor is not supported with row tensor parallelism" + ) + # Check if FP8 is enabled if with_fp8_compute: if input_fp8_meta is None and not is_float8_tensor(input): @@ -399,9 +442,18 @@ def _functional_forward( input_fp8_meta = None weight_fp8_meta = None output_fp8_meta = None - with_fp8_output = ( - with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None - ) + with_fp8_output = with_fp8_compute and tensor_parallel_mode != "row" + if out is None: + with_fp8_output = with_fp8_output and output_fp8_meta is not None + else: + if is_float8_tensor(out): + if not with_fp8_output: + raise ValueError( + "Output tensor is a Float8Tensor, but FP8 output is not supported" + ) + out._reset_caches() + else: + with_fp8_output = False # Check input tensor x_local = reshape( @@ -476,7 +528,9 @@ def _functional_forward( # Construct output tensor y = None - if with_fp8_output: + if out is not None: + y = reshape(out, (-1, output_dims[-1])) + elif with_fp8_output: fp8_dtype = get_fp8_te_dtype( output_fp8_meta["recipe"], fprop_tensor=True, @@ -506,19 +560,31 @@ def _functional_forward( x_async = None if with_fp8_compute: kwargs = dict( + accumulate=accumulate_into_out, out=y, bias=b, use_bias=(b is not None), ) if with_fp8_output: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=y._fp8_meta_forward, - ) + if y._fp8_meta is None: + # Hackily create FP8TensorMeta if needed + fp8_meta = FP8TensorMeta() + fp8_meta.scale = y._scale_inv.reciprocal() + fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=device) + fp8_meta.scale_inv = y._scale_inv + fp8_meta_index = 0 + else: + # Get FP8TensorMeta from Float8Tensor + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=y._fp8_meta_forward, + ) + fp8_meta = y._fp8_meta[fp8_meta_key] + fp8_meta_index = y._fp8_meta_index kwargs.update( dict( out=y._data, - out_index=y._fp8_meta_index, - fp8_meta_tensor=y._fp8_meta[fp8_meta_key], + out_index=fp8_meta_index, + fp8_meta_tensor=fp8_meta, D_dtype=y._fp8_dtype, ) ) @@ -541,6 +607,7 @@ def _functional_forward( x, y.dtype, get_workspace(), + accumulate=accumulate_into_out, out=y, bias=b, use_bias=(b is not None), @@ -553,13 +620,11 @@ def _functional_forward( else: torch.distributed.all_reduce(y, group=tensor_parallel_group) - # Reshape output tensor - output_dims = list(input_dims) - output_dims[0] = -1 - output_dims[-1] = weight_dims[0] - output = reshape(y, output_dims) + # Reshape output tensor if needed + if out is None: + out = reshape(y, output_dims) - return output, x_local, w + return out, x_local, w @staticmethod def _functional_backward( @@ -573,6 +638,10 @@ def _functional_backward( weight_requires_grad: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + grad_weight: Optional[torch.Tensor] = None, + accumulate_into_grad_weight: bool = False, + grad_input: Optional[torch.Tensor] = None, + accumulate_into_grad_input: bool = False, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, @@ -581,8 +650,6 @@ def _functional_backward( weight_fp8_meta: Optional[dict[str, Any]] = None, grad_output_fp8_meta: Optional[dict[str, Any]] = None, grad_input_fp8_meta: Optional[dict[str, Any]] = None, - accumulate_into_grad_weight: bool = False, - grad_weight: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Functional API for backward pass @@ -608,6 +675,14 @@ def _functional_backward( Tensor device dtype: torch.dtype, default = default dtype Tensor datatype + grad_weight: torch.Tensor, optional + Loss gradient w.r.t. weight tensor + accumulate_into_grad_weight: bool, default = `False` + Add result to weight grad instead of overwriting + grad_input: torch.Tensor, optional + Loss gradient w.r.t. input tensor + accumulate_into_grad_input: bool, default = `False` + Add result to input grad instead of overwriting tensor_parallel_mode: {`None`, "column", "row"}, default = `None` Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group @@ -632,10 +707,6 @@ def _functional_backward( grad_output_fp8_meta: dict, optional FP8 metadata for casting loss gradient w.r.t. input tensor to FP8 - accumulate_into_grad_weight: bool, default = `False` - Accumulate into weight grad instead of overwriting - grad_weight: torch.Tensor, optional - Loss gradient w.r.t. weight tensor Returns ------- @@ -678,6 +749,34 @@ def _functional_backward( f"and weight tensor (shape={weight_dims}) " "are not compatible" ) + if grad_input is not None and tuple(grad_input.size()) != input_dims: + raise ValueError( + f"Grad input tensor (shape={tuple(grad_input.size())}) " + f"does not match expected shape ({input_dims})" + ) + + # Check grad input tensor + if not input_requires_grad: + grad_input = None + if grad_input is not None and not devices_match(grad_input.device, device): + raise ValueError( + f"Grad input tensor has invalid device (expected {device}, got {grad_input.device})" + ) + if grad_input is not None and grad_input.dtype != dtype: + raise ValueError( + f"Grad input tensor has invalid dtype (expected {dtype}, got {grad_input.dtype})" + ) + if accumulate_into_grad_input: + if grad_input is None: + raise ValueError( + "Attempted to accumulate into grad input tensor " + "without providing grad input tensor" + ) + if tensor_parallel_mode == "column": + raise ValueError( + "Accumulating into grad input tensor " + "is not supported with column tensor parallelism" + ) # Check if FP8 is enabled if with_fp8_compute: @@ -689,11 +788,19 @@ def _functional_backward( grad_output_fp8_meta = None grad_input_fp8_meta = None with_fp8_grad_input = ( - with_fp8_compute - and input_requires_grad - and tensor_parallel_mode != "column" - and grad_input_fp8_meta is not None + with_fp8_compute and input_requires_grad and tensor_parallel_mode != "column" ) + if grad_input is None: + with_fp8_grad_input = with_fp8_grad_input and grad_input_fp8_meta is not None + else: + if is_float8_tensor(grad_input): + if not with_fp8_grad_input: + raise ValueError( + "Grad input tensor is a Float8Tensor, but FP8 output is not supported" + ) + grad_input._reset_caches() + else: + with_fp8_grad_input = False # Check grad output tensor dy_async = None @@ -806,7 +913,9 @@ def _functional_backward( w = w.from_float8() # Construct grad input tensor - if with_fp8_grad_input: + if grad_input is not None: + dx = reshape(grad_input, (-1, input_dims[-1])) + elif with_fp8_grad_input: fp8_dtype = get_fp8_te_dtype( grad_input_fp8_meta["recipe"], fprop_tensor=False, @@ -835,16 +944,32 @@ def _functional_backward( _wait_async(dy_async) dy_async = None if with_fp8_compute: - kwargs = dict(out=dx) + kwargs = dict( + accumulate=accumulate_into_grad_input, + out=dx, + ) if with_fp8_grad_input: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dx._fp8_meta_forward, - ) + if dx._fp8_meta is None: + # Hackily create FP8TensorMeta if needed + fp8_meta = FP8TensorMeta() + fp8_meta.scale = dx._scale_inv.reciprocal() + fp8_meta.amax_history = torch.empty( + 1, 1, dtype=torch.float32, device=device + ) + fp8_meta.scale_inv = dx._scale_inv + fp8_meta_index = 0 + else: + # Get FP8TensorMeta from Float8Tensor + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dx._fp8_meta_forward, + ) + fp8_meta = dx._fp8_meta[fp8_meta_key] + fp8_meta_index = dx._fp8_meta_index kwargs.update( dict( out=dx._data, - out_index=dx._fp8_meta_index, - fp8_meta_tensor=dx._fp8_meta[fp8_meta_key], + out_index=fp8_meta_index, + fp8_meta_tensor=fp8_meta, D_dtype=dx._fp8_dtype, ) ) @@ -867,6 +992,7 @@ def _functional_backward( dy, dx.dtype, get_workspace(), + accumulate=accumulate_into_grad_input, layout="NN", out=dx, ) @@ -936,8 +1062,7 @@ def _functional_backward( _wait_async(dy_async) _wait_async(x_async) _wait_async(dx_async) - grad_input = None - if dx is not None: + if dx is not None and grad_input is None: grad_input = reshape(dx, input_dims) return grad_input, grad_weight @@ -1027,6 +1152,8 @@ def op_backward( weight_requires_grad=ctx.weight_requires_grad, device=self.device, dtype=self.dtype, + grad_weight=grad_weight, + accumulate_into_grad_weight=accumulate_into_main_grad, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, @@ -1034,8 +1161,6 @@ def op_backward( weight_fp8_meta=ctx.weight_fp8_meta, grad_output_fp8_meta=ctx.grad_output_fp8_meta, grad_input_fp8_meta=ctx.grad_input_fp8_meta, - accumulate_into_grad_weight=accumulate_into_main_grad, - grad_weight=grad_weight, ) # Clear input tensor if possible diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py new file mode 100644 index 0000000000..db1651c184 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Make extra tensor output in operation fuser.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) + + +class MakeExtraOutput(BasicOperation): + """Make extra output in operation fuser + + If this operation is included in the operation fuser, then the + operation fuser will return the intermediate tensor as an extra + tensor output. In the backward pass, the gradient is directly + accumulated into the gradient w.r.t. the extra output. + + This operation is considered an advanced feature and most users + are discouraged from using it. In-place operations break some + autograd assumptions and they can result in subtle, esoteric bugs. + + Compare to `AddInPlace`, which does a similar operation in the + backward pass. + + """ + + # Operation expects buffer for output tensor + num_extra_outputs: int = 1 + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + return input_, [(input_,)] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + grad_input = basic_op_grad_extra_outputs[0][0] + grad_input += grad_output + return grad_input, [], [()] diff --git a/transformer_engine/pytorch/ops/fused_forward/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py similarity index 52% rename from transformer_engine/pytorch/ops/fused_forward/__init__.py rename to transformer_engine/pytorch/ops/fused/__init__.py index ed523a067a..bd832254d8 100644 --- a/transformer_engine/pytorch/ops/fused_forward/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -4,7 +4,15 @@ """Compound tensor operation supported by the operation fuser.""" -from .linear_bias_activation import ( +from .backward_linear_add import ( + BackwardLinearAdd, + fuse_backward_linear_add, +) +from .forward_linear_bias_activation import ( ForwardLinearBiasActivation, fuse_forward_linear_bias_activation, ) +from .forward_linear_bias_add import ( + ForwardLinearBiasAdd, + fuse_forward_linear_bias_add, +) diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py new file mode 100644 index 0000000000..138eca3d96 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -0,0 +1,156 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused backward dgrad GEMM + add.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.ops.basic import BasicLinear, MakeExtraOutput +from transformer_engine.pytorch.ops.op import ( + FusedOperation, + FusibleOperation, + OperationContext, +) +from ...utils import clear_tensor_data + + +class BackwardLinearAdd(FusedOperation): + """Fused backward dgrad GEMM + add + + Column tensor parallelism is not supported since that requires + communication immediately after the dgrad GEMM. + + """ + + def __init__( + self, + *, + linear: BasicLinear, + backward_add: MakeExtraOutput, + ) -> None: + super().__init__((linear, backward_add)) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operations + linear_op = self.basic_ops[0] + linear_op_ctx = basic_op_ctxs[0] + + # Saved tensors from forward pass + (x_local,) = linear_op_ctx.saved_tensors + + # wgrad fusion + accumulate_into_main_grad = linear_op._accumulate_into_main_grad + grad_weight = None + if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: + if not hasattr(linear_op.weight, "main_grad"): + raise RuntimeError( + "BasicLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + grad_weight = linear_op.weight.main_grad.detach() + else: + accumulate_into_main_grad = False + + # Linear backward pass + grad_input = basic_op_grad_extra_outputs[1][0] + grad_input, grad_weight = BasicLinear._functional_backward( + grad_output=grad_output, + input=x_local, + weight=linear_op.weight, + input_dims=linear_op_ctx.input_dims, + weight_dims=linear_op.weight.size(), + input_requires_grad=linear_op_ctx.input_requires_grad, + weight_requires_grad=linear_op_ctx.weight_requires_grad, + device=linear_op.device, + dtype=linear_op.dtype, + grad_weight=grad_weight, + accumulate_into_grad_weight=accumulate_into_main_grad, + grad_input=grad_input, + accumulate_into_grad_input=True, + tensor_parallel_mode=linear_op.tensor_parallel_mode, + tensor_parallel_group=linear_op.tensor_parallel_group, + sequence_parallel=linear_op.sequence_parallel, + with_fp8_compute=linear_op_ctx.with_fp8_compute, + weight_fp8_meta=linear_op_ctx.weight_fp8_meta, + grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta, + grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, + ) + if accumulate_into_main_grad: + grad_weight = None + + # Clear input tensor if possible + if linear_op_ctx.has_prev_op: + clear_tensor_data(x_local) + + return grad_input, [(grad_weight,), ()], [(), ()] + + +def fuse_backward_linear_add( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fused backward dgrad GEMM + add + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 2: + out.extend(window) + + # Check if first op is linear + window, ops = ops[:1], ops[1:] + op, _ = window[0] + if not isinstance(op, BasicLinear): + continue + if op.tensor_parallel_mode == "column": + # Row tensor-parallelism requires communication after the + # GEMM + continue + + # Check if second op is "make extra output" + op, _ = ops[0] + if not isinstance(op, MakeExtraOutput): + continue + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = BackwardLinearAdd( + linear=window[0][0], + backward_add=window[1][0], + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fused_forward/linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py similarity index 93% rename from transformer_engine/pytorch/ops/fused_forward/linear_bias_activation.py rename to transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 1504dc4a53..5fd52405e4 100644 --- a/transformer_engine/pytorch/ops/fused_forward/linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -2,9 +2,10 @@ # # See LICENSE for license information. -"""Fused operation for GEMM, bias, activation in the forward pass.""" +"""Fused operation for forward GEMM + bias + activation.""" from __future__ import annotations +from collections.abc import Iterable from typing import Any, Optional import torch @@ -20,7 +21,7 @@ class ForwardLinearBiasActivation(FusedOperation): - """Fused GEMM, bias, activation in the forward pass + """Fused forward GEMM + bias + activation Bias and activation are both optional. Row tensor parallelism is not supported since that requires communication immediately after @@ -60,10 +61,12 @@ def fuser_forward( self, basic_op_ctxs: list[OperationContext], input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_prev_ops: list[Optional[BasicOperation]], basic_op_next_ops: list[Optional[BasicOperation]], basic_op_kwargs: list[dict[str, Any]], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: # Get basic operations idx = self._op_idxs["linear"] @@ -128,13 +131,13 @@ def fuser_forward( linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None - return output + return output, [() for _ in range(len(self.basic_ops))] def fuse_forward_linear_bias_activation( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: - """Fuse GEMM, bias, activation in the forward pass + """Fuse forward GEMM + bias + activation Parameters ---------- diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py new file mode 100644 index 0000000000..6ddee2849a --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -0,0 +1,196 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused operation for forward GEMM + bias + add.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + FusedOperation, + FusibleOperation, + OperationContext, +) + + +class ForwardLinearBiasAdd(FusedOperation): + """Fused forward GEMM + bias + add + + Bias is optional. Row tensor parallelism is not supported since + that requires communication immediately after the GEMM. + + """ + + def __init__( + self, + *, + linear: BasicLinear, + bias: Optional[Bias], + add: AddInPlace, + ) -> None: + + # Basic operations that comprise this fused operation + op_idxs = dict( + linear=0, + bias=None, + add=None, + ) + ops = [linear] + if bias is not None: + op_idxs["bias"] = len(ops) + ops.append(bias) + op_idxs["add"] = len(ops) + ops.append(add) + + # Initialize base class + super().__init__(ops) + + # Index of each basic operations + self._op_idxs: dict[str, Optional[int]] = op_idxs + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + + # Get basic operations + idx = self._op_idxs["linear"] + linear_op = self.basic_ops[idx] + linear_op_ctx = basic_op_ctxs[idx] + if self._op_idxs["bias"] is None: + bias_op = None + bias = None + else: + idx = self._op_idxs["bias"] + bias_op = self.basic_ops[idx] + bias = bias_op.bias + if basic_op_kwargs[idx]: + raise ValueError("Bias operation forward does not expect keyword arguments") + + # FP8 metadata + with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + if with_fp8_compute: + input_fp8_meta = linear_op.get_fp8_meta("input") + weight_fp8_meta = linear_op.get_fp8_meta("param") + grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + prev_op = basic_op_prev_ops[0] + if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: + grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + + # Linear forward + output = basic_op_extra_inputs[self._op_idxs["add"]][0] + output, x_local, _ = BasicLinear._functional_forward( + input=input_, + weight=linear_op.weight, + bias=bias, + device=linear_op.device, + dtype=linear_op.dtype, + out=output, + accumulate_into_out=True, + tensor_parallel_mode=linear_op.tensor_parallel_mode, + tensor_parallel_group=linear_op.tensor_parallel_group, + sequence_parallel=linear_op.sequence_parallel, + with_fp8_compute=with_fp8_compute, + input_fp8_meta=input_fp8_meta, + weight_fp8_meta=weight_fp8_meta, + output_fp8_meta=output_fp8_meta, + ) + + # Save state for backward pass + linear_op_ctx.save_for_backward(x_local) + linear_op_ctx.with_fp8_compute = with_fp8_compute + linear_op_ctx.weight_fp8_meta = weight_fp8_meta + linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta + linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.input_dims = input_.size() + linear_op_ctx.input_requires_grad = input_.requires_grad + linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad + linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + + return output, [() for _ in range(len(self.basic_ops))] + + +def fuse_forward_linear_bias_add( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fuse forward GEMM + bias + add + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 2: + out.extend(window) + + # Check if first op is linear + window, ops = ops[:1], ops[1:] + op, _ = window[0] + if not isinstance(op, BasicLinear): + continue + if op.tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after the + # GEMM + continue + linear = op + op, _ = ops[0] + + # Check if next op is bias + bias = None + if isinstance(op, Bias): + bias = op + window.extend(ops[:1]) + ops = ops[1:] + if len(ops) == 0: + continue + op, _ = ops[0] + + # Check if next op is add in-place + if not isinstance(op, AddInPlace): + continue + add = op + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = ForwardLinearBiasAdd( + linear=linear, + bias=bias, + add=add, + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 06ea608ed8..a7c99c592d 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -16,11 +16,18 @@ FusibleOperation, OperationContext, ) -from transformer_engine.pytorch.ops.fused_forward import ( +from transformer_engine.pytorch.ops.fused import ( + fuse_backward_linear_add, fuse_forward_linear_bias_activation, + fuse_forward_linear_bias_add, ) +def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: + """Split tuple at index""" + return t[:idx], t[idx:] + + class _OperationFuserAutogradFunction(torch.autograd.Function): """Autograd function for a pipeline of operations @@ -38,8 +45,10 @@ def forward( backward_ops: list[tuple[FusibleOperation, list[int]]], basic_ops: list[BasicOperation], basic_op_kwargs: list[dict[str, Any]], - *params: torch.nn.Parameter, - ) -> torch.Tensor: + num_params: int, + num_extra_inputs: int, + *params_and_extra_inputs: torch.nn.Parameter, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass Parameters @@ -60,39 +69,82 @@ def forward( Basic operations basic_op_kwargs: list of dict Keyword arguments to BasicOperation - *params: torch.nn.Parameter - Parameters in operation pipeline + num_params: int + Number of parameter tensors to include in autograd graph. + *params_and_extra_inputs: torch.Tensor + Other tensor inputs to include in autograd graph. Consists + of parameter tensors, followed by extra operation inputs. + + Returns + ------- + Output tensor(s). If none of the operations have any extra + tensor outputs, then the pipeline's output tensor is returned. + Otherwise, a tuple with the pipeline's output tensor and extra + tensor outputs is returned. """ # Operation autograd contexts basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))] + # Unflatten list of parameters and extra tensor inputs + if len(params_and_extra_inputs) != num_params + num_extra_inputs: + raise ValueError( + f"Expected {num_params + num_extra_inputs} extra tensor arguments " + f"({num_params} parameters, {num_extra_inputs} extra inputs), " + f"but got {len(params_and_extra_inputs)}" + ) + _, extra_inputs = _split_tuple(params_and_extra_inputs, num_params) + basic_op_extra_inputs = [] + for op in basic_ops: + xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) + basic_op_extra_inputs.append(xs) + # Apply forward ops x = input_ requires_grad = x.requires_grad + extra_outputs = [None for _ in range(len(basic_ops))] for op, basic_op_idxs in forward_ops: # Forward op + extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] next_ops = [ basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs ] - x = op.fuser_forward( + x, fused_op_extra_outputs = op.fuser_forward( [basic_op_ctxs[idx] for idx in basic_op_idxs], x, - prev_ops, - next_ops, - [basic_op_kwargs[idx] for idx in basic_op_idxs], + basic_op_extra_inputs=extra_inputs, + basic_op_prev_ops=prev_ops, + basic_op_next_ops=next_ops, + basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], ) + for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): + extra_outputs[idx] = ys # Check if backward op is required if not requires_grad: requires_grad = any(param.requires_grad for param in op.parameters()) + if not requires_grad: + requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) for idx in basic_op_idxs: basic_op_ctxs[idx]._requires_grad = requires_grad x.requires_grad_(requires_grad=requires_grad) + # Flatten list of extra outputs + extra_outputs_flat = [] + for idx, ys in enumerate(extra_outputs): + ys = list(ys) + num_extra_outputs = basic_ops[idx].num_extra_outputs + if len(ys) != num_extra_outputs: + raise RuntimeError( + f"Expected op {idx} to generate " + "{num_extra_outputs} extra inputs, " + f"but got {len(ys)}" + ) + extra_outputs_flat.extend(ys) + # Flatten list of saved tensors to_save = [] for ctx in basic_op_ctxs: @@ -108,8 +160,13 @@ def forward( func_ctx.backward_ops = backward_ops func_ctx.basic_ops = basic_ops func_ctx.basic_op_ctxs = basic_op_ctxs + func_ctx.num_params = num_params + func_ctx.num_extra_inputs = num_extra_inputs + func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + if extra_outputs_flat: + return x, *extra_outputs_flat return x @staticmethod @@ -117,6 +174,7 @@ def forward( def backward( func_ctx: Any, grad_output: torch.Tensor, + *grad_extra_outputs: torch.Tensor, ) -> tuple[Optional[torch.Tensor], ...]: """Backward pass""" @@ -126,15 +184,25 @@ def backward( basic_op_ctxs = func_ctx.basic_op_ctxs # Unflatten list of saved tensors - saved_tensors = func_ctx.saved_tensors for ctx in basic_op_ctxs: - ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)] + ctx.saved_tensors = func_ctx.saved_tensors[slice(*ctx._saved_tensors_range)] ctx._saved_tensors_range = None - del saved_tensors + + # Unflatten list of extra tensor output grads + if len(grad_extra_outputs) != func_ctx.num_extra_outputs: + raise ValueError( + f"Expected grads for {func_ctx.num_extra_outputs} extra tensor outputs, " + f"but got {len(grad_extra_outputs)}" + ) + basic_op_grad_extra_outputs = [] + for op in basic_ops: + dys, grad_extra_outputs = _split_tuple(grad_extra_outputs, op.num_extra_outputs) + basic_op_grad_extra_outputs.append(dys) # Apply backward ops dx = grad_output grad_params = [None for _ in range(len(basic_ops))] + grad_extra_inputs = [None for _ in range(len(basic_ops))] for op, basic_op_idxs in backward_ops: # Stop if no more gradients are required @@ -143,13 +211,17 @@ def backward( break # Backward op - dx, fused_op_dparams = op.fuser_backward( + grad_extra_outputs = [basic_op_grad_extra_outputs[idx] for idx in basic_op_idxs] + dx, fused_op_grad_params, fused_op_grad_extra_inputs = op.fuser_backward( [basic_op_ctxs[idx] for idx in basic_op_idxs], dx, + basic_op_grad_extra_outputs=grad_extra_outputs, ) - for idx, basic_op_dparams in zip(basic_op_idxs, fused_op_dparams): - grad_params[idx] = basic_op_dparams + for idx, dparams in zip(basic_op_idxs, fused_op_grad_params): + grad_params[idx] = dparams basic_op_ctxs[idx].saved_tensors = None + for idx, dxs in zip(basic_op_idxs, fused_op_grad_extra_inputs): + grad_extra_inputs[idx] = dxs # Flatten list of parameter gradients grad_params_flat = [] @@ -166,6 +238,22 @@ def backward( ) grad_params_flat.extend(dparams) + # Flatten list of parameter gradients + grad_extra_inputs_flat = [] + for idx, dxs in enumerate(grad_extra_inputs): + num_extra_inputs = basic_ops[idx].num_extra_inputs + if dxs is None: + dxs = [None for _ in range(num_extra_inputs)] + else: + dxs = list(dxs) + if len(dxs) != num_extra_inputs: + raise RuntimeError( + f"Expected op {idx} to generate grads " + f"for {num_extra_inputs} extra inputs, " + f"but got {len(dxs)}" + ) + grad_extra_inputs_flat.extend(dxs) + # Update FP8 scaling factors if func_ctx.is_first_module and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) @@ -176,7 +264,10 @@ def backward( None, # backward_ops None, # basic_ops None, # basic_op_kwargs - *grad_params_flat, # params + None, # num_params + None, # num_extra_inputs + *grad_params_flat, + *grad_extra_inputs_flat, ) @@ -208,6 +299,9 @@ def __init__( self._num_basic_ops: int = len(basic_ops) self._basic_ops: list[BasicOperation] = basic_ops + # Number of extra tensor inputs + self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops) + # Ops for forward and backward pass self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]] @@ -224,6 +318,7 @@ def _fuse_forward_ops( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in forward pass""" + ops = fuse_forward_linear_bias_add(ops) ops = fuse_forward_linear_bias_activation(ops) return ops @@ -233,6 +328,7 @@ def _fuse_backward_ops( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in backward pass""" + ops = fuse_backward_linear_add(ops) return ops def fuse_ops(self) -> None: @@ -243,8 +339,9 @@ def fuse_ops(self) -> None: def __call__( self, input: torch.Tensor, # pylint: disable=redefined-builtin + *extra_inputs: torch.Tensor, basic_op_kwargs: Optional[list[dict[str, Any]]] = None, - ) -> torch.Tensor: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: # Initialization before forward pass for op in self._basic_ops: @@ -255,9 +352,7 @@ def __call__( basic_op_kwargs = [{} for _ in range(len(self._basic_ops))] # Flatten list of parameters - params = [] - for op in self._basic_ops: - params.extend(op.parameters()) + params = [param for op in self._basic_ops for param in op.parameters()] # Fuser forward pass return _OperationFuserAutogradFunction.apply( @@ -266,5 +361,8 @@ def __call__( self._backward_ops, self._basic_ops, basic_op_kwargs, + len(params), + self._num_extra_inputs, *params, + *extra_inputs, ) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 3d90d07b84..47c6567056 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -67,10 +67,12 @@ def fuser_forward( self, basic_op_ctxs: list[OperationContext], input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_prev_ops: list[Optional[BasicOperation]], basic_op_next_ops: list[Optional[BasicOperation]], basic_op_kwargs: list[dict[str, Any]], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: """Forward pass This op is either a basic op or the fusion of basic ops, so @@ -82,24 +84,27 @@ def fuser_forward( Parameters ---------- basic_op_ctxs: list of OperationContext - Contexts for corresponding basic operations + Contexts for basic operations input_: torch.Tensor Input tensor + basic_op_extra_inputs: list of torch.Tensor + Extra tensor inputs to basic operations basic_op_prev_ops: list of BasicOperation - Basic operations that preceed each of the corresponding - basic operations (or `None` if corresponding basic op is - first) + Basic operations that preceed this operation's basic + operations basic_op_next_ops: list of BasicOperation - Basic operations that follow each of the corresponding - basic operations (or `None` if corresponding basic op is - last) + Basic operations that follow this operation's basic + operations basic_op_kwargs: list of dict - Keyword arguments to forward functions of corresponding - basic operations + Keyword arguments to forward functions of basic + operations. Returns ------- - torch.Tensor: Output tensor. + torch.Tensor: + Output tensor. + Iterable of torch.Tensor: + Extra tensor outputs from basic operations. """ raise NotImplementedError( @@ -110,7 +115,13 @@ def fuser_backward( self, basic_op_ctxs: list[OperationContext], grad_output: torch.Tensor, - ) -> tuple[torch.Tensor, Iterable[Iterable[Optional[torch.Tensor]]]]: + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: """Backward pass This op is either a basic op or the fusion of basic ops, so @@ -122,24 +133,21 @@ def fuser_backward( Parameters ---------- basic_op_ctxs: list of OperationContext - Contexts for corresponding basic operations. + Contexts for basic operations grad_output: torch.Tensor - Loss gradient w.r.t. operation output. - basic_op_prev_ops: list of BasicOperation - Basic operations that preceed each of the corresponding - basic operations (or `None` if corresponding basic op is - first) - basic_op_next_ops: list of BasicOperation - Basic operations that follow each of the corresponding - basic operations (or `None` if corresponding basic op is - last) + Loss gradient w.r.t. operation output + basic_op_grad_extra_outputs: list of tuple of torch.Tensor + Loss gradients w.r.t. extra tensor outputs from basic + operations. Returns ------- torch.Tensor: Loss gradient w.r.t. operation input Iterable of iterable of torch.Tensor: - Loss gradients w.r.t. parameters for corresponding basic + Loss gradients w.r.t. parameters for basic operations + Iterable of iterable of torch.Tensor: + Loss gradients w.r.t. extra tensor inputs to basic operations """ @@ -156,6 +164,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): """ + # Number of extra tensor inputs + num_extra_inputs: int = 0 + # Number of extra tensor outputs + num_extra_outputs: int = 0 + def __init__(self) -> None: super().__init__() @@ -297,6 +310,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, + *, prev_op: Optional[BasicOperation] = None, next_op: Optional[BasicOperation] = None, **kwargs: Any, @@ -309,6 +323,10 @@ def op_forward( Context to coordinate between forward and backward passes input_: torch.Tensor Input tensor + prev_op: BasicOperation, optional + Basic operation that preceeds this operation + next_op: BasicOperation, optional + Basic operation that follows this operation Returns ------- @@ -345,35 +363,63 @@ def fuser_forward( self, basic_op_ctxs: list[OperationContext], input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_prev_ops: list[Optional[BasicOperation]], basic_op_next_ops: list[Optional[BasicOperation]], basic_op_kwargs: list[dict[str, Any]], - ) -> torch.Tensor: - return self.op_forward( + ) -> tuple[torch.Tensor, list[tuple[()]]]: + if self.num_extra_inputs > 0 or self.num_extra_outputs > 0: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It should override `fuser_forward` instead of `op_forward`." + ) + output = self.op_forward( basic_op_ctxs[0], input_, - basic_op_prev_ops[0], - basic_op_next_ops[0], + prev_op=basic_op_prev_ops[0], + next_op=basic_op_next_ops[0], **basic_op_kwargs[0], ) + return output, [()] def fuser_backward( self, basic_op_ctxs: list[OperationContext], grad_output: torch.Tensor, - ) -> tuple[torch.Tensor, Iterable[Iterable[Optional[torch.Tensor]]]]: + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[Iterable[Optional[torch.Tensor]]], + list[tuple[()]], + ]: + if self.num_extra_inputs > 0 or self.num_extra_outputs > 0: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It should override `fuser_backward` instead of `op_backward`." + ) grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output) - return grad_input, [grad_params] + return grad_input, [grad_params], [()] def forward( self, input: torch.Tensor, # pylint: disable=redefined-builtin + *extra_inputs: torch.Tensor, **kwargs: Any, - ) -> torch.Tensor: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply operation""" from .fuser import OperationFuser - return OperationFuser([self], fuse_ops=False)(input, [kwargs]) + return OperationFuser([self], fuse_ops=False)( + input, + *extra_inputs, + basic_op_kwargs=[kwargs], + ) class FusedOperation(FusibleOperation): @@ -417,6 +463,7 @@ def pre_forward(self) -> None: def forward( self, input: torch.Tensor, # pylint: disable=redefined-builtin + *extra_inputs: torch.Tensor, basic_op_kwargs: Optional[list[dict[str, Any]]] = None, ) -> torch.Tensor: """Apply operation""" @@ -424,4 +471,8 @@ def forward( basic_op_kwargs = [{} for _ in range(len(self.basic_ops))] from .fuser import OperationFuser - return OperationFuser([self], fuse_ops=False)(input, basic_op_kwargs) + return OperationFuser([self], fuse_ops=False)( + input, + *extra_inputs, + basic_op_kwargs=basic_op_kwargs, + ) diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index 57b4036bba..c5e25fe1f2 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -144,28 +144,44 @@ def _make_module_groups( modules: Iterable[torch.nn.Module], ) -> list[OperationFuser | torch.nn.Module]: """Make list of modules, with fusible operations grouped together""" - module_groups = [] - fusible_ops = [] - - def maybe_add_fuser(): - nonlocal fusible_ops - if fusible_ops: - module_groups.append(OperationFuser(fusible_ops, fuse_ops=True)) - fusible_ops = [] + # Group fusible operations together + groups = [] for module in modules: if isinstance(module, FusibleOperation): - fusible_ops.append(module) + if not groups or not isinstance(groups[-1], list): + groups.append([]) + groups[-1].append(module) else: - maybe_add_fuser() - module_groups.append(module) - maybe_add_fuser() - return module_groups + groups.append(module) + for idx, group in enumerate(groups): + if isinstance(group, list): + groups[idx] = OperationFuser(group, fuse_ops=True) + + # Check if operations expect extra input or output tensors + # Note: If any op has extra inputs or outputs, then the entire + # Sequential must be made up of TE ops. + if len(groups) > 1: + ops = [] + for group in groups: + if isinstance(group, OperationFuser): + ops.extend(group._basic_ops) + num_extra_inputs = sum(op.num_extra_inputs for op in ops) + num_extra_outputs = sum(op.num_extra_outputs for op in ops) + if num_extra_inputs > 0 or num_extra_outputs > 0: + raise RuntimeError( + f"`Sequential` expects {num_extra_inputs} extra inputs " + f"and {num_extra_outputs} extra outputs, " + "but it contains non-fusible operations" + ) + + return groups def forward( self, input: torch.Tensor, # pylint: disable=redefined-builtin - ) -> torch.Tensor: + *extra_inputs: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass""" # Create module groups if needed @@ -175,5 +191,5 @@ def forward( # Forward pass for each module group x = input for module_group in self._module_groups: - x = module_group(x) + x = module_group(x, *extra_inputs) return x From 44c8924f67b8c1da2cc8cff0cde99a9fdad08050 Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:23:07 -0700 Subject: [PATCH 43/72] Bug fix for num_warmup_iters=0 case (#1095) Buf fix for num_warmup_iters=0 case Signed-off-by: Vasudevan Rengasamy Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f6331c9b2a..e2642bc360 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -213,7 +213,7 @@ def _make_graphed_callables( only_inputs=True, allow_unused=allow_unused_input, ) - del outputs, grad_inputs + del outputs, grad_inputs torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, From ed3fb6b29c538b540db68cc724aa4960ee68839d Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Mon, 12 Aug 2024 15:31:53 -0700 Subject: [PATCH 44/72] TE with threading build (#1092) * added threading build back * integrating threading for pytorch and paddle extensions * added messages --------- Signed-off-by: Phuong Nguyen Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- build_tools/paddle.py | 3 ++- build_tools/pytorch.py | 2 +- build_tools/utils.py | 4 ++-- transformer_engine/common/CMakeLists.txt | 16 ++++++++++++++++ 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/build_tools/paddle.py b/build_tools/paddle.py index 163f094fce..f3140cf028 100644 --- a/build_tools/paddle.py +++ b/build_tools/paddle.py @@ -6,6 +6,7 @@ from pathlib import Path import setuptools +import os from .utils import cuda_version @@ -62,7 +63,7 @@ def setup_paddle_extension( print("Could not determine CUDA Toolkit version") else: if version >= (11, 2): - nvcc_flags.extend(["--threads", "4"]) + nvcc_flags.extend(["--threads", os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1")]) if version >= (11, 0): nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) if version >= (11, 8): diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index e423ffe907..9b858653de 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -68,7 +68,7 @@ def setup_pytorch_extension( print("Could not determine CUDA Toolkit version") else: if version >= (11, 2): - nvcc_flags.extend(["--threads", "4"]) + nvcc_flags.extend(["--threads", os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1")]) if version >= (11, 0): nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) if version >= (11, 8): diff --git a/build_tools/utils.py b/build_tools/utils.py index 3230ad35bf..a0837c1c04 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -37,8 +37,8 @@ def get_max_jobs_for_parallel_build() -> int: num_jobs = 0 # Check environment variable - if os.getenv("NVTE_MAX_BUILD_JOBS"): - num_jobs = int(os.getenv("NVTE_MAX_BUILD_JOBS")) + if os.getenv("NVTE_BUILD_MAX_JOBS"): + num_jobs = int(os.getenv("NVTE_BUILD_MAX_JOBS")) elif os.getenv("MAX_JOBS"): num_jobs = int(os.getenv("MAX_JOBS")) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b814ef5974..048e7fd61a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -14,6 +14,22 @@ set(CMAKE_CUDA_STANDARD_REQUIRED ON) project(transformer_engine LANGUAGES CUDA CXX) +set(BUILD_THREADS_PER_JOB $ENV{NVTE_BUILD_THREADS_PER_JOB}) +if (NOT BUILD_THREADS_PER_JOB) + set(BUILD_THREADS_PER_JOB 1) +endif() +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}") + +if(DEFINED ENV{MAX_JOBS}) + set(JOBS $ENV{MAX_JOBS}) +elseif(DEFINED ENV{NVTE_BUILD_MAX_JOBS}) + set(JOBS $ENV{NVTE_BUILD_MAX_JOBS}) +else() + set(JOBS "max number of") +endif() + +message(STATUS "Parallel build with ${JOBS} jobs and ${BUILD_THREADS_PER_JOB} threads per job") + if (CMAKE_BUILD_TYPE STREQUAL "Debug") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") endif() From def4d1cbfd24e4bb28608d045634a817f638abb7 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 12 Aug 2024 15:45:35 -0700 Subject: [PATCH 45/72] Remove duplicate test (#1082) Signed-off-by: Przemek Tredak --- qa/L0_pytorch_unittest/test.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 90c5e499f3..e6ccf3b82f 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -19,7 +19,6 @@ NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py -pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py From b48403861b651572bdd39e4a17c24f5c0930370a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 13 Aug 2024 08:06:22 -0700 Subject: [PATCH 46/72] Pin NLTK version to fix JAX ci (#1096) Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- qa/L0_jax_unittest/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 3db1807fe2..81bbfa1065 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -4,6 +4,7 @@ set -xe +pip install nltk==3.8.1 pip install pytest==8.2.1 : ${TE_PATH:=/opt/transformerengine} From ec49a52b4524192ab5a0b2557d37747cb09d1d92 Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Tue, 13 Aug 2024 09:10:22 -0700 Subject: [PATCH 47/72] Dgrad ReduceScatter overlap fix (#1088) * DGRAD-RS overlap bug fix This PR fixes a bug in enabling DGRAD-RS overlap by adding the layer to the correct method list. Previously, the RS-DGRAD overlap layer was incorrectly added to pipeline method list even if ring_exchange method is specified in config. Signed-off-by: Vasudevan Rengasamy * Bug fix for ring_exchange ReduceScatter ring_exchange RS uses main_stream for last GEMM chunk. But the send/recv streams wait for stream_compute during last chunk. Signed-off-by: Vasudevan Rengasamy --------- Signed-off-by: Vasudevan Rengasamy Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- .../pytorch/csrc/comm_gemm_overlap.h | 19 ++++++++----------- transformer_engine/pytorch/module/base.py | 4 +++- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index bae46cffc9..3b4e126943 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -1205,11 +1205,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}, workspace.options()); - if (i == _tp_size - 1) { - at::cuda::setCurrentCUDAStream(stream_main); - } else { - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - } + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, _ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); @@ -1230,6 +1226,13 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { recv_rank, (cudaStream_t)_stream_recv); } } + at::cuda::setCurrentCUDAStream(stream_main); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); @@ -1248,12 +1251,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); torch::sum_out(rs_output, reduce_buf, 0); } - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); _ub_comm->sms = ori_sms; } diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 651a1a4c1a..3613e1fa5e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -340,7 +340,9 @@ def add_ub( layers_reduce_scatter_overlap.remove(wgrad_name) layers_all_gather_overlap.remove(name) layers_reduce_scatter_overlap.append(name) - methods["pipeline"].append(name) + methods["bulk"].remove(name) + new_method = ub_cfgs[name]["method"] + methods[new_method].append(name) for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: ub_cfg = get_default_config(name) From b8d453ef4d22f0ab1b097fe2f915976ee6ae3817 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 13 Aug 2024 11:20:34 -0700 Subject: [PATCH 48/72] [PyTorch] Merge `k_channels` and `v_channels` back to `kv_channels` (#1094) * merge k_channels and v_channels back to kv_channels and accept a tuple Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix isinstance call Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix MLA tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 2 +- tests/pytorch/test_onnx_export.py | 2 +- transformer_engine/pytorch/attention.py | 32 ++++++++++++++------- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index afc2081752..ac82a5424e 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -902,7 +902,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: # Set up model block = DotProductAttention( config.num_heads, - config.head_dim_qk, + (config.head_dim_qk, config.head_dim_v), num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, qkv_format=qkv_format, diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index e8361a2190..bdc459cdcc 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -1083,7 +1083,7 @@ def test_export_core_attention( model = te.attention.DotProductAttention( num_attention_heads=num_attention_heads, - k_channels=kv_channels, + kv_channels=kv_channels, attention_dropout=0.5, qkv_format=qkv_format, attn_mask_type=attn_mask_type, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index c8ca157c28..3fc805bdc6 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5177,10 +5177,9 @@ class DotProductAttention(TransformerEngineBaseModule): ---------- num_attention_heads : int number of attention heads in the transformer layer. - k_channels : int - number of channels per attention head in key. - v_channels : Optional[int] = None - number of channels per attention head in value. + kv_channels : Union[int, Tuple[int, int]] + the head size in key and value tensors. If the same, :attr:`kv_channels` can be + an integer; if not, :attr:`kv_channels` should be a tuple of two integers. num_gqa_groups : Optional[int] = None number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -5242,7 +5241,7 @@ class DotProductAttention(TransformerEngineBaseModule): For that, please use `get_qkv_layout` to gain the layout information. softmax_scale: Optional[float], default = `None` softmax scale for the attention scores. If `None`, defaults to - `1.0 / math.sqrt(kv_channels)`. + `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. Parallelism parameters ---------------------- @@ -5266,8 +5265,7 @@ class DotProductAttention(TransformerEngineBaseModule): def __init__( self, num_attention_heads: int, - k_channels: int, - v_channels: Optional[int] = None, + kv_channels: Union[int, Tuple[int, int]], num_gqa_groups: Optional[int] = None, attention_dropout: float = 0.0, qkv_format: str = "sbhd", @@ -5310,8 +5308,12 @@ def __init__( self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream - self.hidden_size_per_attention_head = k_channels - self.v_channels = k_channels if v_channels is None else v_channels + self.hidden_size_per_attention_head_k = ( + kv_channels if isinstance(kv_channels, int) else kv_channels[0] + ) + self.hidden_size_per_attention_head_v = ( + kv_channels if isinstance(kv_channels, int) else kv_channels[1] + ) self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) @@ -5329,7 +5331,9 @@ def __init__( attention_dropout_ctx = self.rng_states_tracker.fork if softmax_scale is None: - softmax_scale = 1.0 / math.sqrt(k_channels) + softmax_scale = 1.0 / math.sqrt( + kv_channels if isinstance(kv_channels, int) else kv_channels[0] + ) self.deterministic = ( not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) @@ -5628,6 +5632,14 @@ def forward( assert ( key_layer.shape[:-1] == value_layer.shape[:-1] ), "Keys and values must have the same batch size, sequence length and number of heads!" + assert ( + key_layer.shape[-1] == self.hidden_size_per_attention_head_k + ), f"Keys have head_dim = {key_layer.shape[-1]}, " + "but expected head_dim = {self.hidden_size_per_attention_head_k}!" + assert ( + value_layer.shape[-1] == self.hidden_size_per_attention_head_v + ), f"Values have head_dim = {value_layer.shape[-1]}, " + "but expected head_dim = {self.hidden_size_per_attention_head_v}!" if attn_mask_type is None: attn_mask_type = self.attn_mask_type From 88c0c9143140eb747f59f8d32b865667b6c6acb6 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 13 Aug 2024 13:03:51 -0700 Subject: [PATCH 49/72] [PyTorch] Update docs/example and benchmarks/ scripts (#1075) * update example/benchmark scripts Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix head_dim after MLA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update notebook Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- benchmarks/attention/benchmark_attention.py | 20 +- .../arbitrary_mask_to_post_scale_bias.py | 16 +- docs/examples/attention/attention.ipynb | 171 +++++++++--------- docs/examples/attention/example_attention.py | 15 +- tests/pytorch/fused_attn/test_fused_attn.py | 24 ++- 5 files changed, 123 insertions(+), 123 deletions(-) diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py index e5df485eda..bfd7bf8471 100644 --- a/benchmarks/attention/benchmark_attention.py +++ b/benchmarks/attention/benchmark_attention.py @@ -11,9 +11,7 @@ import transformer_engine from tests.pytorch.fused_attn.test_fused_attn import ( ModelConfig, - _is_flash_attention_supported, - _is_fused_attention_supported, - _is_unfused_attention_supported, + _get_attention_backends, _run_dot_product_attention, ) @@ -29,8 +27,6 @@ workspace_opt = True # QKV memory layout qkv_layout = "bshd_bshd_bshd" -# sliding window attention -swa = False # padding between sequences for qkv_format=thd pad_between_seqs = False # training mode @@ -64,7 +60,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -76,7 +71,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -97,7 +91,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -115,7 +108,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -205,13 +197,15 @@ def main(): ) for model in model_configs.keys(): config = model_configs[model] - fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( + available_backends, fused_attn_backends = _get_attention_backends( config, - dtype, + qkv_dtype=dtype, qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=pad_between_seqs, ) - fused_attn_supported = fused_attn_supported and not swa - flash_attn_supported = _is_flash_attention_supported(config) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + print( f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}' f'{" and flash-attention" if flash_attn_supported else ""}...' diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py index cd8ab85ba2..85ce01079c 100644 --- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py +++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py @@ -6,7 +6,6 @@ import torch from typing import Tuple from tests.pytorch.fused_attn.test_fused_attn import ModelConfig -from transformer_engine.pytorch.distributed import _set_cuda_rng_state from transformer_engine.pytorch.attention import DotProductAttention # Initialize RNG state @@ -22,7 +21,7 @@ def reset_rng_states() -> None: """Revert back to initial RNG state""" torch.set_rng_state(_cpu_rng_state) - _set_cuda_rng_state(_cuda_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) def _run_dot_product_attention( @@ -40,7 +39,7 @@ def _run_dot_product_attention( [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" ) inp = torch.randn( - [config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim], + [config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk], dtype=dtype, device="cuda", ) @@ -51,7 +50,7 @@ def _run_dot_product_attention( k.requires_grad = True v.requires_grad = True out_grad = torch.randn( - [config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim], + [config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim_v], dtype=dtype, device="cuda", ) @@ -80,7 +79,7 @@ def _run_dot_product_attention( block = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, qkv_format="bshd", attention_dropout=config.dropout_p, @@ -89,6 +88,8 @@ def _run_dot_product_attention( get_rng_state_tracker=None, tp_group=None, layer_number=1, + attn_mask_type="no_mask", + window_size=(-1, -1), ).to(dtype=dtype, device="cuda") # Run a forward and backward pass @@ -103,6 +104,7 @@ def _run_dot_product_attention( attn_mask_type=config.attn_mask_type, # 'arbitrary' core_attention_bias_type=config.attn_bias_type, # 'no_bias' core_attention_bias=bias, # None + window_size=(-1, -1), ) out.backward(out_grad) @@ -116,6 +118,7 @@ def _run_dot_product_attention( attn_mask_type=config.attn_mask_type, # no_mask core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias' core_attention_bias=bias, # bias + window_size=(-1, -1), ) out.backward(out_grad) @@ -133,6 +136,7 @@ def _run_dot_product_attention( config = model_configs["test_bias"] fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd") +print() print("Run with arbitrary mask:") config = model_configs["test_mask"] unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd") @@ -140,4 +144,6 @@ def _run_dot_product_attention( torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2) for i in range(3): torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2) + +print() print("Test passed!") diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 515f420790..27017b4773 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "8ae3bc43", + "id": "040f466a", "metadata": {}, "source": [ "# Attention Is All You Need!\n", @@ -23,7 +23,7 @@ }, { "cell_type": "markdown", - "id": "47421c01", + "id": "89a7d849", "metadata": {}, "source": [ "## 1. Attention Backends\n", @@ -71,7 +71,7 @@ }, { "cell_type": "markdown", - "id": "e52f60f0", + "id": "c90a2573", "metadata": {}, "source": [ "### 1.1 Flash vs. Non-Flash\n", @@ -85,30 +85,30 @@ "- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n", "\n", "
\n", - "Note \n", + "Note: \n", " \n", - "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n", + "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n", "
\n" ] }, { "cell_type": "markdown", - "id": "bb909ac4", + "id": "b5ce567d", "metadata": {}, "source": [ "### 1.2 flash-attention\n", "\n", "The flash-attention backend, available only in PyTorch, is a module wrapped around the public `flash-attn` package [[3]](https://github.com/Dao-AILab/flash-attention). \n", "\n", - "The flash-attention backend supports `flash-attn`'s features as they are released, and to facilitate the use of `flash-attn`, flash-attention also offers a few functionalities such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask. Please see `transformer_engine.pytorch.attention.FlashAttention` for more details.\n", + "The flash-attention backend supports `flash-attn`'s features as well as a few extra functionalities to facilitate the use of `flash-attn`, such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask use cases. Please see `transformer_engine.pytorch.attention.FlashAttention` for details.\n", "\n", - "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.7, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", + "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.10, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", "\n", - "To understand `flash-attn`'s performance, please refer to their [benchmarks](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n", + "To understand `flash-attn`'s performance, please refer to their benchmarks [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n", "\n", "### 1.3 cuDNN Attention\n", "\n", - "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths. Out of the three, sub-backends 1 and 2 are based on the flash algorithm, as `flash-attn` is.\n", + "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n", "\n", "\n", " \n", @@ -153,14 +153,14 @@ " \n", "
\n", "\n", - "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.7, cuDNN 9.0 and `flash-attn` 2.4.2,\n", + "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.10, cuDNN 9.3 and `flash-attn` 2.4.2,\n", "\n", "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n", "- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n", - "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three without transposes (see Section 3.1 for more details).\n", + "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three formats without transposes (see Section 3.1 for more details).\n", "- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n", - "- flash-attention supports sliding window attention, and cuDNN attention does not.\n", - "- flash-attention uses bottom right diagonal for `causal` mask in cross attention, and cuDNN attention uses top left (see `flash-attn`'s [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)).\n", + "- flash-attention supports KV-caching and paged attention, and cuDNN attention does not.\n", + "- flash-attention uses bottom right diagonal for `causal` mask in cross attention (see [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)), and cuDNN attention supports both top left and bottom right.\n", "- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n", "\n", "To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0." @@ -169,7 +169,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9a380859", + "id": "c5b8e3d7", "metadata": {}, "outputs": [], "source": [ @@ -184,25 +184,25 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "0584bb01", + "execution_count": 1, + "id": "50852cb5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Device 0: NVIDIA H100 PCIe GPU, sm90 compute capability, 79.1GB memory\n", + "Device 0: NVIDIA H100 80GB HBM3 GPU, sm90 compute capability, 79.1GB memory\n", "Running test_0 with cuDNN attention and flash-attention...\n", "Running test_1 with cuDNN attention and flash-attention...\n", "Running test_2 with cuDNN attention...\n", "Running test_3 with cuDNN attention and flash-attention...\n", "\n", " cuDNN fwd+bwd (ms) flash-attn fwd+bwd (ms) cuDNN vs flash speedup\n", - "test_0 0.0638 0.0858 1.3454\n", - "test_1 0.5415 0.7496 1.3842\n", - "test_2 1.2302 0.0000 0.0000\n", - "test_3 12.0122 19.0716 1.5877\n" + "test_0 0.0340 0.0468 1.3786\n", + "test_1 0.3664 0.5850 1.5968\n", + "test_2 0.9332 0.0000 0.0000\n", + "test_3 7.4875 11.8879 1.5877\n" ] } ], @@ -212,7 +212,7 @@ }, { "cell_type": "markdown", - "id": "45e53fc9", + "id": "9a615119", "metadata": {}, "source": [ "## 2. Backend Selection\n", @@ -253,35 +253,35 @@ }, { "cell_type": "markdown", - "id": "6dfeade3", + "id": "e6c0f3f0", "metadata": {}, "source": [ "### 2.1 Debug Information\n", "\n", - "To find out which backend is being used during runtime, users can turn on these debugging flags. Logging is done using the `logging` package.\n", + "To find out which backend is being used during runtime, we have the following two debugging flags. Logging is done by using the `logging` package.\n", "```\n", "NVTE_DEBUG = 0/1 # disables/enables debugging\n", "NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages\n", "```\n", "
\n", - "Note\n", + "Note:\n", " \n", - "These flags are supported in PyTorch only as of Transformer Engine 1.7. JAX and PaddlePaddle support is expected to be added in the future.\n", + "These flags are supported in PyTorch only as of Transformer Engine 1.10. JAX and PaddlePaddle support is expected to be added in the future.\n", "
" ] }, { "cell_type": "markdown", - "id": "7e3b7981", + "id": "16660323", "metadata": {}, "source": [ - "The [example_attention.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime." + "The example script [example_attention.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/example_attention.py) runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend is used in runtime." ] }, { "cell_type": "code", - "execution_count": 22, - "id": "961c51d4", + "execution_count": 24, + "id": "906b8cf1", "metadata": {}, "outputs": [ { @@ -293,7 +293,7 @@ "[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n", "\n", "Run flash-attention...\n", - "[INFO | DotProductAttention]: Running with FlashAttention backend \n", + "[INFO | DotProductAttention]: Running with FlashAttention backend\n", "\n", "Test passed.\n" ] @@ -305,16 +305,16 @@ }, { "cell_type": "markdown", - "id": "11bfbbd7", + "id": "8ca99461", "metadata": {}, "source": [ - "To collect more information, users can turn on `NVTE_DEBUG_LEVEL=2`. In this example, it allows us to find out more about the run config. Users are encouraged to provide if users intend to file a bug with Transformer Engine. For example, " + "`NVTE_DEBUG_LEVEL=2` allows us to find out more about the backend selection logic. Users are encouraged to double check the `config` and provide it to the Transformer Engine team if they would like to file a bug. " ] }, { "cell_type": "code", - "execution_count": 25, - "id": "162a2be1", + "execution_count": 23, + "id": "d3637094", "metadata": {}, "outputs": [ { @@ -323,16 +323,18 @@ "text": [ "\n", "Run cuDNN attention...\n", + "[DEBUG | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': , 'cudnn_version': '9.3.0', 'qkv_type': , 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}\n", "[DEBUG | DotProductAttention]: Disabling FlashAttention due to NVTE_FLASH_ATTN=0\n", + "[DEBUG | DotProductAttention]: Available backends = {FlashAttention=False, FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}\n", + "[DEBUG | DotProductAttention]: Selected backend = FusedAttention (sub-backend 1)\n", "[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n", - "[DEBUG | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': , 'flash_attn_version': , 'cudnn_version': '9.2.0'}\n", - "[DEBUG | FusedAttnFunc ]: Running forward in torch.bfloat16\n", - "[DEBUG | FusedAttnFunc ]: Running backward in torch.bfloat16\n", "\n", "Run flash-attention...\n", + "[DEBUG | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': , 'cudnn_version': '9.3.0', 'qkv_type': , 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}\n", "[DEBUG | DotProductAttention]: Disabling FusedAttention due to NVTE_FUSED_ATTN=0\n", - "[INFO | DotProductAttention]: Running with FlashAttention backend \n", - "[DEBUG | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': , 'flash_attn_version': , 'cudnn_version': '9.2.0'}\n", + "[DEBUG | DotProductAttention]: Available backends = {FlashAttention=True, FusedAttention=False, UnfusedDotProductAttention=True}\n", + "[DEBUG | DotProductAttention]: Selected backend = FlashAttention\n", + "[INFO | DotProductAttention]: Running with FlashAttention backend\n", "\n", "Test passed.\n" ] @@ -344,7 +346,7 @@ }, { "cell_type": "markdown", - "id": "779a51e6", + "id": "611d8fdb", "metadata": {}, "source": [ "### 2.2 User Control\n", @@ -392,28 +394,29 @@ }, { "cell_type": "markdown", - "id": "ccd5650d", + "id": "e60a2a3e", "metadata": {}, "source": [ "## 3. Backend Support\n", "\n", - "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.7, Transformer Engine's attention backends have the following support matrix.\n", + "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.10, Transformer Engine's attention backends have the following support matrix.\n", "\n", - "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Determinism Possible |\n", - "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |\n", - "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes (only for `bshd`,`sbhd`) | Yes |\n", - "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes (only for `bshd`,`thd`) | Yes |\n", - "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | No | Yes |\n", + "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", + "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", + "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n", + "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`) | Yes |\n", + "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", + "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)" ] }, { "cell_type": "markdown", - "id": "8439b389", + "id": "fbdcb327", "metadata": {}, "source": [ "### 3.1 QKV Layout\n", @@ -439,7 +442,7 @@ "**qkv_layout=thd_thd_thd:**\n", "`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n", "\n", - "As of v1.7, Transformer Engine has the following support matrix.\n", + "As of v1.10, Transformer Engine has the following support matrix.\n", "\n", "\n", " \n", @@ -480,16 +483,16 @@ }, { "cell_type": "markdown", - "id": "0290f8e9", + "id": "855d9616", "metadata": {}, "source": [ "### 3.2 Attention Mask\n", "\n", - "Transformer Engine supports 5 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n", + "Transformer Engine supports 7 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n", "\n", - "- `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), `arbitrary`\n", + "- `no_mask`, `padding`, `causal`, `causal_bottom_right`, `padding_causal`, `padding_causal_bottom_right`, `arbitrary`\n", "\n", - "Different backends offer different support for attention mask. As of Transformer Engine 1.7,\n", + "Different backends offer different support for attention mask. As of Transformer Engine 1.10,\n", "\n", "
\n", " \n", @@ -498,34 +501,25 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", "
Requires `attention_mask`
flash-attention`no_mask`, `causal`, `padding`, `padding_causal``no_mask`, `causal`: No
`padding`, `padding_causal`: Yes if `cu_seqlens` not provided
cuDNN attention`no_mask`, `causal`, `padding`, `padding_causal``no_mask`, `causal`: Noflash-attention
  • `no_mask`, `causal` (self-attention),
  • `padding`, `padding_causal` (self-attention),
  • `causal_bottom_right`, `padding_causal_bottom_right`
  • `no_mask`, `causal` `causal_bottom_right`: No
  • `padding`, `padding_causal`, `padding_causal_bottom_right`: Yes if `cu_seqlens` not provided
  • `arbitrary`: Yes
  • \n", - " `padding`, `padding_causal`: Yes if `cu_seqlens` not provided\n", - " cuDNN attention
  • `no_mask`, `causal`,
  • `padding`, `padding_causal`,
  • `causal_bottom_right`, `padding_causal_bottom_right`
  • Framework-native attention`no_mask`, `causal`, `arbitrary``no_mask`, `causal`: NoFramework-native attention
  • All (PyTorch)
  • `no_mask`, `causal`, `padding` (Jax, PaddlePaddle)
  • `arbitrary`: Yes
    \n", "\n", - "**padding and padding_causal:** For these two mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.7, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n", + "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.10, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n", "\n", "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n", " - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n", @@ -536,13 +530,13 @@ "\n", "**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n", "\n", - "**Arbitrary mask:** cuDNN does not support `Arbitrary` mask type as of v9.0. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py).\n" + "**Arbitrary mask:** cuDNN does not support `Arbitrary` mask type as of v9.3. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py).\n" ] }, { "cell_type": "code", - "execution_count": 6, - "id": "b1b7cdd4", + "execution_count": 33, + "id": "a1f25a9b", "metadata": {}, "outputs": [ { @@ -550,27 +544,29 @@ "output_type": "stream", "text": [ "Run with post_scale_bias:\n", - "[DotProductAttention]: using cuDNN attention (sub-backend 1)\n", + "[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n", + "\n", "Run with arbitrary mask:\n", - "[DotProductAttention]: using unfused DPA\n", + "[INFO | DotProductAttention]: Running with UnfusedDotProductAttention backend\n", + "\n", "Test passed!\n" ] } ], "source": [ - "!NVTE_DEBUG=1 python arbitrary_mask_to_post_scale_bias.py" + "!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python arbitrary_mask_to_post_scale_bias.py" ] }, { "cell_type": "markdown", - "id": "e045c284", + "id": "dda4a589", "metadata": {}, "source": [ "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n", "\n", "### 3.3 Attention Bias\n", "\n", - "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.7, their support matrix is as follows.\n", + "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.10, their support matrix is as follows.\n", "\n", "\n", " \n", @@ -617,25 +613,20 @@ }, { "cell_type": "markdown", - "id": "8b8a4e40", + "id": "a0702339", "metadata": {}, "source": [ "### 3.4 FP8 Attention\n", "\n", "A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n", "\n", - "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.7. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", + "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.10. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", "\n", "- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n", "\n", "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n", "\n", - "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`. This should result in the following print when the debug flags are turned on, `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2`.\n", - "```\n", - "[DEBUG | DotProductAttention]: Running with fp8_recipe.fp8_mha=False, fp8_recipe.fp8_dpa=True and NVTE_FP8_DPA_BWD=0\n", - "[DEBUG | FusedAttnFunc ]: Running forward in FP8\n", - "[DEBUG | FusedAttnFunc ]: Running backward in torch.bfloat16\n", - "```" + "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." ] } ], diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py index 2ed7303417..15022005a1 100644 --- a/docs/examples/attention/example_attention.py +++ b/docs/examples/attention/example_attention.py @@ -11,9 +11,7 @@ import transformer_engine from tests.pytorch.fused_attn.test_fused_attn import ( ModelConfig, - _is_flash_attention_supported, - _is_fused_attention_supported, - _is_unfused_attention_supported, + _get_attention_backends, _run_dot_product_attention, ) @@ -60,7 +58,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported): ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -75,7 +72,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported): ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -94,13 +90,14 @@ def main(): models = ["test_0"] for model in models: config = model_configs[model] - fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( + available_backends, fused_attn_backends = _get_attention_backends( config, - dtype, + qkv_dtype=dtype, qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=pad_between_seqs, ) - fused_attn_supported = fused_attn_supported and not swa - flash_attn_supported = _is_flash_attention_supported(config) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends example_attention(model, fused_attn_supported, flash_attn_supported) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index ac82a5424e..82a3c8576b 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -8,6 +8,7 @@ import os from importlib.metadata import version from typing import Any, Dict, List, Tuple, Union, Optional +from contextlib import contextmanager import pytest import torch @@ -108,6 +109,16 @@ def __init__( self.window_size = window_size +@contextmanager +def logging_context(highest_level=logging.WARNING): + previous_level = logging.root.manager.disable + logging.disable(highest_level) + try: + yield + finally: + logging.disable(previous_level) + + def _get_attention_backends( config: ModelConfig, qkv_dtype: torch.dtype, @@ -180,12 +191,13 @@ def test(): return available_backends, fused_attention_backend backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) + with logging_context(): + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) return available_backends, fused_attn_backends From 516dacc2868e44cb61b5b46b667938fc739ed7b9 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Tue, 13 Aug 2024 13:49:30 -0700 Subject: [PATCH 50/72] Timing for build (#1048) * add timing for build * using perf_counter --------- Signed-off-by: Phuong Nguyen --- build_tools/build_ext.py | 5 +++++ setup.py | 22 +++++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 631b2b3627..d7351a8617 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -10,6 +10,7 @@ import sys import sysconfig import copy +import time from pathlib import Path from subprocess import CalledProcessError @@ -81,6 +82,7 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: build_command.append(str(max_jobs)) # Run CMake commands + start_time = time.perf_counter() for command in [configure_command, build_command, install_command]: print(f"Running command {' '.join(command)}") try: @@ -88,6 +90,9 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: except (CalledProcessError, OSError) as e: raise RuntimeError(f"Error when running CMake: {e}") + total_time = time.perf_counter() - start_time + print(f"Time for build_ext: {total_time:.2f} seconds") + def get_build_ext(extension_cls: Type[setuptools.Extension]): class _CMakeBuildExtension(extension_cls): diff --git a/setup.py b/setup.py index 6a8bae2793..d1dd35c027 100644 --- a/setup.py +++ b/setup.py @@ -5,10 +5,12 @@ """Installation script.""" import os +import time from pathlib import Path from typing import List, Tuple import setuptools +from wheel.bdist_wheel import bdist_wheel from build_tools.build_ext import CMakeExtension, get_build_ext from build_tools.utils import ( @@ -39,10 +41,23 @@ install_and_import("pybind11[global]") from pybind11.setup_helpers import build_ext as BuildExtension +# Start timing +start_time = time.perf_counter() + CMakeBuildExtension = get_build_ext(BuildExtension) +class TimedBdist(bdist_wheel): + """Helper class to measure build time""" + + def run(self): + start_time = time.perf_counter() + super().run() + total_time = time.perf_counter() - start_time + print(f"Time for bdist_wheel: {total_time:.2f} seconds") + + def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" # Project directory root @@ -141,7 +156,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: }, description="Transformer acceleration library", ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension}, + cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=">=3.8, <3.13", classifiers=[ "Programming Language :: Python :: 3.8", @@ -156,3 +171,8 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: include_package_data=True, package_data={"": ["VERSION.txt"]}, ) + + # End timing + end_time = time.perf_counter() + total_time = end_time - start_time + print(f"Total build time: {total_time:.2f} seconds") From 218f46cb990161f7b959419834ebce2cd1fa3bee Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Tue, 13 Aug 2024 15:59:12 -0700 Subject: [PATCH 51/72] [Test] Remove test_dgeglu.cu which is already included in test_act.cu (#1097) rm test_dgeglu Signed-off-by: Phuong Nguyen --- tests/cpp/operator/CMakeLists.txt | 1 - tests/cpp/operator/test_dgeglu.cu | 126 ------------------------------ 2 files changed, 127 deletions(-) delete mode 100644 tests/cpp/operator/test_dgeglu.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index e302be57bd..e590d8e92a 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -10,7 +10,6 @@ add_executable(test_operator test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu test_act.cu - test_dgeglu.cu test_layernorm.cu test_rmsnorm.cu test_multi_cast_transpose.cu diff --git a/tests/cpp/operator/test_dgeglu.cu b/tests/cpp/operator/test_dgeglu.cu deleted file mode 100644 index 0924e2b4c9..0000000000 --- a/tests/cpp/operator/test_dgeglu.cu +++ /dev/null @@ -1,126 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include "../test_common.h" - -using namespace transformer_engine; - -namespace { - -template -inline CType gelu(const IType val) { - CType cval = val; - return cval * (0.5f + 0.5f * tanhf(cval * (0.79788456f + 0.03567741f * cval * cval))); -} - -template -inline CType dgelu(const IType val) { - CType cval = val; - const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval)); - return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) + - 0.5f * (1.f + tanh_out); -} - -template -void compute_ref_dgeglu(const IT *grad_h, const IT *input_h, OT *output_h, const size_t N, - const size_t H) { - const size_t col = H * 2; - - for (size_t i = 0; i < N; i++) { - for (size_t j = 0; j < H; j++) { - CT grad_elt = CT(grad_h[i * H + j]); - CT gelu_elt = CT(input_h[i * col + j]); - CT gate_elt = CT(input_h[i * col + H + j]); - - CT after_dgelu = dgelu(gelu_elt) * grad_elt * gate_elt; - CT after_dgate = grad_elt * gelu(gelu_elt); - - output_h[i * col + j] = OT(after_dgelu); - output_h[i * col + H + j] = OT(after_dgate); - } - } -} - -template -void performTestDGeGLU(const size_t N, const size_t H) { - using namespace test; - - using CType = fp32; - - DType itype = TypeInfo::dtype; - DType otype = TypeInfo::dtype; - - Tensor grad({N, H}, itype); - Tensor input({N, H * 2}, itype); - Tensor output({N, H * 2}, otype); - - fillUniform(&grad); - fillUniform(&input); - - std::unique_ptr ref_output = std::make_unique(N * H * 2); - - nvte_dgeglu(grad.data(), input.data(), output.data(), 0); - - compute_ref_dgeglu(grad.cpu_dptr(), input.cpu_dptr(), - ref_output.get(), N, H); - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - auto [atol, rtol] = getTolerances(otype); - compareResults("output_dgelu", output, ref_output.get(), atol, rtol); -} - -std::vector> test_cases = { - {4096, 2048}, {768, 2816}, {256, 5120}, {128, 10240}, {256, 256}, {257, 259}, {128, 128 + 1}}; - -} // namespace - -class DGeGLUTestSuite - : public ::testing::TestWithParam>> {}; - -TEST_P(DGeGLUTestSuite, TestDGeGLU) { - using namespace transformer_engine; - using namespace test; - - const DType input_type = std::get<0>(GetParam()); - const DType output_type = std::get<1>(GetParam()); - const auto size = std::get<2>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - output_type, OutputType, - performTestDGeGLU(size.first, size.second););); -} - -INSTANTIATE_TEST_SUITE_P( - OperatorTest, DGeGLUTestSuite, - ::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::ValuesIn(test_cases)), - [](const testing::TestParamInfo &info) { - std::string name = test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second); - return name; - }); From dcc50c8e3d48453fa6c5bd90455a67eedffe8201 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 13 Aug 2024 15:59:34 -0700 Subject: [PATCH 52/72] Use minimal CUDA container for PyTorch GitHub build (#1091) * Use minimal CUDA container for PyTorch GitHub build Signed-off-by: Tim Moon * Accidentally installed PyTorch in wrong test Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Debug sanity test Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Install PyTorch build dependencies Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Include NumPy as a dependency Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Disable sanity import test Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .github/workflows/build.yml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fb7ab345d1..313aee6ab8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -35,9 +35,14 @@ jobs: name: 'PyTorch' runs-on: ubuntu-latest container: - image: nvcr.io/nvidia/pytorch:24.05-py3 + image: nvcr.io/nvidia/cuda:12.5.0-devel-ubuntu22.04 options: --user root steps: + - name: 'Dependencies' + run: | + apt-get update + apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 + pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 - name: 'Checkout' uses: actions/checkout@v3 with: @@ -48,7 +53,8 @@ jobs: NVTE_FRAMEWORK: pytorch MAX_JOBS: 1 - name: 'Sanity check' - run: python tests/pytorch/test_sanity_import.py + if: false # Sanity import test requires Flash Attention + run: python3 tests/pytorch/test_sanity_import.py jax: name: 'JAX' runs-on: ubuntu-latest From 4b2b39b4e1c64b841a8c361fccf046b15a4ad60f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Tue, 13 Aug 2024 19:09:25 -0700 Subject: [PATCH 53/72] [TE/JAX] Prototype for New XLA Custom Calls with FFI (#946) * implemented custom call with ffi in csrc * moved headers of misc to misc.h, add ffi.h * ActLu and DActLu lowering with ffi_lowering * CastTranspose with ffi_lowering * enabled cudaGraph * added 4d input test case to TestActivationLu * added operand_output_aliases for CastTranspose * added env var NVTE_JAX_WITH_FFI, default value = 1 * replace casting ActivationEnum by taking its value --------- Signed-off-by: Phuong Nguyen --- build_tools/jax.py | 4 + tests/jax/test_custom_call_compute.py | 16 +-- .../jax/cpp_extensions/activation.py | 107 +++++++++-------- .../jax/cpp_extensions/custom_call.py | 25 +++- transformer_engine/jax/cpp_extensions/misc.py | 25 ++++ .../jax/cpp_extensions/transpose.py | 82 +++++++------ transformer_engine/jax/csrc/extensions.h | 25 ++-- .../jax/csrc/extensions/activation.cpp | 112 ++++++++++++++++-- .../jax/csrc/extensions/ffi.cpp | 51 ++++++++ transformer_engine/jax/csrc/extensions/ffi.h | 25 ++++ transformer_engine/jax/csrc/extensions/misc.h | 30 +++++ .../jax/csrc/extensions/normalization.cpp | 1 - .../jax/csrc/extensions/pybind.cpp | 14 ++- .../jax/csrc/extensions/transpose.cpp | 56 +++++++++ 14 files changed, 450 insertions(+), 123 deletions(-) create mode 100644 transformer_engine/jax/csrc/extensions/ffi.cpp create mode 100644 transformer_engine/jax/csrc/extensions/ffi.h create mode 100644 transformer_engine/jax/csrc/extensions/misc.h diff --git a/build_tools/jax.py b/build_tools/jax.py index 72a22f683e..aba9c749d0 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -11,6 +11,8 @@ from .utils import cuda_path, all_files_in_dir from typing import List +from jax.extend import ffi + def setup_jax_extension( csrc_source_files, @@ -27,12 +29,14 @@ def setup_jax_extension( # Header files cuda_home, _ = cuda_path() + jax_ffi_include = ffi.include_dir() include_dirs = [ cuda_home / "include", common_header_files, common_header_files / "common", common_header_files / "common" / "include", csrc_header_files, + jax_ffi_include, ] # Compile flags diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 5006f87a9d..6991d83d4c 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -432,7 +432,7 @@ def ref_act_lu(inputs): def primitive_func(self, inputs): return jnp.mean(activation_lu(inputs, activation_type=self.activation_type)) - @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)]) + @pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)]) @pytest.mark.parametrize( "activation_type", [ @@ -450,7 +450,7 @@ def primitive_func(self, inputs): ) def test_activation_lu(self, random_inputs, activation_type): x = random_inputs - x = jnp.repeat(x, len(activation_type), axis=1) + x = jnp.repeat(x, len(activation_type), axis=-2) self.activation_type = activation_type value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,))) @@ -511,7 +511,7 @@ def _prim_func_bwd(ctx, g): _prim_func.defvjp(_prim_func_fwd, _prim_func_bwd) - dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype) + dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_axes], dtype=x.dtype) dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) amax_no_use = jnp.zeros(1, jnp.float32) value_n_grad_primitive_func = value_and_grad( @@ -520,7 +520,7 @@ def _prim_func_bwd(ctx, g): return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)]) + @pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)]) @pytest.mark.parametrize( "activation_type", [ @@ -541,10 +541,12 @@ def test_activation_lu(self, random_inputs, activation_type): self.scale = jnp.ones(1, jnp.float32) self.scale_inv = jnp.ones(1, jnp.float32) self.activation_type = activation_type - self.transpose_indices = (1, 2, 0) x = random_inputs - x = jnp.repeat(x, len(activation_type), axis=1) + x = jnp.repeat(x, len(activation_type), axis=-2) + axes = jnp.arange(x.ndim) + self.transpose_axes = tuple([*axes[-2:]] + [*axes[:-2]]) + print(self.transpose_axes) prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x) ref_out, (ref_grad,) = self.ref_func(x, activation_type) @@ -556,7 +558,7 @@ def test_activation_lu(self, random_inputs, activation_type): assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose( prim_grad_trans, - jnp.transpose(ref_grad, self.transpose_indices), + jnp.transpose(ref_grad, self.transpose_axes), dtype=FP8Helper.BWD_DTYPE, ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index bdc377cb27..56359646b1 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -11,6 +11,7 @@ from jax import core, dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import NVTE_Activation_Type @@ -22,6 +23,7 @@ jax_dtype_to_te_dtype, jax_dtype_to_ir_dtype, get_padded_spec, + is_ffi_enabled, ) from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP @@ -109,25 +111,29 @@ def lowering(ctx, x, *, act_enum): """ (x_aval,) = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]] - - out_types = [ - ir.RankedTensorType.get(out_shape, ir_x_type.element_type), - ] - operands = [x] - operand_shapes = [ir_x_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - hidden_size = ir_x_shape[-1] - batch_size = reduce(operator.mul, ir_x_shape[:-2]) - in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor( - (batch_size, hidden_size), in_dtype, in_dtype, act_enum - ) + if is_ffi_enabled(): + name = "te_act_lu_ffi" + out = ffi.ffi_lowering(name)(ctx, x, act_enum=act_enum) + else: + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]] + + out_types = [ + ir.RankedTensorType.get(out_shape, ir_x_type.element_type), + ] + operands = [x] + operand_shapes = [ir_x_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + hidden_size = ir_x_shape[-1] + batch_size = reduce(operator.mul, ir_x_shape[:-2]) + in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) + opaque = transformer_engine_jax.pack_common_descriptor( + (batch_size, hidden_size), in_dtype, in_dtype, act_enum + ) - out = custom_caller(ActLuPrimitive.name, args, opaque, False) + out = custom_caller(ActLuPrimitive.name, args, opaque, False) return out @@ -189,7 +195,7 @@ def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) if not ActLuPrimitive.enabled(): return _jax_act_lu(inputs, activation_type) - act_type_id = ActivationEnum[activation_type] + act_type_id = ActivationEnum[activation_type].value return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id) @@ -231,34 +237,38 @@ def lowering(ctx, dz, x, *, act_enum): in_aval, gi_aval = ctx.avals_in assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert gi_aval.dtype == in_aval.dtype - ir_in_type = ir.RankedTensorType(dz.type) - ir_in_shape = ir_in_type.shape - gi_type = ir.RankedTensorType(x.type) - gi_shape = gi_type.shape - # assert ir_in_shape == gi_shape - for axis in range(len(ir_in_shape) - 1): - assert ir_in_shape[axis] == gi_shape[axis] - - ir_batch_size = reduce(operator.mul, ir_in_shape[:-1]) - i_hidden_size = ir_in_shape[-1] - g_hidden_size = gi_shape[-1] - assert i_hidden_size == g_hidden_size - out_dtype = ir_in_type.element_type - out_shape = gi_shape - - out_types = [ - ir.RankedTensorType.get(out_shape, out_dtype), - ] - operands = [dz, x] - operand_shapes = [ir_in_shape, gi_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor( - (ir_batch_size, i_hidden_size), in_dtype, in_dtype, act_enum - ) + if is_ffi_enabled(): + name = "te_dact_lu_ffi" + out = ffi.ffi_lowering(name)(ctx, dz, x, act_enum=act_enum) + else: + ir_in_type = ir.RankedTensorType(dz.type) + ir_in_shape = ir_in_type.shape + gi_type = ir.RankedTensorType(x.type) + gi_shape = gi_type.shape + # assert ir_in_shape == gi_shape + for axis in range(len(ir_in_shape) - 1): + assert ir_in_shape[axis] == gi_shape[axis] + + ir_batch_size = reduce(operator.mul, ir_in_shape[:-1]) + i_hidden_size = ir_in_shape[-1] + g_hidden_size = gi_shape[-1] + assert i_hidden_size == g_hidden_size + out_dtype = ir_in_type.element_type + out_shape = gi_shape + + out_types = [ + ir.RankedTensorType.get(out_shape, out_dtype), + ] + operands = [dz, x] + operand_shapes = [ir_in_shape, gi_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) + opaque = transformer_engine_jax.pack_common_descriptor( + (ir_batch_size, i_hidden_size), in_dtype, in_dtype, act_enum + ) - out = custom_caller(DActLuPrimitive.name, args, opaque, False) + out = custom_caller(DActLuPrimitive.name, args, opaque, False) return out @@ -320,12 +330,11 @@ def dact_lu( dact_lu fusion wrapper Return dgated_act_lu(inputs) """ - if not DActLuPrimitive.enabled(): _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs) return vjp_func(inputs)[0] - act_type_id = ActivationEnum[activation_type] + act_type_id = ActivationEnum[activation_type].value return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id) @@ -487,7 +496,7 @@ def act_lu_fp8( casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype) return casted_output, updated_amax - act_type_id = ActivationEnum[activation_type] + act_type_id = ActivationEnum[activation_type].value return ActLuFp8Primitive.outer_primitive.bind( x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id ) diff --git a/transformer_engine/jax/cpp_extensions/custom_call.py b/transformer_engine/jax/cpp_extensions/custom_call.py index 36396a977c..8e58ed3bed 100644 --- a/transformer_engine/jax/cpp_extensions/custom_call.py +++ b/transformer_engine/jax/cpp_extensions/custom_call.py @@ -3,12 +3,14 @@ # See LICENSE for license information. """JAX/TE custom call""" from dataclasses import dataclass +from enum import IntEnum from jax.lib import xla_client from jax.interpreters import mlir from transformer_engine import transformer_engine_jax +from .misc import is_ffi_enabled try: from jaxlib.hlo_helpers import custom_call @@ -17,8 +19,25 @@ # version, so we still need this import. pass + +class CustomCallAPIVersion(IntEnum): + """Enum for selecting between old and new custom call registration API""" + + OPAQUE = 0 + FFI = 1 + + for _name, _value in transformer_engine_jax.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") + if _name.endswith("_ffi"): + if is_ffi_enabled(): + # COMMAND_BUFFER_COMPATIBLE i.e. cudaGraph enabled by default + xla_client.register_custom_call_target( + _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value + ) + else: + xla_client.register_custom_call_target( + _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value + ) @dataclass @@ -79,7 +98,7 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): result_layouts=args.output_layouts, backend_config=opaque, has_side_effect=has_side_effect, - **kwargs + **kwargs, ).results else: # Need to disable one pylint error as the second function @@ -93,6 +112,6 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): result_layouts=args.output_layouts, backend_config=opaque, has_side_effect=has_side_effect, - **kwargs + **kwargs, ) return out diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 9ad7354815..58b8db4c88 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -3,10 +3,14 @@ # See LICENSE for license information. """JAX/TE miscellaneous for custom ops""" +import os import functools from typing import Tuple +from importlib.metadata import version as get_pkg_version +from packaging.version import Version as PkgVersion import numpy as np + import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import dtype_to_ir_type @@ -142,3 +146,24 @@ def get_cudnn_version() -> Tuple[int, int, int]: major, encoded_version = divmod(encoded_version, major_version_magnitude) minor, patch = divmod(encoded_version, 100) return (major, minor, patch) + + +@functools.lru_cache(maxsize=None) +def jax_version_meet_requirement(version: str): + """ + Helper function checking if required JAX version is available + """ + jax_version = PkgVersion(get_pkg_version("jax")) + jax_version_required = PkgVersion(version) + return jax_version >= jax_version_required + + +def is_ffi_enabled(): + """ + Helper function checking if XLA Custom Call with FFI is enabled + """ + is_supported = jax_version_meet_requirement("0.4.31") + # New APIs with FFI are enabled by default + is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1")) + assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value" + return is_supported and is_enabled diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index cc64951a95..e503792dc0 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -11,6 +11,7 @@ from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import DType as TEDType @@ -25,6 +26,7 @@ get_padded_spec, multidim_transpose, normalize_axis_boundary, + is_ffi_enabled, ) from .activation import ActivationEnum from .activation import _jax_act_lu @@ -262,45 +264,49 @@ def lowering( assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - if static_axis_boundary >= 0: - for i in range(static_axis_boundary + 1): - assert ir_x_shape[i] == 1 - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - - transposed_x_shape = multidim_transpose( - ir_x_shape, static_axis_boundary, transpose_axis_boundary - ) - - out_types = [ - ir.RankedTensorType.get(ir_x_shape, ir_out_dtype), - ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [x, amax, scale, scale_inv] - operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - contracted_x_shape = ( - reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), - reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]), - ) - opaque = transformer_engine_jax.pack_common_descriptor( - contracted_x_shape, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), - ) - - out = custom_caller( - CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2} - ) + if is_ffi_enabled(): + name = "te_cast_transpose_ffi" + out = ffi.ffi_lowering(name, operand_output_aliases={1: 2})( + ctx, x, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary + ) + else: + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + if static_axis_boundary >= 0: + for i in range(static_axis_boundary + 1): + assert ir_x_shape[i] == 1 + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + + transposed_x_shape = multidim_transpose( + ir_x_shape, static_axis_boundary, transpose_axis_boundary + ) + out_types = [ + ir.RankedTensorType.get(ir_x_shape, ir_out_dtype), + ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [x, amax, scale, scale_inv] + operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + contracted_x_shape = ( + reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), + reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]), + ) + opaque = transformer_engine_jax.pack_common_descriptor( + contracted_x_shape, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + ) + out = custom_caller( + CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2} + ) return out @staticmethod diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index c084ab09e9..433f4f770d 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -13,8 +13,6 @@ #include #include #include -#include -#include #include #include @@ -27,23 +25,14 @@ #include "common/common.h" #include "common/util/logging.h" +#include "extensions/ffi.h" +#include "extensions/misc.h" +#include "transformer_engine/activation.h" #include "utils.h" namespace transformer_engine { namespace jax { -constexpr int kMaxNumDim = 8; - -// TODO: Rename Shape to ??? -struct Shape { - int num_dim; - size_t dims[kMaxNumDim]; - - void from_vector(const std::vector &shape); - - std::vector to_vector() const; -}; - // Phuong: These 3 functions need to stay in the header file for compilation purpose // 1. inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } @@ -62,8 +51,6 @@ const T *UnpackOpaque(const char *opaque, size_t opaque_len) { return reinterpret_cast(opaque); } -std::vector MakeShapeVector(NVTEShape shape); - // Packing struct CustomCallCommonDescriptor { @@ -167,6 +154,8 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype); +XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler); + void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); // Activation @@ -179,6 +168,10 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler); + pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 51563a8ccd..1e8998b365 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -3,15 +3,16 @@ * * See LICENSE for license information. ************************************************************************/ - #include "transformer_engine/activation.h" #include "extensions.h" #include "transformer_engine/transpose.h" +#include "xla/ffi/api/c_api.h" namespace transformer_engine { namespace jax { +// TODO: We won't need this function anymore when we move to the new XLA custom calls size_t get_activation_len(NVTE_Activation_Type activation_enum) { switch (activation_enum) { case NVTE_Activation_Type::GELU: @@ -43,8 +44,7 @@ size_t get_activation_len(NVTE_Activation_Type activation_enum) { void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, cudaStream_t stream, float *scale_inverse, float *amax, void *output, - NVTE_Activation_Type act_enum) { - auto act_len = get_activation_len(act_enum); + NVTE_Activation_Type act_enum, size_t act_len) { auto input_shape = std::vector{m, n * act_len}; auto output_shape = std::vector{m, n}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); @@ -95,12 +95,39 @@ void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; auto act_enum = static_cast(desc.act_enum); - ; + auto act_len = get_activation_len(act_enum); ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output, - act_enum); + act_enum, act_len); } +Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf, + int64_t act_enum) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + auto *input = input_buf.untyped_data(); + auto *output = output_buf->untyped_data(); + + auto input_dims = input_buf.dimensions(); + auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>()); + auto n = input_dims.back(); + auto act_len = input_dims.end()[-2]; + auto act_type = static_cast(act_enum); + + ActLuImpl(input, m, n, in_dtype, out_dtype, nullptr, stream, nullptr, nullptr, output, act_type, + act_len); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output + .Attr("act_enum")); + void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; float *amax = reinterpret_cast(buffers[1]); @@ -119,10 +146,10 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; auto act_enum = static_cast(desc.act_enum); - ; + auto act_len = get_activation_len(act_enum); ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, output, - act_enum); + act_enum, act_len); } void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -134,7 +161,6 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; auto act_enum = static_cast(desc.act_enum); - ; auto act_len = get_activation_len(act_enum); auto input_shape = std::vector{m, n}; @@ -182,6 +208,76 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq } } +Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, + Result_Type output_buf, int64_t act_enum) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + auto *input = input_buf.untyped_data(); + auto *act_input = act_input_buf.untyped_data(); + auto *output = output_buf->untyped_data(); + + auto act_input_dims = act_input_buf.dimensions(); + auto m = + std::accumulate(act_input_dims.begin(), act_input_dims.end() - 2, 1, std::multiplies<>()); + auto n = act_input_dims.back(); + auto act_len = act_input_dims.end()[-2]; + + auto input_shape = std::vector{m, n}; + auto act_input_shape = std::vector{m, n * act_len}; + auto output_shape = std::vector{m, n * act_len}; + + auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); + auto act_input_tensor = TensorWrapper(act_input, act_input_shape, static_cast(in_dtype)); + auto output_tensor = TensorWrapper(output, output_shape, static_cast(out_dtype)); + + auto act_type = static_cast(act_enum); + switch (act_type) { + case NVTE_Activation_Type::GELU: + nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::GEGLU: + nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SILU: + nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SWIGLU: + nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::RELU: + nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::REGLU: + nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::QGELU: + nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::QGEGLU: + nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SRELU: + nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SREGLU: + nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; + } + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuHandler, DActLuFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // act_input + .Ret() // output + .Attr("act_enum")); + pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) { auto input_shape = std::vector{batch_size, hidden_size}; diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp new file mode 100644 index 0000000000..19fd50cbd1 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#include "extensions/ffi.h" + +#include + +#include "common/util/logging.h" + +namespace transformer_engine { +namespace jax { + +// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186 +DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { + switch (type) { + case xla::ffi::DataType::F16: + return DType::kFloat16; + break; + case xla::ffi::DataType::F32: + return DType::kFloat32; + break; + case xla::ffi::DataType::BF16: + return DType::kBFloat16; + break; + case xla::ffi::DataType::F8E5M2: + return DType::kFloat8E5M2; + break; + case xla::ffi::DataType::F8E4M3FN: + return DType::kFloat8E4M3; + break; + default: + auto type_num = static_cast(type); + NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", + static_cast(type_num)); + break; + } +} + +Error_Type ffi_with_cuda_error_check() { + cudaError_t last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + return Error_Type(XLA_FFI_Error_Code_INTERNAL, + std::string("CUDA error: ") + cudaGetErrorString(last_error)); + } + return Error_Type::Success(); +} + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h new file mode 100644 index 0000000000..77132c3fca --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -0,0 +1,25 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include + +namespace transformer_engine { +namespace jax { + +using Buffer_Type = xla::ffi::AnyBuffer; +using Result_Type = xla::ffi::Result; +using Error_Type = xla::ffi::Error; +using FFI = xla::ffi::Ffi; +using FFI_Stream_Type = xla::ffi::PlatformStream; + +DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type); +Error_Type ffi_with_cuda_error_check(); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h new file mode 100644 index 0000000000..7f6179e91c --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -0,0 +1,30 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include +#include +#include + +namespace transformer_engine { +namespace jax { + +constexpr int kMaxNumDim = 8; + +struct Shape { + int num_dim; + size_t dims[kMaxNumDim]; + + void from_vector(const std::vector &shape); + + std::vector to_vector() const; +}; + +std::vector MakeShapeVector(NVTEShape shape); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 9585e2edf1..14e046fa54 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -3,7 +3,6 @@ * * See LICENSE for license information. ************************************************************************/ - #include "extensions.h" #include "transformer_engine/layer_norm.h" #include "transformer_engine/rmsnorm.h" diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index fb293f2fe1..0a2172bb1b 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -14,6 +14,13 @@ pybind11::capsule EncapsulateFunction(T *fn) { return pybind11::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); } +template +pybind11::capsule EncapsulateFFI(T *fn) { + static_assert(std::is_invocable_r_v, + "Encapsulated function must be an XLA FFI handler"); + return pybind11::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + pybind11::dict Registrations() { pybind11::dict dict; dict["te_transpose"] = EncapsulateFunction(Transpose); @@ -44,6 +51,10 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); + + dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler); + dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); + dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); return dict; } @@ -114,7 +125,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("QGELU", NVTE_Activation_Type::QGELU) .value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("SRELU", NVTE_Activation_Type::SRELU) - .value("SREGLU", NVTE_Activation_Type::SREGLU); + .value("SREGLU", NVTE_Activation_Type::SREGLU) + .export_values(); pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) diff --git a/transformer_engine/jax/csrc/extensions/transpose.cpp b/transformer_engine/jax/csrc/extensions/transpose.cpp index 3e53b7521f..7a2e31312a 100644 --- a/transformer_engine/jax/csrc/extensions/transpose.cpp +++ b/transformer_engine/jax/csrc/extensions/transpose.cpp @@ -7,6 +7,7 @@ #include "transformer_engine/transpose.h" #include "extensions.h" +#include "xla/ffi/api/ffi.h" namespace transformer_engine { namespace jax { @@ -66,6 +67,61 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size stream); } +Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, + Result_Type input_cast_buf, Result_Type input_cast_trans_buf, + Result_Type amax_out_buf, int64_t transpose_axis) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(input_cast_buf->element_type()); + + auto *input = input_buf.untyped_data(); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + + auto *input_cast = input_cast_buf->untyped_data(); + auto *input_cast_trans = input_cast_trans_buf->untyped_data(); + float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); + assert(amax == amax_out); + + if (!use_fp8(out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + + auto input_dims = input_buf.dimensions(); + if (transpose_axis < 0) transpose_axis += input_dims.size(); + auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1, + std::multiplies<>()); + auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1, + std::multiplies<>()); + auto input_shape = std::vector{m, n}; + auto input_trans_shape = std::vector{n, m}; + + auto input_tensor = TensorWrapper(input, input_shape, in_dtype); + auto input_cast_tensor = + TensorWrapper(input_cast, input_shape, out_dtype, amax_out, scale, scale_inv); + auto input_cast_trans_tensor = + TensorWrapper(input_cast_trans, input_trans_shape, out_dtype, amax_out, scale, scale_inv); + + nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(), + stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // input_cast + .Ret() // input_cast_trans + .Ret() // amax_out + .Attr("transpose_axis")); + pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) { auto input_shape = std::vector{batch_size, hidden_size}; From 0075a46aeb896045f0701fc4433d678a064a8e0c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 13 Aug 2024 19:09:52 -0700 Subject: [PATCH 54/72] Upgrade NLTK version to circumvent unsafe pickling in v3.8.1 (#1102) * Switch to nltk>3.8.1 and new data Signed-off-by: Kirthi Shankar Sivamani * fix nltk install Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- examples/jax/encoder/requirements.txt | 2 +- examples/jax/encoder/test_model_parallel_encoder.py | 2 +- examples/jax/encoder/test_multigpu_encoder.py | 2 +- examples/jax/encoder/test_multiprocessing_encoder.py | 2 +- examples/jax/encoder/test_single_gpu_encoder.py | 2 +- qa/L0_jax_unittest/test.sh | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/jax/encoder/requirements.txt b/examples/jax/encoder/requirements.txt index 40b1915c96..26af82aa49 100644 --- a/examples/jax/encoder/requirements.txt +++ b/examples/jax/encoder/requirements.txt @@ -1,4 +1,4 @@ datasets flax>=0.7.1 -nltk +nltk>=3.8.2 optax diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 716d543d5b..646d6e0a12 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -168,7 +168,7 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt") + nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index c6223ed5bb..005ae50e72 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -147,7 +147,7 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt") + nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index c9620aa2be..286c064e96 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -250,7 +250,7 @@ def eval_model( def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt") + nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 674f7de815..363759afea 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -144,7 +144,7 @@ def eval_model(state, test_ds, batch_size, var_collect): def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt") + nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 81bbfa1065..db3aa31951 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -4,7 +4,7 @@ set -xe -pip install nltk==3.8.1 +pip install "nltk>=3.8.2" pip install pytest==8.2.1 : ${TE_PATH:=/opt/transformerengine} From ba0fe9a7e9fc5e9bbb375f79bb423bc1301a334a Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 14 Aug 2024 10:10:40 +0800 Subject: [PATCH 55/72] [JAX] Propagate sm_margin to the underly layernorm kernels (#1089) * Propagate sm_margin to the underly layernorm kernels --------- Signed-off-by: Reese Wang Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- .../jax/cpp_extensions/normalization.py | 32 +++++++++++++---- transformer_engine/jax/csrc/extensions.h | 4 +-- .../jax/csrc/extensions/normalization.cpp | 34 +++++++++++-------- 3 files changed, 47 insertions(+), 23 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index f1d3a7f28d..caf9272b02 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. """JAX/TE custom ops for normalization""" -from functools import partial, reduce +from functools import partial, reduce, cache import operator import os import warnings @@ -40,6 +40,18 @@ ] +@cache +def get_forward_sm_margin(): + """Retrieves the number of stream multiprocessors (SM) reserved for other kernels""" + return int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + + +@cache +def get_backward_sm_margin(): + """Retrieves the number of stream multiprocessors (SM) reserved for other kernels""" + return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + + class LayerNormFwdPrimitive(BasePrimitive): """ Layer Normalization Forward Primitive @@ -77,6 +89,7 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): True, kwargs["zero_centered_gamma"], kwargs["epsilon"], + get_forward_sm_margin(), ) wkspace_aval = out_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -136,7 +149,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): operand_shapes = [x_shape, g_shape, b_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_forward_sm_margin() opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, @@ -354,6 +367,7 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): True, kwargs["zero_centered_gamma"], kwargs["epsilon"], + get_backward_sm_margin(), ) ) wkspace_aval = dx_aval.update( @@ -420,7 +434,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_backward_sm_margin() wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:] opaque = transformer_engine_jax.pack_norm_descriptor( @@ -591,6 +605,7 @@ def abstract(x_aval, gamma_aval, **kwargs): False, False, kwargs["epsilon"], + get_forward_sm_margin(), ) wkspace_aval = out_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -638,7 +653,7 @@ def lowering(ctx, x, gamma, *, epsilon): operand_shapes = [x_shape, g_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_forward_sm_margin() opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, @@ -776,6 +791,7 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs): False, False, kwargs["epsilon"], + get_backward_sm_margin(), ) ) wkspace_aval = dx_aval.update( @@ -829,7 +845,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_backward_sm_margin() opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, @@ -989,6 +1005,7 @@ def abstract( True, zero_centered_gamma, epsilon, + get_forward_sm_margin(), ) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) @@ -1076,7 +1093,7 @@ def lowering( ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_forward_sm_margin() opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, @@ -1296,6 +1313,7 @@ def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtyp False, False, epsilon, + get_forward_sm_margin(), ) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) @@ -1365,7 +1383,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_forward_sm_margin() opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 433f4f770d..b872370715 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -186,7 +186,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, bool is_layer_norm, bool zero_centered_gamma, - float eps); + float eps, int sm_margin); void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -196,7 +196,7 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, bool zero_centered_gamma, - float eps); + float eps, int sm_margin); void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 14e046fa54..fb40400e62 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -13,7 +13,7 @@ namespace jax { pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, bool is_layer_norm, bool zero_centered_gamma, - float eps) { + float eps, int sm_margin) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; @@ -26,7 +26,7 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd // dummy tensor wrappers that will carry workspace size info later TensorWrapper dummy_work_tensor, dummy_barrier_tensor; - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; if (is_layer_norm) { auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); @@ -53,7 +53,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac DType in_dtype, void *weight, DType w_dtype, void *bias, void *output, DType out_dtype, void *workspace, DType work_dtype, void *barrier, DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale, - float *scale_inv, cudaStream_t stream) { + float *scale_inv, int sm_margin, cudaStream_t stream) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; @@ -70,7 +70,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv); auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype); @@ -94,7 +94,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, bool zero_centered_gamma, - float eps) { + float eps, int sm_margin) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; @@ -111,7 +111,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid // dummy tensor wrappers that will carry workspace size info later TensorWrapper dummy_work_tensor, dummy_barrier_tensor; TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor; - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; // initialize dBeta information here -- layernorm will modify but RMSnorm will not @@ -151,7 +151,7 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace void *weight, DType w_dtype, void *ograd, void *workspace, DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu, void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part, - DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype, + DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype, int sm_margin, cudaStream_t stream) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; @@ -173,7 +173,7 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype); auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype); - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; auto workspace_shape = std::vector{wkspace_size}; @@ -227,13 +227,14 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + sm_margin, stream); } void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -262,11 +263,12 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s auto eps = desc.eps; auto out_dtype = in_dtype; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + sm_margin, stream); } void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -286,6 +288,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, auto dbeta_part_dtype = desc.dbeta_part_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; auto *ograd = buffers[0]; auto *mu = buffers[1]; @@ -304,7 +307,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, stream); + dbeta_part_dtype, sm_margin, stream); } void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -334,12 +337,13 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + sm_margin, stream); } void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -367,12 +371,13 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; auto out_dtype = in_dtype; LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + sm_margin, stream); } void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -406,12 +411,13 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si auto dbeta_part_dtype = DType::kByte; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, stream); + dbeta_part_dtype, sm_margin, stream); } } // namespace jax From 67900e8dc17b49e93eaafa80db1bd96f3a307d6f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:50:53 -0700 Subject: [PATCH 56/72] Remove Total Time Measurement (#1105) Remove total time measurement Signed-off-by: Phuong Nguyen --- setup.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/setup.py b/setup.py index d1dd35c027..e418cb95ff 100644 --- a/setup.py +++ b/setup.py @@ -41,9 +41,6 @@ install_and_import("pybind11[global]") from pybind11.setup_helpers import build_ext as BuildExtension -# Start timing -start_time = time.perf_counter() - CMakeBuildExtension = get_build_ext(BuildExtension) @@ -55,7 +52,7 @@ def run(self): start_time = time.perf_counter() super().run() total_time = time.perf_counter() - start_time - print(f"Time for bdist_wheel: {total_time:.2f} seconds") + print(f"Total time for bdist_wheel: {total_time:.2f} seconds") def setup_common_extension() -> CMakeExtension: @@ -171,8 +168,3 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: include_package_data=True, package_data={"": ["VERSION.txt"]}, ) - - # End timing - end_time = time.perf_counter() - total_time = end_time - start_time - print(f"Total build time: {total_time:.2f} seconds") From 8ef3308ad61f913ab9f745ee2d71b6b5ec302450 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:58:29 -0700 Subject: [PATCH 57/72] [TE/JAX] Add default include path for XLA FFI (#1104) * add default path for ffi include * add an option to get XLA_HOME from env --------- Signed-off-by: Phuong Nguyen --- build_tools/jax.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index aba9c749d0..f829230f50 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -2,7 +2,8 @@ # # See LICENSE for license information. -"""Paddle-paddle related extensions.""" +"""JAX related extensions.""" +import os from pathlib import Path import setuptools @@ -11,7 +12,24 @@ from .utils import cuda_path, all_files_in_dir from typing import List -from jax.extend import ffi + +def xla_path() -> str: + """XLA root path lookup. + Throws FileNotFoundError if XLA source is not found.""" + + try: + from jax.extend import ffi + except ImportError: + if os.getenv("XLA_HOME"): + xla_home = Path(os.getenv("XLA_HOME")) + else: + xla_home = "/opt/xla" + else: + xla_home = ffi.include_dir() + + if not os.path.isdir(xla_home): + raise FileNotFoundError("Could not find xla source.") + return xla_home def setup_jax_extension( @@ -29,14 +47,14 @@ def setup_jax_extension( # Header files cuda_home, _ = cuda_path() - jax_ffi_include = ffi.include_dir() + xla_home = xla_path() include_dirs = [ cuda_home / "include", common_header_files, common_header_files / "common", common_header_files / "common" / "include", csrc_header_files, - jax_ffi_include, + xla_home, ] # Compile flags From cc329b79e9bc167a6c91ebc82787de228ebfe9f4 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 14 Aug 2024 16:42:58 -0700 Subject: [PATCH 58/72] Bump minimum CUDA version to 12.0 (#1103) * Bump minimum CUDA version to 12.0 Signed-off-by: Tim Moon * Debug CUDA version check Signed-off-by: Tim Moon * Debug CMake build Signed-off-by: Tim Moon * Review suggestions from @ksivaman and @ptrendx Remove logic for CUDA <12.0 in PyTorch and Paddle builds. Update version in docs and README. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/build.yml | 2 +- README.rst | 6 +-- build_tools/build_ext.py | 4 +- build_tools/paddle.py | 18 ++++--- build_tools/pytorch.py | 18 ++++--- docs/installation.rst | 4 +- transformer_engine/common/CMakeLists.txt | 63 +++++++++++++----------- 7 files changed, 67 insertions(+), 48 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 313aee6ab8..896d8f927e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -12,7 +12,7 @@ jobs: name: 'Core' runs-on: ubuntu-latest container: - image: nvcr.io/nvidia/cuda:12.5.0-devel-ubuntu22.04 + image: nvcr.io/nvidia/cuda:12.0.0-devel-ubuntu22.04 options: --user root steps: - name: 'Dependencies' diff --git a/README.rst b/README.rst index 085c91ca49..25ed8af1de 100644 --- a/README.rst +++ b/README.rst @@ -149,8 +149,8 @@ Installation Pre-requisites ^^^^^^^^^^^^^^^^^^^^ * Linux x86_64 -* CUDA 11.8+ for Hopper and CUDA 12.1+ for Ada -* NVIDIA Driver supporting CUDA 11.8 or later +* CUDA 12.0+ for Hopper and CUDA 12.1+ for Ada +* NVIDIA Driver supporting CUDA 12.0 or later * cuDNN 8.1 or later * For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later. @@ -182,7 +182,7 @@ From source Compiling with FlashAttention-2 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance. +Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance. It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug `_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index d7351a8617..f71cef08ea 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -70,8 +70,8 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: configure_command.append(f"-Dpybind11_DIR={pybind11_dir}") # CMake build and install commands - build_command = [_cmake_bin, "--build", build_dir] - install_command = [_cmake_bin, "--install", build_dir] + build_command = [_cmake_bin, "--build", build_dir, "--verbose"] + install_command = [_cmake_bin, "--install", build_dir, "--verbose"] # Check whether parallel build is restricted max_jobs = get_max_jobs_for_parallel_build() diff --git a/build_tools/paddle.py b/build_tools/paddle.py index f3140cf028..f410682875 100644 --- a/build_tools/paddle.py +++ b/build_tools/paddle.py @@ -62,12 +62,18 @@ def setup_paddle_extension( except FileNotFoundError: print("Could not determine CUDA Toolkit version") else: - if version >= (11, 2): - nvcc_flags.extend(["--threads", os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1")]) - if version >= (11, 0): - nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) - if version >= (11, 8): - nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"]) + if version < (12, 0): + raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") + nvcc_flags.extend( + ( + "--threads", + os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"), + "-gencode", + "arch=compute_80,code=sm_80", + "-gencode", + "arch=compute_90,code=sm_90", + ) + ) # Construct Paddle CUDA extension sources = [str(path) for path in sources] diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 9b858653de..f932f0695e 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -67,12 +67,18 @@ def setup_pytorch_extension( except FileNotFoundError: print("Could not determine CUDA Toolkit version") else: - if version >= (11, 2): - nvcc_flags.extend(["--threads", os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1")]) - if version >= (11, 0): - nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) - if version >= (11, 8): - nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"]) + if version < (12, 0): + raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") + nvcc_flags.extend( + ( + "--threads", + os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"), + "-gencode", + "arch=compute_80,code=sm_80", + "-gencode", + "arch=compute_90,code=sm_90", + ) + ) # Libraries library_dirs = [] diff --git a/docs/installation.rst b/docs/installation.rst index 5dd10a79d1..012f3303cb 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -12,8 +12,8 @@ Prerequisites .. _driver link: https://www.nvidia.com/drivers 1. Linux x86_64 -2. `CUDA 11.8 `__ -3. |driver link|_ supporting CUDA 11.8 or later. +2. `CUDA 12.0 `__ +3. |driver link|_ supporting CUDA 12.0 or later. 4. `cuDNN 8.1 `__ or later. 5. For FP8/FP16/BF16 fused attention, `CUDA 12.1 `__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9.1 `__ or later. diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 048e7fd61a..7fab75dca0 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -4,39 +4,27 @@ cmake_minimum_required(VERSION 3.21) +# Language options if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) endif() - set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) - -project(transformer_engine LANGUAGES CUDA CXX) - -set(BUILD_THREADS_PER_JOB $ENV{NVTE_BUILD_THREADS_PER_JOB}) -if (NOT BUILD_THREADS_PER_JOB) - set(BUILD_THREADS_PER_JOB 1) -endif() -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}") - -if(DEFINED ENV{MAX_JOBS}) - set(JOBS $ENV{MAX_JOBS}) -elseif(DEFINED ENV{NVTE_BUILD_MAX_JOBS}) - set(JOBS $ENV{NVTE_BUILD_MAX_JOBS}) -else() - set(JOBS "max number of") -endif() - -message(STATUS "Parallel build with ${JOBS} jobs and ${BUILD_THREADS_PER_JOB} threads per job") - if (CMAKE_BUILD_TYPE STREQUAL "Debug") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") endif() +# Transformer Engine library +project(transformer_engine LANGUAGES CUDA CXX) + +# CUDA Toolkit find_package(CUDAToolkit REQUIRED) +if (CUDAToolkit_VERSION VERSION_LESS 12.0) + message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}") +endif() -# Check for cuDNN frontend API +# cuDNN frontend API set(CUDNN_FRONTEND_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") @@ -47,10 +35,11 @@ if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") endif() include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +# Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) -include_directories(${PROJECT_SOURCE_DIR}/..) # Configure Transformer Engine library +include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) list(APPEND transformer_engine_SOURCES pycudnn.cpp @@ -89,8 +78,6 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") -target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) - # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas @@ -100,7 +87,10 @@ target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") -# Make header files with C++ strings +# Hack to enable dynamic loading in cuDNN frontend +target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) + +# Helper functions to make header files with C++ strings function(make_string_header STRING STRING_NAME) configure_file(util/string_header.h.in "string_headers/${STRING_NAME}.h" @@ -112,10 +102,11 @@ function(make_string_header_from_file file_ STRING_NAME) "string_headers/${STRING_NAME}.h" @ONLY) endfunction() + +# Header files with C++ strings list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path) make_string_header("${cuda_include_path}" string_path_cuda_include) - make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu string_code_transpose_rtc_cast_transpose_fusion_cu) make_string_header_from_file(transpose/rtc/cast_transpose.cu @@ -126,7 +117,6 @@ make_string_header_from_file(utils.cuh string_code_utils_cuh) make_string_header_from_file(util/math.h string_code_util_math_h) - target_include_directories(transformer_engine PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/string_headers") @@ -139,6 +129,23 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") +# Number of parallel build jobs +if(ENV{MAX_JOBS}) + set(BUILD_JOBS_STR "$ENV{MAX_JOBS}") +elseif(ENV{NVTE_BUILD_MAX_JOBS}) + set(BUILD_JOBS_STR "$ENV{NVTE_BUILD_MAX_JOBS}") +else() + set(BUILD_JOBS_STR "max") +endif() +message(STATUS "Parallel build jobs: ${BUILD_JOBS_STR}") + +# Number of threads per parallel build job +set(BUILD_THREADS_PER_JOB $ENV{NVTE_BUILD_THREADS_PER_JOB}) +if (NOT BUILD_THREADS_PER_JOB) + set(BUILD_THREADS_PER_JOB 1) +endif() +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}") +message(STATUS "Threads per parallel build job: ${BUILD_THREADS_PER_JOB}") + # Install library install(TARGETS transformer_engine DESTINATION .) - From a326e351a1fb9c4ff8ee970a407c1f4f35f663af Mon Sep 17 00:00:00 2001 From: Marks101 <46690260+Marks101@users.noreply.github.com> Date: Thu, 15 Aug 2024 02:40:35 +0200 Subject: [PATCH 59/72] [PyTorch] Fix issues with cross attention (#1069) Signed-off-by: Markus Schnoes Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 3 ++- transformer_engine/pytorch/transformer.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3fc805bdc6..b2fb22c8fc 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5778,7 +5778,8 @@ def forward( assert ( attention_mask is not None ), "Please provide attention_mask for padding!" - if max_seqlen_q == max_seqlen_kv: + if self.attention_type == "self": + assert max_seqlen_q == max_seqlen_kv cu_seqlens_q = get_cu_seqlens(attention_mask) cu_seqlens_kv = cu_seqlens_q else: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 130cf91f0e..e40653d998 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -652,7 +652,7 @@ def forward( hidden_states, attention_mask=attention_mask, attn_mask_type=self_attn_mask_type, - window_size=enc_dec_window_size, + window_size=window_size, inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, @@ -679,6 +679,8 @@ def forward( inter_attention_outputs = self.inter_attention( hidden_states, attention_mask=enc_dec_attn_mask, + attn_mask_type=enc_dec_attn_mask_type, + window_size=enc_dec_window_size, encoder_output=encoder_output, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, From 941364df3e6f6c00b4aace90024e9492eb09b511 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 15 Aug 2024 12:14:01 -0700 Subject: [PATCH 60/72] Fix docstring related to `t` in `thd` (#1111) fix typos regarding t in thd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/include/transformer_engine/fused_attn.h | 2 +- transformer_engine/pytorch/attention.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index fa358bc86c..ae08f2a4aa 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -22,7 +22,7 @@ extern "C" { /*! \enum NVTE_QKV_Layout * \brief Memory layouts of QKV tensors. * `S`, `B`, `H`, `D`, and `T` stand for sequence length, batch size, number of heads, - * head size, and the total number of sequences in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. + * head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. * `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length * or padded to the same length, and `THD`-based layouts are used when sequences have * different lengths in a batch. diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b2fb22c8fc..6a7c034b8d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3122,7 +3122,7 @@ def get_qkv_layout( qkv_format: str, default = `sbhd` Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length dimension, `b` batch size, `h` the number of attention heads, - `d` head size, and `t` the total number of sequences in a batch, i.e. + `d` head size, and `t` the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. Returns @@ -5232,7 +5232,7 @@ class DotProductAttention(TransformerEngineBaseModule): qkv_format: str, default = `sbhd` dimension format for `query_layer`, `key_layer` and `value_layer`, {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size, - `h` the number of heads, `d` head size, and `t` the total number of sequences + `h` the number of heads, `d` head size, and `t` the total number of tokens in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats are used for when sequences in a batch are of equal length or padded to equal length, and the `thd` format is used for when sequences in a batch From 304078568d44aad1ea9fd5b533cf710125a47c9a Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Fri, 16 Aug 2024 09:59:08 -0700 Subject: [PATCH 61/72] Add a CP implementation variant with KV all-gather. (#1060) * add window_size to AttnFuncWithCP Signed-off-by: Xiaowei Ren * add seq_offsets_qkvo for cudnn thd Signed-off-by: Xiaowei Ren * add seq_offsets_qkvo to AttnFuncWithCP Signed-off-by: Xiaowei Ren * fix seq_offsets calculation of cudnn thd Signed-off-by: Xiaowei Ren * remove a thd assert Signed-off-by: Xiaowei Ren * fix bias for thd test Signed-off-by: Xiaowei Ren * add thd test for cudnn FA with CP Signed-off-by: Xiaowei Ren * skip GQA/MQA test for cuDNN THD Signed-off-by: Xiaowei Ren * make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1 Signed-off-by: Xiaowei Ren * fix seq_offsets inputs Signed-off-by: Xiaowei Ren * remove two comments Signed-off-by: Xiaowei Ren * fix attn mask type for cudnn thd with cp Signed-off-by: Xiaowei Ren * fix attn_mask_type check Signed-off-by: Xiaowei Ren * fix attn_mask_type for cudnn fa with thd Signed-off-by: Xiaowei Ren * fix a typo Signed-off-by: Xiaowei Ren * fix out dout in bwd Signed-off-by: Xiaowei Ren * assert cudnn+thd does not support attn bias Signed-off-by: Xiaowei Ren * check if attn_mask_type has padding Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * change cp test batch size to 2 Signed-off-by: Xiaowei Ren * fix code format Signed-off-by: Xiaowei Ren * fix two assert info Signed-off-by: Xiaowei Ren * fix assert comment Signed-off-by: Xiaowei Ren * fix assert comments Signed-off-by: Xiaowei Ren * minor fix Signed-off-by: Xiaowei Ren * fix assert comments Signed-off-by: Xiaowei Ren * assert swa+CP cannot work with thd format Signed-off-by: Xiaowei Ren * add a new CP function for swa Signed-off-by: Xiaowei Ren * add a missing dgrads Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * add draft fwd function for swa+cp Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * enable flash attention for swa+cp Signed-off-by: Xiaowei Ren * remove an assert of swa+cp Signed-off-by: Xiaowei Ren * call SWAFuncWithCP for swa+cp Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * use 2hd layout Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change qkv_format check Signed-off-by: Xiaowei Ren * add a code comment Signed-off-by: Xiaowei Ren * tensor shape bug fix Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor shape fix Signed-off-by: Xiaowei Ren * add function to compute cu_seqlens of a cp rank Signed-off-by: Xiaowei Ren * add cu_seqlens and cu_seqlens_padded to context parallelism Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * fix FlashAttention output sequence length Signed-off-by: Xiaowei Ren * fix cu_seqlens_kv_per_step calculation Signed-off-by: Xiaowei Ren * zero dQKV for ending padded tokens Signed-off-by: Xiaowei Ren * zero dQKV tensors of FlashAttention Signed-off-by: Xiaowei Ren * fix softmax_lse correction Signed-off-by: Xiaowei Ren * remove padded tokens of KV to save comounication Signed-off-by: Xiaowei Ren * do not need to zero dkv for FlashAttention any mroe Signed-off-by: Xiaowei Ren * zero out tensors Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * fix CP unit test Signed-off-by: Xiaowei Ren * fix kv shape of cp test with thd format Signed-off-by: Xiaowei Ren * update cp unit test Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add simple code framework Signed-off-by: Xiaowei Ren * try not to have a separate CP function for SWA Signed-off-by: Xiaowei Ren * backup some code change Signed-off-by: Xiaowei Ren * back up code Signed-off-by: Xiaowei Ren * clean up fwd implementation of SWAFuncWithCP Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * fix assert info Signed-off-by: Xiaowei Ren * reduce kv chunk concat overheads Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * make AttnFuncWithCP and SWAFuncWithCP have same API Signed-off-by: Xiaowei Ren * add a docstring Signed-off-by: Xiaowei Ren * preliminary implementation of SWAFuncWithCP forward seems working Signed-off-by: Xiaowei Ren * fix output shape of SWAFuncWithCP Signed-off-by: Xiaowei Ren * code refactoring for FlashAttention and add a code placeholder for bwd Signed-off-by: Xiaowei Ren * use gather_along_first_dim Signed-off-by: Xiaowei Ren * finish the preliminary implementation of bwd Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * fix assert condition Signed-off-by: Xiaowei Ren * add draft implementation of SWA+CP with FusedAttention Signed-off-by: Xiaowei Ren * fix attention mask type of swa+cp Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * add qkv_layout Signed-off-by: Xiaowei Ren * add missing window_size argument Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * fix kv shape of swa+cp Signed-off-by: Xiaowei Ren * bug and typo fix Signed-off-by: Xiaowei Ren * fix dout shape Signed-off-by: Xiaowei Ren * add multi stream in fwd of swa+cp Signed-off-by: Xiaowei Ren * save chunk_ids_to_kv_ag in fwd Signed-off-by: Xiaowei Ren * add multi stream in bwd of swa+cp Signed-off-by: Xiaowei Ren * minor fix to cp stream sync Signed-off-by: Xiaowei Ren * rename AttnFuncWithCP Signed-off-by: Xiaowei Ren * check if window size is None Signed-off-by: Xiaowei Ren * fix docstring of AttnFuncWithCP Signed-off-by: Xiaowei Ren * minor fix Signed-off-by: Xiaowei Ren * add env var for users to choose KV ag or KV p2p Signed-off-by: Xiaowei Ren * update cp tests Signed-off-by: Xiaowei Ren * fix window size in cp unit test Signed-off-by: Xiaowei Ren * fix pytest skip messages Signed-off-by: Xiaowei Ren * add cp_comm_type into API Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code cleaning Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * assert sequence length divisible requirements Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add support table of context parallelism Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo and code format fix Signed-off-by: Xiaowei Ren * do not print multiple disabling messages Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix device in torch.arange and adjust code for the PR of MLA Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typos and clean asserts Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xiaowei Ren --- .../fused_attn/run_fused_attn_with_cp.py | 68 +- .../fused_attn/test_fused_attn_with_cp.py | 63 +- transformer_engine/pytorch/attention.py | 686 ++++++++++++++++-- transformer_engine/pytorch/transformer.py | 6 +- 4 files changed, 736 insertions(+), 87 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index c8f3c8c458..2433a8a09d 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -13,7 +13,9 @@ dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} -def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention"): +def run_dpa_with_cp( + dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p" +): """Test DotProductAttention module with context parallelism""" os.environ["NVTE_FLASH_ATTN"] = "0" @@ -24,10 +26,16 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" config = model_configs_fused_attn[model] - if qkv_format == "thd" and ( - config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias" - ): - return + + assert config.attn_mask_type in [ + "causal", + "no_mask", + ], f"{config.attn_mask_type} is an unsupported attention mask type!" + if kernel_backend == "FusedAttention" and qkv_format == "thd": + if "causal" in config.attn_mask_type: + config.attn_mask_type = "padding_causal" + else: + config.attn_mask_type = "padding" rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) @@ -49,73 +57,77 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= assert rank in cp_comm_ranks cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - assert config.attn_mask_type in [ - "causal", - "no_mask", - ], f"{config.attn_mask_type} is an unsupported attention mask type!" - - if kernel_backend == "FusedAttention" and qkv_format == "thd": - if "causal" in config.attn_mask_type: - config.attn_mask_type = "padding_causal" - else: - config.attn_mask_type = "padding" - # instantiate core attn module core_attn = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, qkv_format=qkv_format, attn_mask_type=config.attn_mask_type, + window_size=config.window_size, ) core_attn = core_attn.cuda() # create flash attn inputs if qkv_format == "bshd": - q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim) + q_input_shape = ( + config.batch_size, + config.max_seqlen_q, + config.num_heads, + config.head_dim_qk, + ) kv_input_shape = ( config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, - config.head_dim, + config.head_dim_qk, ) attn_output_shape = ( config.batch_size, config.max_seqlen_q, - config.num_heads * config.head_dim, + config.num_heads * config.head_dim_qk, ) cu_seqlens_q = None cu_seqlens_kv = None cu_seqlens_q_padded = None cu_seqlens_kv_padded = None elif qkv_format == "sbhd": - q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim) + q_input_shape = ( + config.max_seqlen_q, + config.batch_size, + config.num_heads, + config.head_dim_qk, + ) kv_input_shape = ( config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, - config.head_dim, + config.head_dim_qk, ) attn_output_shape = ( config.max_seqlen_q, config.batch_size, - config.num_heads * config.head_dim, + config.num_heads * config.head_dim_qk, ) cu_seqlens_q = None cu_seqlens_kv = None cu_seqlens_q_padded = None cu_seqlens_kv_padded = None elif qkv_format == "thd": - q_input_shape = (config.batch_size * config.max_seqlen_q, config.num_heads, config.head_dim) + q_input_shape = ( + config.batch_size * config.max_seqlen_q, + config.num_heads, + config.head_dim_qk, + ) kv_input_shape = ( config.batch_size * config.max_seqlen_q, config.num_gqa_groups, - config.head_dim, + config.head_dim_qk, ) attn_output_shape = ( config.batch_size * config.max_seqlen_q, - config.num_heads * config.head_dim, + config.num_heads * config.head_dim_qk, ) seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2) @@ -211,7 +223,9 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) - core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream()) + core_attn.set_context_parallel_group( + cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type + ) out_ = core_attn( q_, k_, diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 31a653b505..0074d18cec 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -16,11 +16,17 @@ ) model_configs_flash_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: b, h, hg, d, sq, skv, p, mask, bias "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA + "cp_1_2": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # MHA "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_2_2": ModelConfig( + 2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # GQA } @@ -39,7 +45,28 @@ def get_bash_arguments(**kwargs): @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -def test_cp_with_flash_attention(dtype, model, qkv_format): +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) +def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): + config = model_configs_flash_attn[model] + if cp_comm_type == "all_gather" and qkv_format == "thd": + pytest.skip( + f"CP implementation with KV all-gather does not support {qkv_format} format yet!" + ) + if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type: + pytest.skip( + f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask" + " type yet!" + ) + if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": + pytest.skip( + f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias" + " type yet!" + ) + if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip( + f"CP implementation with KV P2P does not support window size {config.window_size} yet!" + ) + subprocess.run( get_bash_arguments( dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention" @@ -49,7 +76,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format): model_configs_fused_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: b, h, hg, d, sq, skv, p, mask, bias "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA @@ -66,9 +93,37 @@ def test_cp_with_flash_attention(dtype, model, qkv_format): @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -def test_cp_with_fused_attention(dtype, model, qkv_format): +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) +def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): if qkv_format == "thd" and get_device_compute_capability() < (9, 0): pytest.skip("THD format is only supported on sm90+.") + if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): + pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0") + + config = model_configs_fused_attn[model] + if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: + pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!") + if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": + pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!") + if cp_comm_type == "all_gather" and qkv_format == "thd": + pytest.skip( + f"CP implementation with KV all-gather does not support {qkv_format} format yet!" + ) + if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type: + pytest.skip( + f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask" + " type yet!" + ) + if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": + pytest.skip( + f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias" + " type yet!" + ) + if config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip( + f"Fused attention does not support sliding window attention + context parallelism yet!" + ) + subprocess.run( get_bash_arguments( dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention" diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6a7c034b8d..904dbbde01 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -65,6 +65,8 @@ set_all_rng_states, CudaRNGStatesTracker, graph_safe_rng_available, + gather_along_first_dim, + reduce_scatter_along_first_dim, ) from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo @@ -321,13 +323,6 @@ def get_attention_backend( logger.debug("Disabling FusedAttention as it requires compute capability sm80+") use_fused_attention = False - # Filter: Context parallelism - if context_parallel and use_unfused_attention: - logger.debug( - "Disabling UnfusedDotProductAttention as it does not support context parallelism" - ) - use_unfused_attention = False - # Filter: Data type if use_flash_attention and ( qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor @@ -398,6 +393,81 @@ def get_attention_backend( ) use_flash_attention = False + # Filter: Context parallelism + # qkv_format | attn_mask_type | attn_bias_type | supported backends + # ---------------------------------------------------------------------------------------------------- + # bshd, sbhd | self-attention: | no_bias, post_scale_bias | FlashAttention, FusedAttention + # | no_mask, causal | | + # | cross-attention: | | + # | no_mask | | + # thd | self-attention: | no_bias | FlashAttention, FusedAttention + # | padding, padding_causal | | if no padding between sequences, + # | cross-attention: | | FusedAttention + # | padding | | if there is padding between sequences + # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v. + if context_parallel and use_unfused_attention: + logger.debug( + "Disabling UnfusedDotProductAttention as it does not support context parallelism" + ) + use_unfused_attention = False + if context_parallel and use_flash_attention: + if "bottom_right" in attn_mask_type: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with" + " causal_bottom_right masking" + ) + use_flash_attention = False + elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with causal" + " masking for cross-attention" + ) + use_flash_attention = False + elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with bias type" + " of %s", + core_attention_bias_type, + ) + use_flash_attention = False + elif qkv_format == "thd" and core_attention_bias_type != "no_bias": + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with attention" + " bias for THD format" + ) + use_flash_attention = False + if context_parallel and use_fused_attention: + if "bottom_right" in attn_mask_type: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with" + " causal_bottom_right masking" + ) + use_fused_attention = False + elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with causal" + " masking for cross-attention" + ) + use_fused_attention = False + elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with bias type" + " of %s", + core_attention_bias_type, + ) + use_fused_attention = False + elif qkv_format == "thd" and core_attention_bias_type != "no_bias": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with attention" + " bias for THD format" + ) + use_fused_attention = False + elif head_dim_qk != head_dim_v: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with MLA" + ) + use_fused_attention = False + # Filter: Attention mask # attn_mask_type | supported backends # ------------------------------------------------------------------- @@ -498,11 +568,10 @@ def get_attention_backend( if ( use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]) - and (not _flash_attn_2_3_plus or context_parallel) + and not _flash_attn_2_3_plus ): logger.debug( - "Disabling FlashAttention as sliding window attention requires " - "flash-attn 2.3+ and no context parallelism" + "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" ) use_flash_attention = False @@ -1222,11 +1291,11 @@ def get_cu_seqlens_on_cp_rank( return cu_seqlens_on_cp_rank -class AttnFuncWithCP(torch.autograd.Function): +class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ - Attention implementation with context parallelism. - Split attention compute into multiple steps, and overlap current-step - compute with next-step communication. + Attention implementation with context parallelism. Exchange KV between CP ranks + with P2P in ring topology. Split attention compute into multiple steps, and overlap + current-step compute with next-step communication. """ @staticmethod @@ -1267,6 +1336,7 @@ def forward( padding = "padding" in attn_mask_type if qkv_format in ["bshd", "sbhd"]: + seq_dim = qkv_format.index("s") qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] else: qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format @@ -1280,6 +1350,9 @@ def forward( cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] + assert qkv_format == "thd" or ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" if causal: if qkv_format == "bshd": # [b, s, np, hn] -> [b, 2, s//2, np, hn] @@ -1295,6 +1368,9 @@ def forward( "Only support bias shape of [b, h, sq, sk] for forward, " "and [1, h, sq, sk] for backward!" ) + assert ( + attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 + ), "Sequence length does not meet divisible requirements!" # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] attn_bias_ = attn_bias.view( *attn_bias.shape[:-2], @@ -1310,7 +1386,7 @@ def forward( assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" fa_optional_forward_kwargs = {} if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1] + fa_optional_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) if _flash_attn_2_4_plus: fa_optional_forward_kwargs["alibi_slopes"] = None if _flash_attn_2_5_7_plus: @@ -1546,7 +1622,7 @@ def forward( # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = [-1, -1] + fa_optional_forward_kwargs["window_size"] = (-1, -1) ( _, _, @@ -1667,7 +1743,7 @@ def forward( # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = [-1, -1] + fa_optional_forward_kwargs["window_size"] = (-1, -1) ( _, _, @@ -1821,8 +1897,6 @@ def forward( torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) softmax_lse = softmax_lse.to(torch.float) - if qkv_format in ["bshd", "sbhd"]: - seq_dim = qkv_format.index("s") for i in range(cp_size): if qkv_format == "bshd": out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:]) @@ -1849,8 +1923,6 @@ def forward( cu_seqlens_q_padded, False, ) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" else: if qkv_format in ["bshd", "sbhd"]: flash_attn_fwd_out_correction( @@ -1869,8 +1941,6 @@ def forward( cu_seqlens_q_padded, True, ) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" kv = p2p_comm_buffers[-1] if use_fused_attention: @@ -2056,7 +2126,7 @@ def backward(ctx, dout): out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = [-1, 0] + fa_optional_backward_kwargs["window_size"] = (-1, 0) _flash_attn_backward( dout_, q_, @@ -2141,7 +2211,7 @@ def backward(ctx, dout): out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = [-1, -1] + fa_optional_backward_kwargs["window_size"] = (-1, -1) _flash_attn_backward( dout_, q_, @@ -2232,7 +2302,7 @@ def backward(ctx, dout): out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = [-1, -1] + fa_optional_backward_kwargs["window_size"] = (-1, -1) _flash_attn_backward( dout_, q_, @@ -2291,7 +2361,7 @@ def backward(ctx, dout): out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = [-1, -1] + fa_optional_backward_kwargs["window_size"] = (-1, -1) _flash_attn_backward( dout_, q_, @@ -2486,6 +2556,455 @@ def backward(ctx, dout): ) +@jit_fuser +def get_seq_chunk_ids_to_all_gathered_kv( + local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left +): + """Compute sequence chunk ids to the all-gathered KV.""" + seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv + seq_start_idx = max(0, seq_end_idx - max_seqlen_q - window_size_left) + seqlen = seq_end_idx - seq_start_idx + num_chunks = (seqlen + max_seqlen_kv - 1) // max_seqlen_kv + chunk_ids = torch.arange( + local_chunk_id - num_chunks + 1, + local_chunk_id + 1, + dtype=torch.int32, + device="cuda", + ) + chunk_ids_to_all_gathered_kv = torch.where( + chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1 + ) + return chunk_ids_to_all_gathered_kv + + +class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): + """ + Attention implementation with context parallelism. + KV all-gather between CP ranks is exposed. + """ + + @staticmethod + def forward( + ctx, + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + cp_group, + cp_stream, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + window_size, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + cp_size = get_distributed_world_size(cp_group) + rank = get_distributed_rank(cp_group) + + causal = "causal" in attn_mask_type + padding = "padding" in attn_mask_type + assert causal and not padding, f"{attn_mask_type} mask type is not supported!" + if use_fused_attention and causal and "bottom_right" not in attn_mask_type: + attn_mask_type = attn_mask_type + "_bottom_right" + + assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" + assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + assert ( + use_fused_attention or _flash_attn_2_3_plus + ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + fa_optional_forward_kwargs = {} + if _flash_attn_2_4_plus: + fa_optional_forward_kwargs["alibi_slopes"] = None + + assert qkv_format != "thd", f"{qkv_format} format is not supported!" + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + + seq_dim = qkv_format.index("s") + assert ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + + max_seqlen_q = max_seqlen_q // (2 * cp_size) + max_seqlen_kv = max_seqlen_kv // (2 * cp_size) + cu_seqlens_q = cu_seqlens_q // (2 * cp_size) + cu_seqlens_kv = cu_seqlens_kv // (2 * cp_size) + cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) + cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size) + + if causal: + if qkv_format == "bshd": + # [b, s, np, hn] -> [b, 2, s//2, np, hn] + q = q.view(q.shape[0], 2, q.shape[1] // 2, *q.shape[2:]) + # [b, s, np, hn] -> [s, b, np, hn] + k, v = [x.transpose(0, 1).contiguous() for x in [k, v]] + elif qkv_format == "sbhd": + # [s, b, np, hn] -> [2, s//2, b, np, hn] + q = q.view(2, q.shape[0] // 2, *q.shape[1:]) + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), cp_stream] + + k_ag, _ = gather_along_first_dim(k, cp_group) + v_ag, _ = gather_along_first_dim(v, cp_group) + cp_stream.wait_stream(torch.cuda.current_stream()) + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + + local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] + chunk_ids_to_kv_ag_per_step = [None, None] + out_per_step = [None, None] + softmax_lse_per_step = [None, None] + rng_states = [None, None] + out = torch.empty_like(q) + + for i in range(len(local_seq_chunk_ids) + 1): + if i < len(local_seq_chunk_ids): + with torch.cuda.stream(flash_attn_streams[i]): + chunk_ids_to_kv_ag = get_seq_chunk_ids_to_all_gathered_kv( + local_seq_chunk_ids[i], + cp_size, + max_seqlen_q, + max_seqlen_kv, + ( + max_seqlen_kv * cp_size * 2 + if (window_size is None or window_size[0] == -1) + else window_size[0] + ), + ) + chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag + num_kv_chunks = chunk_ids_to_kv_ag.numel() + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_ = q[:, i].contiguous() + # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn] + k_ = ( + torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag) + .movedim(2, 0) + .contiguous() + .view(k.shape[1], -1, *k.shape[-2:]) + ) + v_ = ( + torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag) + .movedim(2, 0) + .contiguous() + .view(v.shape[1], -1, *v.shape[-2:]) + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q[i].contiguous() + # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn] + k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view( + -1, *k.shape[-3:] + ) + v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view( + -1, *v.shape[-3:] + ) + if use_fused_attention: + out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv * num_kv_chunks, + cu_seqlens_q, + cu_seqlens_kv * num_kv_chunks, + q_, + k_, + v_, + TE_DType[q.dtype], + tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks, + window_size=window_size, + ) + else: + q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] + _, _, _, _, out_per_step[i], softmax_lse_per_step[i], _, rng_states[i] = ( + _flash_attn_forward( + q_, + k_, + v_, + cu_seqlens_q, + cu_seqlens_kv * num_kv_chunks, + max_seqlen_q, + max_seqlen_kv * num_kv_chunks, + dropout_p, + softmax_scale, + causal=True, + return_softmax=False, + window_size=window_size, + **fa_optional_forward_kwargs, + ) + ) + + if i > 0: + with torch.cuda.stream(flash_attn_streams[i - 1]): + if qkv_format == "bshd": + out[:, i - 1].copy_(out_per_step[i - 1].view_as(out[:, i - 1])) + elif qkv_format == "sbhd": + out[i - 1].copy_(out_per_step[i - 1].view_as(out[i - 1])) + + torch.cuda.current_stream().wait_stream(cp_stream) + + if use_fused_attention: + if qkv_format == "bshd": + out = out.view(out.shape[0], -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + out = out.view(-1, *out.shape[-3:]) + else: + out = out.view(-1, *out.shape[-2:]) + + ctx.save_for_backward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *chunk_ids_to_kv_ag_per_step, + *out_per_step, + *softmax_lse_per_step, + *rng_states, + ) + ctx.cp_group = cp_group + ctx.cp_stream = cp_stream + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.softmax_scale = softmax_scale + ctx.qkv_format = qkv_format + ctx.attn_mask_type = attn_mask_type + ctx.attn_bias_type = attn_bias_type + ctx.deterministic = deterministic + ctx.use_fused_attention = use_fused_attention + ctx.window_size = window_size + return out + + @staticmethod + def backward(ctx, dout): + cp_size = get_distributed_world_size(ctx.cp_group) + rank = get_distributed_rank(ctx.cp_group) + + (q, k, v, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ( + ctx.saved_tensors[:7] + ) + chunk_ids_to_kv_ag_per_step = ctx.saved_tensors[7:9] + out_per_step = ctx.saved_tensors[9:11] + softmax_lse_per_step = ctx.saved_tensors[11:13] + rng_states = ctx.saved_tensors[13:15] + + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + + dout = dout.view_as(q) + dq = torch.empty_like(q) + dk = torch.zeros( + (2 * cp_size, k.shape[0] // 2, *k.shape[1:]), dtype=k.dtype, device=k.device + ) + dv = torch.zeros_like(dk) + dq_per_step = [None, None] + dk_per_step = [None, None] + dv_per_step = [None, None] + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream] + # synchronize dkv update across steps + dkv_update_done = torch.cuda.Event() + + k_ag, _ = gather_along_first_dim(k, ctx.cp_group) + v_ag, _ = gather_along_first_dim(v, ctx.cp_group) + ctx.cp_stream.wait_stream(torch.cuda.current_stream()) + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + + local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] + + fa_optional_backward_kwargs = {} + if _flash_attn_2_4_plus: + fa_optional_backward_kwargs["alibi_slopes"] = None + if _flash_attn_2_4_1_plus: + fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + + for i in range(len(local_seq_chunk_ids) + 1): + if i < len(local_seq_chunk_ids): + with torch.cuda.stream(flash_attn_streams[i]): + chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i] + num_kv_chunks = chunk_ids_to_kv_ag.numel() + out_ = out_per_step[i] + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_ = q[:, i].contiguous() + # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn] + k_ = ( + torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag) + .movedim(2, 0) + .contiguous() + .view(k.shape[1], -1, *k.shape[-2:]) + ) + v_ = ( + torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag) + .movedim(2, 0) + .contiguous() + .view(v.shape[1], -1, *v.shape[-2:]) + ) + dout_ = dout[:, i].contiguous().view_as(out_) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q[i].contiguous() + # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn] + k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view( + -1, *k.shape[-3:] + ) + v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view( + -1, *v.shape[-3:] + ) + dout_ = dout[i].contiguous().view_as(out_) + if ctx.use_fused_attention: + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + torch.empty_like(x) for x in [q_, k_, v_] + ] + aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] + dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv * num_kv_chunks, + cu_seqlens_q, + cu_seqlens_kv * num_kv_chunks, + q_, + k_, + v_, + out_, + dout_, + TE_DType[q.dtype], + TE_DType[k.dtype], + aux_ctx_tensors, + tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=ctx.attn_mask_type, + attn_bias_type=ctx.attn_bias_type, + window_size=ctx.window_size, + ) + else: + q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + torch.empty_like(x) for x in [q_, k_, v_] + ] + _flash_attn_backward( + dout_, + q_, + k_, + v_, + out_, + softmax_lse_per_step[i], + dq_per_step[i], + dk_per_step[i], + dv_per_step[i], + cu_seqlens_q, + cu_seqlens_kv * num_kv_chunks, + ctx.max_seqlen_q, + ctx.max_seqlen_kv * num_kv_chunks, + ctx.dropout_p, + ctx.softmax_scale, + True, + window_size=ctx.window_size, + rng_state=rng_states[i], + **fa_optional_backward_kwargs, + ) + + if i > 0: + with torch.cuda.stream(flash_attn_streams[i - 1]): + chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i - 1] + num_kv_chunks = chunk_ids_to_kv_ag.numel() + if ctx.qkv_format == "bshd": + dq[:, i - 1].copy_(dq_per_step[i - 1].view_as(dq[:, i - 1])) + dk_per_step[i - 1] = ( + dk_per_step[i - 1] + .view(k.shape[1], num_kv_chunks, -1, *k.shape[-2:]) + .movedim(0, 2) + .contiguous() + ) + dv_per_step[i - 1] = ( + dv_per_step[i - 1] + .view(v.shape[1], num_kv_chunks, -1, *v.shape[-2:]) + .movedim(0, 2) + .contiguous() + ) + elif ctx.qkv_format == "sbhd": + dq[i - 1].copy_(dq_per_step[i - 1].view_as(dq[i - 1])) + dk_per_step[i - 1] = dk_per_step[i - 1].view( + num_kv_chunks, -1, *k.shape[-3:] + ) + dv_per_step[i - 1] = dv_per_step[i - 1].view( + num_kv_chunks, -1, *v.shape[-3:] + ) + + # wait until dkv update of last step is done + if i > 1: + flash_attn_streams[i - 1].wait_event(dkv_update_done) + dk.index_add_(0, chunk_ids_to_kv_ag, dk_per_step[i - 1]) + dv.index_add_(0, chunk_ids_to_kv_ag, dv_per_step[i - 1]) + if i < len(local_seq_chunk_ids): + flash_attn_streams[i - 1].record_event(dkv_update_done) + + torch.cuda.current_stream().wait_stream(ctx.cp_stream) + + dk = dk.view(-1, *dk.shape[-3:]) + dv = dv.view(-1, *dv.shape[-3:]) + dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) + dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + + if ctx.qkv_format == "bshd": + dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) + dk = dk.transpose(0, 1).contiguous() + dv = dv.transpose(0, 1).contiguous() + elif ctx.qkv_format == "sbhd": + dq = dq.view(-1, *dq.shape[-3:]) + + return ( + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + def attn_forward_func_with_cp( is_training, q, @@ -2501,6 +3020,7 @@ def attn_forward_func_with_cp( cp_group, cp_global_ranks, cp_stream, + cp_comm_type, softmax_scale=None, qkv_format="bshd", attn_mask_type="causal", @@ -2508,8 +3028,12 @@ def attn_forward_func_with_cp( attn_bias=None, deterministic=False, use_fused_attention=False, + window_size=None, ) -> torch.Tensor: - """Attention implementation with context parallelism""" + """ + Attention implementation with context parallelism. + """ + assert qkv_format in [ "bshd", "sbhd", @@ -2533,29 +3057,62 @@ def attn_forward_func_with_cp( assert ( cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!" - out = AttnFuncWithCP.apply( - is_training, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - dropout_p, - cp_group, - cp_global_ranks, - cp_stream, - softmax_scale, - qkv_format, - attn_mask_type, - attn_bias_type, - attn_bias, - deterministic, - use_fused_attention, + + sliding_window_attn = ( + window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) ) + + if sliding_window_attn or cp_comm_type == "all_gather": + out = AttnFuncWithCPAndKVAllGather.apply( + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + cp_group, + cp_stream, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + window_size, + ) + elif cp_comm_type == "p2p": + out = AttnFuncWithCPAndKVP2P.apply( + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + cp_group, + cp_global_ranks, + cp_stream, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + ) + else: + raise ValueError(f"Unsupported communication type: {cp_comm_type}!") + return out @@ -3316,6 +3873,7 @@ def forward( cp_group: Optional[dist_group_type] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, + cp_comm_type: str = "p2p", ) -> torch.Tensor: """flash-attn fprop""" @@ -3424,10 +3982,6 @@ def forward( max_seqlen_kv = seqlens_kv.max().item() if context_parallel: - assert window_size in ( - (-1, -1), - (-1, 0), - ), "Sliding window attention is not supported with context parallelism." assert ( alibi_slopes is None ), "Alibi slope bias addition is not supported with context parallelism." @@ -3447,10 +4001,12 @@ def forward( cp_group, cp_global_ranks, cp_stream, + cp_comm_type, softmax_scale=self.softmax_scale, qkv_format="bshd" if qkv_format == "sbhd" else qkv_format, attn_mask_type=attn_mask_type, deterministic=self.deterministic, + window_size=window_size, ) else: @@ -4995,6 +5551,7 @@ def forward( cp_group: Optional[dist_group_type] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, + cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: @@ -5107,12 +5664,14 @@ def forward( cp_group, cp_global_ranks, cp_stream, + cp_comm_type, softmax_scale=self.softmax_scale, qkv_format=qkv_format, attn_mask_type=attn_mask_type, attn_bias_type=core_attention_bias_type, attn_bias=core_attention_bias, use_fused_attention=True, + window_size=window_size, ) else: with self.attention_dropout_ctx(): @@ -5260,6 +5819,9 @@ class DotProductAttention(TransformerEngineBaseModule): compute and communication overlapping. To address the wave quantization issue of each split step, we add an additional CUDA stream so that we can overlap two flash attention kernels. + cp_comm_type : str + inter-gpu communication type for context parallelism. + Can be "p2p" or "all_gather". """ def __init__( @@ -5280,6 +5842,7 @@ def __init__( cp_group: Optional[dist_group_type] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, + cp_comm_type: str = "p2p", softmax_scale: Optional[float] = None, ) -> None: super().__init__() @@ -5307,6 +5870,7 @@ def __init__( self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream + self.cp_comm_type = cp_comm_type self.hidden_size_per_attention_head_k = ( kv_channels if isinstance(kv_channels, int) else kv_channels[0] @@ -5430,6 +5994,7 @@ def set_context_parallel_group( cp_group: Union[dist_group_type, None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, + cp_comm_type: str = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -5443,10 +6008,14 @@ def set_context_parallel_group( list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. + cp_comm_type : str + inter-gpu communication type for context parallelism. + Can be "p2p" or "all_gather". """ self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream + self.cp_comm_type = cp_comm_type @no_torch_dynamo(recursive=False) def forward( @@ -5943,6 +6512,7 @@ def forward( cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, + cp_comm_type=self.cp_comm_type, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) @@ -5985,6 +6555,7 @@ def forward( cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, + cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, ) @@ -6009,6 +6580,7 @@ def forward( cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, + cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, ) @@ -6437,6 +7009,7 @@ def set_context_parallel_group( cp_group: Union[dist_group_type, None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, + cp_comm_type: str = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -6450,13 +7023,16 @@ def set_context_parallel_group( list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. + cp_comm_type : str + inter-gpu communication type for context parallelism. + Can be "p2p" or "all_gather". """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): if index == 0: continue if hasattr(child, "set_context_parallel_group"): - child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream) + child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type) def forward( self, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index e40653d998..f026da23ef 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -487,6 +487,7 @@ def set_context_parallel_group( cp_group: Union[dist_group_type, None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, + cp_comm_type: str = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -500,13 +501,16 @@ def set_context_parallel_group( list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. + cp_comm_type : str + inter-gpu communication type for context parallelism. + Can be "p2p" or "all_gather". """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): if index == 0: continue if hasattr(child, "set_context_parallel_group"): - child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream) + child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type) def forward( self, From 4edcff5777be08b6f89658572c433aa8f36acf0d Mon Sep 17 00:00:00 2001 From: Shijie Date: Sat, 17 Aug 2024 05:17:39 +0800 Subject: [PATCH 62/72] [PyTorch] Support dtype casting in fused adam (#977) * support dtype casting fusion in FusedAdam Signed-off-by: Shijie Wang * minor changes Signed-off-by: Shijie Wang * fix lint Signed-off-by: Shijie Wang * changes based on review comments Signed-off-by: Shijie Wang * remove unused code Signed-off-by: Shijie Wang * code refactor Signed-off-by: Shijie Wang * fix typo Signed-off-by: Shijie Wang * refactor Signed-off-by: Shijie Wang * remove unused code Signed-off-by: Shijie Wang * Fix linter warnings Signed-off-by: Tim Moon * Copy CUDA headers for framework sdists Signed-off-by: Tim Moon --------- Signed-off-by: Shijie Wang Signed-off-by: Tim Moon Co-authored-by: Tim Moon --- build_tools/utils.py | 41 +- tests/pytorch/test_fused_optimizer.py | 87 +++- transformer_engine/pytorch/csrc/extensions.h | 7 + .../multi_tensor/multi_tensor_adam.cu | 273 ++++++++++++- .../pytorch/csrc/extensions/pybind.cpp | 3 + .../pytorch/csrc/multi_tensor_apply.cuh | 53 ++- .../pytorch/optimizers/__init__.py | 1 + .../pytorch/optimizers/fused_adam.py | 377 ++++++++++-------- 8 files changed, 647 insertions(+), 195 deletions(-) diff --git a/build_tools/utils.py b/build_tools/utils.py index a0837c1c04..81b9a896cb 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -14,7 +14,7 @@ import importlib from pathlib import Path from subprocess import CalledProcessError -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union @functools.lru_cache(maxsize=None) @@ -254,12 +254,39 @@ def get_frameworks() -> List[str]: return _frameworks -def copy_common_headers(te_src, dst): - headers = te_src / "common" - for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True): - new_path = os.path.join(dst, file_path[len(str(te_src)) + 1 :]) - Path(new_path).parent.mkdir(exist_ok=True, parents=True) - shutil.copy(file_path, new_path) +def copy_common_headers( + src_dir: Union[Path, str], + dst_dir: Union[Path, str], +) -> None: + """Copy headers from core library + + src_dir should be the transformer_engine directory within the root + Transformer Engine repository. All .h and .cuh files within + transformer_engine/common are copied into dst_dir. Relative paths + are preserved. + + """ + + # Find common header files in src dir + headers = glob.glob( + os.path.join(str(src_dir), "common", "**", "*.h"), + recursive=True, + ) + headers.extend( + glob.glob( + os.path.join(str(src_dir), "common", "**", "*.cuh"), + recursive=True, + ) + ) + headers = [Path(path) for path in headers] + + # Copy common header files to dst dir + src_dir = Path(src_dir) + dst_dir = Path(dst_dir) + for path in headers: + new_path = dst_dir / path.relative_to(src_dir) + new_path.parent.mkdir(exist_ok=True, parents=True) + shutil.copy(path, new_path) def install_and_import(package): diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 8a50648391..ee6739fbf6 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -10,6 +10,13 @@ from torch import nn from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te +from transformer_engine.pytorch.attention import MultiheadAttention +from transformer_engine.pytorch import fp8_model_init +from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() class TestFusedOptimizer(unittest.TestCase): @@ -169,6 +176,83 @@ def test_frozen_model(self): torch.testing.assert_close(ref_param, tst_param) + @unittest.skipIf(not is_bf16_compatible(), "bf16 if not supported") + def test_bf16_model_weight_cast(self): + dtype = torch.bfloat16 + model = MultiheadAttention( + hidden_size=1024, + num_attention_heads=16, + layer_number=1, + params_dtype=dtype, + fuse_qkv_params=True, + ).cuda() + ref_params = [] + master_params = [] + model_params = [] + for p in model.parameters(): + if p.requires_grad: + ref_params.append(p.detach().clone().float()) + master_params.append(p.detach().clone().float()) + model_params.append(p) + options = { + "lr": 5e-4, + "betas": (0.9, 0.999), + "eps": 1e-08, + "weight_decay": 0, + "amsgrad": False, + } + ref_optim = torch.optim.Adam(ref_params, **options) + tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) + + for i in range(self.iters): + self.gen_grad(ref_params, master_params) + ref_optim.step() + tst_optim.step() + torch.testing.assert_close(ref_params, master_params) + model_params_to_fp32 = [p.float() for p in model_params] + torch.testing.assert_close( + ref_params, model_params_to_fp32, rtol=1e-3, atol=1e-3, equal_nan=True + ) + + @unittest.skipIf(not fp8_available, reason=reason_for_no_fp8) + def test_fp8_model_weight_cast(self): + dtype = torch.bfloat16 + with fp8_model_init(enabled=True): + model = MultiheadAttention( + hidden_size=1024, + num_attention_heads=16, + layer_number=1, + params_dtype=dtype, + fuse_qkv_params=True, + ).cuda() + ref_params = [] + master_params = [] + model_params = [] + for p in model.parameters(): + if p.requires_grad: + ref_params.append(p.detach().clone().float()) + master_params.append(p.detach().clone().float()) + model_params.append(p) + options = { + "lr": 5e-4, + "betas": (0.9, 0.999), + "eps": 1e-08, + "weight_decay": 0, + "amsgrad": False, + } + ref_optim = torch.optim.Adam(ref_params, **options) + tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) + + for i in range(self.iters): + self.gen_grad(ref_params, master_params) + ref_optim.step() + tst_optim.step() + torch.testing.assert_close(ref_params, master_params) + model_params_to_fp32 = [p.float() for p in model_params] + torch.testing.assert_close( + ref_params, model_params_to_fp32, rtol=1e-2, atol=1e-2, equal_nan=True + ) + class TestFusedSGD(TestFusedOptimizer): def __init__(self, *args, **kwargs): @@ -345,8 +429,9 @@ def testGradScalerCapturableMaster(self): if m.__class__ in [torch.nn.Conv2d]: m.half() params_ = [p for p in self.model_.parameters() if p.requires_grad] + master_weights = [p.float() for p in self.model_.parameters() if p.requires_grad] optimizer_ = te.optimizers.FusedAdam( - params_, lr=self.lr, capturable=True, master_weights=True + params_, lr=self.lr, capturable=True, master_weights=master_weights ) scaler = torch.cuda.amp.GradScaler(enabled=True) scaler_ = torch.cuda.amp.GradScaler(enabled=True) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index cd5bda8b63..05e4e97112 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -423,12 +423,19 @@ std::tuple multi_tensor_unscale_l2norm_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor inv_scale, at::optional per_tensor_python); +using transformer_engine::DType; void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay); +void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype); + void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor lr, const float beta1, const float beta2, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu index 2752f92348..09b53a8976 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu @@ -8,16 +8,19 @@ #include #include #include +#include // Another possibility: // #include #include +#include "common/utils.cuh" #include "multi_tensor_apply.cuh" #include "type_shim.h" #define BLOCK_SIZE 512 #define ILP 4 +#define THREADS_PER_WARP 32 typedef enum { ADAM_MODE_0 = 0, // L2 regularization mode @@ -25,6 +28,156 @@ typedef enum { } adamMode_t; using MATH_T = float; +using fp8e4m3 = __nv_fp8_e4m3; +using fp8e5m2 = __nv_fp8_e5m2; +using transformer_engine::DType; + +template +struct is_fp8 : std::false_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +template +struct FP8Data { + float scale; + float *amax_ptr; + float *scale_inv_ptr; + float max; + int warp_id; +}; + +template <> +struct FP8Data {}; + +template +struct AdamFunctorMaster { + static constexpr bool is_fp8_type = is_fp8::value; + + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, + TensorListMetadata<5, is_fp8_type> &tl, // NOLINT(*) + const float beta1, const float beta2, + const float beta1_correction, + const float beta2_correction, const float epsilon, + const float lr, adamMode_t mode, const float decay) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + FP8Data fp8_data; + + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; + + GRAD_T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); + g += chunk_idx * chunk_size; + + PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); + p += chunk_idx * chunk_size; + + FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + m += chunk_idx * chunk_size; + + FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + v += chunk_idx * chunk_size; + + FULL_T *p_master = reinterpret_cast(tl.addresses[4][tensor_loc]); + p_master += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + if constexpr (is_fp8_type) { + float *scale_ptr = reinterpret_cast(tl.fp8_meta_addresses[0][tensor_loc]); + fp8_data.scale = scale_ptr != nullptr ? *scale_ptr : 1; + fp8_data.amax_ptr = reinterpret_cast(tl.fp8_meta_addresses[1][tensor_loc]); + fp8_data.scale_inv_ptr = reinterpret_cast(tl.fp8_meta_addresses[2][tensor_loc]); + fp8_data.warp_id = threadIdx.x / THREADS_PER_WARP; + fp8_data.max = 0; + } + + // see note in multi_tensor_scale_kernel.cu + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = static_cast(g[i]); + r_p[ii] = static_cast(p_master[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } + +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p_master[i] = static_cast(r_p[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); + if constexpr (is_fp8_type) { + __builtin_assume(fp8_data.max >= 0); + fp8_data.max = fmaxf(fabsf(r_p[ii]), fp8_data.max); + p[i] = static_cast(r_p[ii] * fp8_data.scale); + } else { + p[i] = static_cast(r_p[ii]); + } + } + } + } + + if constexpr (is_fp8_type) { + fp8_data.max = transformer_engine::reduce_max( + fp8_data.max, fp8_data.warp_id); + if (threadIdx.x == 0) { + if (fp8_data.amax_ptr != nullptr) { + transformer_engine::atomicMaxFloat(fp8_data.amax_ptr, fp8_data.max); + } + if (fp8_data.scale_inv_ptr != nullptr) { + *fp8_data.scale_inv_ptr = __frcp_rn(fp8_data.scale); + } + } + } + } +}; template struct AdamFunctor { @@ -338,22 +491,114 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, } } + const auto p_in_type = tensor_lists[1][0].scalar_type(); + auto tl_size = tensor_lists.size(); + + // case 4: g, p, m, v + // case 5: g, p, m, v, p_master + TORCH_CHECK(tl_size == 4 || tl_size == 5, "tensor list must contain 4 or 5"); + + if (requires_64bit_indexing) { + if (tl_size == 4) { + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctor(), beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, + weight_decay);) + } else { + // g, p, m, v, p_master + const auto g_in_type = tensor_lists[0][0].scalar_type(); + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, + tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); + } + } else { + if (tl_size == 4) { + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor(), beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, + weight_decay);) + } else { + const auto g_in_type = tensor_lists[0][0].scalar_type(); + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); + } + } + AT_CUDA_CHECK(cudaGetLastError()); +} + +void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype) { + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + size_t max_size = 0; + bool requires_64bit_indexing = false; + for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { + for (auto it2 = it->begin(); it2 != it->end(); it2++) { + if (it2->numel() > max_size) { + max_size = it2->numel(); + if (max_size >= INT_MAX) { + requires_64bit_indexing = true; + break; + } + } + } + if (requires_64bit_indexing) { + break; + } + } + + const auto g_in_type = tensor_lists[0][0].scalar_type(); + auto tl_size = tensor_lists.size(); + + // case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv + TORCH_CHECK(tl_size == 8, "tensor list must contain 8 tensors"); + if (requires_64bit_indexing) { - // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( - tensor_lists[0][0].scalar_type(), 0, "adam", - multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctor(), beta1, beta2, - bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);) + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + fp8_dtype, FP8_T, + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 0, "adam", + multi_tensor_apply<5, true>( + (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); } else { - // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( - tensor_lists[0][0].scalar_type(), 0, "adam", - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctor(), beta1, beta2, - bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);) + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + fp8_dtype, FP8_T, + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 0, "adam", + multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, + lr, (adamMode_t)mode, weight_decay);)); } AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c97c66dd98..11b47ccdec 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -191,6 +191,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_adam", &multi_tensor_adam_cuda, "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); + m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda, + "Compute and apply gradient update to parameters for Adam optimizer", + py::call_guard()); m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda, "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " "support and LR scheduling", diff --git a/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh b/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh index 4996dfd05e..e85ec3afc2 100644 --- a/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh +++ b/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh @@ -12,38 +12,55 @@ #include #include +#include "common/common.h" + // This header is the one-stop shop for all your multi-tensor apply needs. // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; -template -struct TensorListMetadata { +template +struct TensorListMetadataBase { void *addresses[n][depth_to_max_tensors[n - 1]]; int sizes[depth_to_max_tensors[n - 1]]; unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; - int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. + int block_to_chunk[depth_to_max_blocks[n - 1]]; int start_tensor_this_launch; }; +template +struct TensorListMetadata : public TensorListMetadataBase {}; + +template +struct TensorListMetadata : public TensorListMetadataBase { + void *fp8_meta_addresses[3][depth_to_max_tensors[n - 1]]; +}; + template __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop_flag, T tl, U callable, ArgTypes... args) { - // Hand the chunk information to the user-supplied functor to process however it likes. + // Hand the chunk information to the user-supplied functor to process however + // it likes. callable(chunk_size, noop_flag, tl, args...); } -template +template void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor &noop_flag, const std::vector> &tensor_lists, T callable, ArgTypes... args) { - TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + if constexpr (USE_FP8) { + TORCH_CHECK(tensor_lists.size() == depth + 3, + "tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, " + "amax, scale_inv) for fp8"); + } else { + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + } int len0 = tensor_lists[0].size(); TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); auto ref_device = tensor_lists[0][0].device(); TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); - for (int l = 0; l < tensor_lists.size(); l++) { // No range-based for because I need indices + for (int l = 0; l < depth; l++) { // No range-based for because I need indices TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); for (int t = 0; t < tensor_lists[l].size(); t++) { // TODO: Print which tensor fails. @@ -58,9 +75,14 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor } } + if constexpr (USE_FP8) { + TORCH_CHECK(tensor_lists[depth].size() == len0 && tensor_lists[depth + 1].size() == len0, + "Size mismatch among tensor lists"); + } + int ntensors = tensor_lists[0].size(); - TensorListMetadata tl; + TensorListMetadata tl; const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); auto stream = at::cuda::getCurrentCUDAStream(); @@ -72,12 +94,15 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); for (int d = 0; d < depth; d++) tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + if constexpr (USE_FP8) { + for (int i = 0; i < 3; i++) + tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t].data_ptr(); + } loc_tensor_info++; auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) { - // std::cout << chunks_this_tensor << std::endl; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; tl.block_to_chunk[loc_block_info] = chunk; loc_block_info++; @@ -87,7 +112,6 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); if (tensors_full || blocks_full || last_chunk) { - // using accscalar_t = acc_type; multi_tensor_apply_kernel<<>>( chunk_size, noop_flag.data_ptr(), tl, callable, args...); @@ -100,7 +124,14 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor tl.start_tensor_this_launch = t + 1; } else { tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; - for (int d = 0; d < depth; d++) tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) { + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + } + if constexpr (USE_FP8) { + for (int i = 0; i < 3; i++) { + tl.fp8_meta_addresses[i][0] = tl.fp8_meta_addresses[i][loc_tensor_info - 1]; + } + } loc_tensor_info = 1; tl.start_tensor_this_launch = t; } diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index 8cbe720a74..fc9bdc304a 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -8,6 +8,7 @@ multi_tensor_l2norm, multi_tensor_unscale_l2norm, multi_tensor_adam, + multi_tensor_adam_fp8, multi_tensor_adam_capturable, multi_tensor_adam_capturable_master, multi_tensor_sgd, diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 91ce502390..322b93a1d8 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -5,9 +5,27 @@ """Fused Adam optimizer.""" import torch import transformer_engine_torch as tex +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from .multi_tensor_apply import multi_tensor_applier +def get_fp8_meta(fp8_tensor): + """FP8 metadata getter.""" + if fp8_tensor._fp8_meta is None: + raise RuntimeError("FP8 meta data is not initialized.") + + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=fp8_tensor._fp8_meta_forward, + ) + + fp8_meta_index = fp8_tensor._fp8_meta_index + scale = fp8_tensor._fp8_meta[fp8_meta_key].scale[fp8_meta_index] + amax = fp8_tensor._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] + scale_inv = fp8_tensor._scale_inv + return scale, amax, scale_inv + + class FusedAdam(torch.optim.Optimizer): """Implements Adam algorithm. @@ -50,9 +68,11 @@ class FusedAdam(torch.optim.Optimizer): method is called. (default: True) capturable (bool, optional): whether to use the version of the optimizer that can be used with CUDA Graphs. (default: False) - master_weights (bool, optional): whether to maintain FP32 master weights - in the optimizer with FP16 mixed precision training, currently can - only be used with capturable set to True. (default: False) + master_weights (list of torch.Tensor, optional): master weights to use + for mixed precision training. If provided, the optimizer will update + the master weights and then cast the master weights to the model weights. + If not provided, the optimizer will update the model weights directly. + (default: None) .. _Adam - A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -72,15 +92,12 @@ def __init__( amsgrad=False, set_grad_none=True, capturable=False, - master_weights=False, + master_weights=None, ): if amsgrad: raise RuntimeError("FusedAdam does not support the AMSGrad variant.") - if master_weights and not capturable: - raise RuntimeError( - "Master weights is currently only supported with the capturable version." - ) + # If the optimizer is capturable then LR should be a tensor (on GPU) lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr defaults = dict( @@ -95,20 +112,10 @@ def __init__( self.set_grad_none = set_grad_none self.capturable = capturable - self.master_weights = master_weights - # Create full precision master weights - self.param_groups_master = [] - for _, pg in enumerate(self.param_groups): - param_list = pg["params"] - self.param_groups_master.append( - { - "params": [ - p.clone().detach().float() if self.master_weights else None - for p in param_list - ], - } - ) + if master_weights is not None: + assert isinstance(master_weights, list), "master_weights must be a list if provided" + self.master_weights = master_weights if capturable: for idx, group in enumerate(self.param_groups): @@ -123,6 +130,7 @@ def __init__( # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") self.multi_tensor_adam = tex.multi_tensor_adam + self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8 self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master @@ -147,7 +155,9 @@ def step(self, closure=None, grad_scaler=None): if closure is not None: loss = closure() - for group, group_master in zip(self.param_groups, self.param_groups_master): + master_param_idx = 0 + + for group in self.param_groups: if len(group["params"]) == 0: continue device = group["params"][0].device @@ -166,51 +176,131 @@ def step(self, closure=None, grad_scaler=None): ) # create lists for multi-tensor apply - g_16, p_16, m_16, v_16 = [], [], [], [] - g_bf, p_bf, m_bf, v_bf = [], [], [], [] - g_32, p_32, m_32, v_32 = [], [], [], [] - p_16_master = [] - p_32_master = [] - - for p, p_master in zip(group["params"], group_master["params"]): - if p.grad is None: - continue - if p.grad.data.is_sparse: - raise RuntimeError("FusedAdam does not support sparse gradients.") - + p_main_of_fp8_model = [] + p_main_of_f16_model = [] + g_of_fp8_model = [] + g_of_f16_model = [] + g_of_f32_model = [] + m_of_fp8_model = [] + m_of_f16_model = [] + m_of_f32_model = [] + v_of_fp8_model = [] + v_of_f16_model = [] + v_of_f32_model = [] + p_fp8_model = [] + p_f16_model = [] + p_f32_model = [] + # fp8 meta + scales = [] + amaxes = [] + scale_invs = [] + + # Only used when extra params include fp8 tensors. Otherwise, it doesn't matter what the out_dtype is. + out_dtype = tex.DType.kFloat32 + + has_fp16 = False + has_bf16 = False + + for p in group["params"]: state = self.state[p] + # State initialization if len(state) == 0: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p.data).float() # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p.data).float() + # Master weights + if self.master_weights and p.dtype != torch.float32: + # model weights can be fp32/bf16/fp16/fp8 + # If it's fp32, it has no corresponding master weights + state["master_param"] = self.master_weights[master_param_idx] + master_param_idx += 1 + assert ( + state["master_param"].shape == p.shape + ), "Master weights shape must match model weights shape" + else: + state["master_param"] = None + + p_master = state["master_param"] + p_grad = p.grad + + if self.master_weights and p_master is not None and p_master.grad is not None: + p_grad = p_master.grad + + if p_grad is None: + continue + if p_grad.data.is_sparse: + raise RuntimeError("FusedAdam does not support sparse gradients.") - if p.dtype == torch.float16: + if isinstance(p, Float8Tensor): + out_dtype = p._fp8_dtype + p_fp8_model.append(p._data.data) + scale, amax, scale_inv = get_fp8_meta(p) + scales.append(scale) + amaxes.append(amax) + scale_invs.append(scale_inv) if self.master_weights: - p_16_master.append(p_master.data) - g_16.append(p.grad.data) - p_16.append(p.data) - m_16.append(state["exp_avg"]) - v_16.append(state["exp_avg_sq"]) - elif p.dtype == torch.bfloat16: - g_bf.append(p.grad) - p_bf.append(p) - m_bf.append(state["exp_avg"]) - v_bf.append(state["exp_avg_sq"]) - elif p.dtype == torch.float32: + p_main_of_fp8_model.append(p_master.data) + g_of_fp8_model.append(p_grad.data) + m_of_fp8_model.append(state["exp_avg"]) + v_of_fp8_model.append(state["exp_avg_sq"]) + elif p.dtype in [torch.float16, torch.bfloat16]: + has_fp16 = has_fp16 or p.dtype == torch.float16 + has_bf16 = has_bf16 or p.dtype == torch.bfloat16 + p_f16_model.append(p.data) if self.master_weights: - p_32_master.append(p_master.data) - g_32.append(p.grad.data) - p_32.append(p.data) - m_32.append(state["exp_avg"]) - v_32.append(state["exp_avg_sq"]) + p_main_of_f16_model.append(p_master.data) + g_of_f16_model.append(p_grad.data) + m_of_f16_model.append(state["exp_avg"]) + v_of_f16_model.append(state["exp_avg_sq"]) + elif p.dtype == torch.float32: + p_f32_model.append(p.data) + g_of_f32_model.append(p_grad.data) + m_of_f32_model.append(state["exp_avg"]) + v_of_f32_model.append(state["exp_avg_sq"]) else: - raise RuntimeError("FusedAdam only support fp16 and fp32.") + raise RuntimeError("FusedAdam only support model weights in fp16/bf16 and fp8") + + if self.capturable and len(p_fp8_model) > 0: + raise RuntimeError( + "FusedAdam does not support FP8 model weights with capturable=True." + ) + + if has_fp16 and has_bf16: + # simple to add support for this, but not needed for now + raise RuntimeError( + "FusedAdam does not support a mix of float16 and bfloat16 model weights." + ) + + def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=None): + # Closures defined in a loop can have unexpected + # behavior when called outside the loop. However, this + # function is called in the same loop iteration as it + # is defined. + # pylint: disable=cell-var-from-loop + inv_scale_arg = () if inv_scale is None else (inv_scale,) + out_dtype_arg = () if out_dtype is None else (out_dtype,) + multi_tensor_applier( + adam_func, + self._dummy_overflow_buf, + tensor_lists, + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + *inv_scale_arg, + *out_dtype_arg, + ) - # If the optimizer is capturable, then if there's a grad scaler it works - # on the GPU + a different multi_tensor_applier should be called if self.capturable: + # If the optimizer is capturable, then if there's a grad scaler it works + # on the GPU + a different multi_tensor_applier should be called + # overflow check of gradients found_inf = ( grad_scaler._check_inf_per_device(self)[device] @@ -228,113 +318,76 @@ def step(self, closure=None, grad_scaler=None): scale = torch.ones((1,), device=device) inv_scale = torch.ones((1,), device=device) - if len(g_16) > 0: - multi_tensor_applier( - ( - self.multi_tensor_adam_capturable_master - if self.master_weights - else self.multi_tensor_adam_capturable - ), - self._dummy_overflow_buf, - ( - [g_16, p_16, m_16, v_16, p_16_master] - if self.master_weights - else [g_16, p_16, m_16, v_16] - ), - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - inv_scale, - ) - - if len(g_bf) > 0: - multi_tensor_applier( - self.multi_tensor_adam_capturable, - self._dummy_overflow_buf, - [g_bf, p_bf, m_bf, v_bf], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - inv_scale, - ) - - if len(g_32) > 0: - multi_tensor_applier( - ( - self.multi_tensor_adam_capturable_master - if self.master_weights - else self.multi_tensor_adam_capturable - ), - self._dummy_overflow_buf, - ( - [g_32, p_32, m_32, v_32, p_32_master] - if self.master_weights - else [g_32, p_32, m_32, v_32] - ), - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - inv_scale, - ) - else: - if len(g_16) > 0: - multi_tensor_applier( - self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_16, p_16, m_16, v_16], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - ) - - if len(g_bf) > 0: - multi_tensor_applier( - self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_bf, p_bf, m_bf, v_bf], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - ) - - if len(g_32) > 0: - multi_tensor_applier( - self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_32, p_32, m_32, v_32], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - ) + if self.master_weights: + if len(p_f16_model) > 0: + tensor_lists = [ + g_of_f16_model, + p_f16_model, + m_of_f16_model, + v_of_f16_model, + p_main_of_f16_model, + ] + apply_multi_tensor_adam( + self.multi_tensor_adam_capturable_master, tensor_lists, inv_scale + ) + if len(p_f32_model) > 0: + tensor_lists = [ + g_of_f32_model, + p_f32_model, + m_of_f32_model, + v_of_f32_model, + ] + apply_multi_tensor_adam( + self.multi_tensor_adam_capturable, tensor_lists, inv_scale + ) + else: + if len(p_f16_model) > 0: + tensor_lists = [g_of_f16_model, p_f16_model, m_of_f16_model, v_of_f16_model] + apply_multi_tensor_adam( + self.multi_tensor_adam_capturable, tensor_lists, inv_scale + ) + if len(p_f32_model) > 0: + tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model] + apply_multi_tensor_adam( + self.multi_tensor_adam_capturable, tensor_lists, inv_scale + ) + + elif self.master_weights: # and self.capturable=False + if len(p_f16_model) > 0: + tensor_lists = [ + g_of_f16_model, + p_f16_model, + m_of_f16_model, + v_of_f16_model, + p_main_of_f16_model, + ] + apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + if len(p_fp8_model) > 0: + tensor_lists = [ + g_of_fp8_model, + p_fp8_model, + m_of_fp8_model, + v_of_fp8_model, + p_main_of_fp8_model, + scales, + amaxes, + scale_invs, + ] + apply_multi_tensor_adam(self.multi_tensor_adam_fp8, tensor_lists, out_dtype) + if len(p_f32_model) > 0: + tensor_lists = [ + g_of_f32_model, + p_f32_model, + m_of_f32_model, + v_of_f32_model, + ] + apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + else: # self.master_weights=False and self.capturable=False + if len(p_f16_model) > 0: + tensor_lists = [g_of_f16_model, p_f16_model, m_of_f16_model, v_of_f16_model] + apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + if len(p_f32_model) > 0: + tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model] + apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) return loss From 3bc2c1f387b5adebabb4d327f79e67aee73d5de7 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 30 Aug 2024 17:10:12 -0700 Subject: [PATCH 63/72] Changed version to 1.10.0 Signed-off-by: Przemek Tredak --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index a597619ec0..81c871de46 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.10.0.dev0 +1.10.0 From 442212041ec64e216fe0e1eeb4d5e9b201300816 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 19 Aug 2024 08:49:08 -0700 Subject: [PATCH 64/72] Remove the commit hash from the release documentation (#1118) Signed-off-by: Przemek Tredak --- docs/conf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 77751994d8..7a50ce76cf 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -47,7 +47,10 @@ git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha -version = str(te_version + "-" + git_sha) +if "dev" in te_version: + version = str(te_version + "-" + git_sha) +else: + version = str(te_version) release = te_version # hack: version is used for html creation, so put the version picker From 8683e4c9e5cb71cc037e5aa32e2868dc9e8b6e1f Mon Sep 17 00:00:00 2001 From: hXl3s Date: Tue, 20 Aug 2024 19:01:37 +0200 Subject: [PATCH 65/72] feat(pytorch): Allow TransformerLayer and MultiheadAttention to accept sequence length parameters (#1066) * Added ability for seqlen for transformer and mha layer Signed-off-by: Lukasz Pierscieniewski * Documentation for new parameters Signed-off-by: Lukasz Pierscieniewski * Add tests for THD layout, assert for THD layout with KV-Cache Signed-off-by: Lukasz Pierscieniewski * Fixed tests Signed-off-by: Lukasz Pierscieniewski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move THD logic in shape calculation, add missing optional in params Signed-off-by: Lukasz Pierscieniewski * Skip the THD test on GPUs older than Ampere Signed-off-by: Przemek Tredak --------- Signed-off-by: Lukasz Pierscieniewski Signed-off-by: Przemek Tredak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: Przemek Tredak --- tests/pytorch/test_numerics.py | 47 ++++++++++++++++++- transformer_engine/pytorch/attention.py | 44 +++++++++++++---- .../pytorch/module/layernorm_mlp.py | 3 +- transformer_engine/pytorch/transformer.py | 20 ++++++++ 4 files changed, 102 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a219f24674..a2023f539a 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -34,11 +34,13 @@ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace +from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +sm_80plus = get_device_compute_capability() >= (8, 0) seed = 1234 torch.manual_seed(seed) @@ -1548,8 +1550,29 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): attn_input_format="bshd", ) - for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()): - assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical" + torch.manual_seed(0) + block_thd = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0, + attention_dropout=0, + kv_channels=config.embed, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + device="cuda", + attn_input_format="thd", + self_attn_mask_type="padding_causal", + ) + + for (n1, p1), (n2, p2), (n3, p3) in zip( + block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters() + ): + assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical" x_sbhd = torch.randn( (config.seq_len, bs, config.hidden_size), @@ -1559,6 +1582,8 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ) x_bshd = x_sbhd.transpose(0, 1).contiguous() + x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous() + x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len # To make sure forward is also identical (just in case some module decides # to act fancy) @@ -1576,6 +1601,24 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): y_sbhd.transpose(0, 1).contiguous(), ) + # THD is not supported in float32 and on GPUs older than Ampere, skip the test here + if dtype != torch.float32 and sm_80plus: + # To make sure forward is also identical (just in case some module decides + # to act fancy) + torch.manual_seed(0) + y_thd = block_thd( + x_thd, + cu_seqlens_q=x_thd_cumsum, + cu_seqlens_kv=x_thd_cumsum, + max_seqlen_q=config.seq_len, + max_seqlen_kv=config.seq_len, + ) + + torch.testing.assert_close( + y_bshd, + y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + ) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 904dbbde01..71bc15fdad 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7048,6 +7048,10 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, ) -> Tuple[Union[torch.Tensor, None], ...]: """ @@ -7113,6 +7117,18 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + cu_seqlens_q: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` + and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + max_seqlen_q: Optional[int], default = `None` + Maximum sequence length in `query_layer`. + Calculated from `cu_seqlens_q` if not provided. + max_seqlen_kv: Optional[int], default = `None` + Maximum sequence length in `key_layer` and `value_layer`. + Calculated from `cu_seqlens_kv` if not provided. fast_zero_fill: bool, default = `True` Whether to set output tensors to 0 or not before use. """ @@ -7139,6 +7155,9 @@ def forward( # ================================================= if inference_params and self.layer_number is not None: + assert ( + self.qkv_format != "thd" + ), "qkv_format == thd is not supported for an inference with KV-cache!" if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_length inf_max_batch_size = inference_params.max_batch_size @@ -7221,13 +7240,18 @@ def forward( dim=split_dim, ) - # query: -> [sq, b, np, hn] - # key, value: -> [sq, b, ng, hn] - query_layer, key_layer, value_layer = ( - x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) - for x in (query_layer, key_layer, value_layer) - ) - + if self.qkv_format == "thd": + query_layer, key_layer, value_layer = ( + x.reshape(x.size(0), -1, self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer) + ) + else: + # query: -> [sq, b, np, hn] + # key, value: -> [sq, b, ng, hn] + query_layer, key_layer, value_layer = ( + x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer) + ) elif self.attention_type == "cross": # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)] mixed_kv_layer = self.key_value( @@ -7341,8 +7365,10 @@ def forward( key_layer, value_layer, qkv_format=self.qkv_format, - cu_seqlens_q=None, - cu_seqlens_kv=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, attention_mask=attention_mask, attn_mask_type=attn_mask_type, window_size=window_size, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index be6df21322..dc9bef645f 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -13,6 +13,7 @@ from .base import ( get_workspace, + _ub_communicators, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -1297,7 +1298,7 @@ def __init__( self.gemm_gelu_fusion = ( bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and self.activation == "gelu" - and not get_ub("fc1_fprop").is_atomic_gemm() + and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) ) if tp_group is None: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index f026da23ef..4cbee3d628 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -529,6 +529,10 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, ) -> torch.Tensor: """ @@ -604,6 +608,18 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + cu_seqlens_q: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` + and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + max_seqlen_q: Optional[int], default = `None` + Maximum sequence length in `query_layer`. + Calculated from `cu_seqlens_q` if not provided. + max_seqlen_kv: Optional[int], default = `None` + Maximum sequence length in `key_layer` and `value_layer`. + Calculated from `cu_seqlens_kv` if not provided. fast_zero_fill: bool, default = `True` Whether to set output tensors to 0 or not before use. inference_params: InferenceParams, default = None @@ -664,6 +680,10 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, fast_zero_fill=fast_zero_fill, ) From 311b6b6001a1e26689a4efb4b6cfd0756ceea283 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:17:03 -0700 Subject: [PATCH 66/72] Add FP8 support to CP implementation with KV P2P (#1114) * add window_size to AttnFuncWithCP Signed-off-by: Xiaowei Ren * add seq_offsets_qkvo for cudnn thd Signed-off-by: Xiaowei Ren * add seq_offsets_qkvo to AttnFuncWithCP Signed-off-by: Xiaowei Ren * fix seq_offsets calculation of cudnn thd Signed-off-by: Xiaowei Ren * remove a thd assert Signed-off-by: Xiaowei Ren * fix bias for thd test Signed-off-by: Xiaowei Ren * add thd test for cudnn FA with CP Signed-off-by: Xiaowei Ren * skip GQA/MQA test for cuDNN THD Signed-off-by: Xiaowei Ren * make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1 Signed-off-by: Xiaowei Ren * fix seq_offsets inputs Signed-off-by: Xiaowei Ren * remove two comments Signed-off-by: Xiaowei Ren * fix attn mask type for cudnn thd with cp Signed-off-by: Xiaowei Ren * fix attn_mask_type check Signed-off-by: Xiaowei Ren * fix attn_mask_type for cudnn fa with thd Signed-off-by: Xiaowei Ren * fix a typo Signed-off-by: Xiaowei Ren * fix out dout in bwd Signed-off-by: Xiaowei Ren * assert cudnn+thd does not support attn bias Signed-off-by: Xiaowei Ren * check if attn_mask_type has padding Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * change cp test batch size to 2 Signed-off-by: Xiaowei Ren * fix code format Signed-off-by: Xiaowei Ren * fix two assert info Signed-off-by: Xiaowei Ren * fix assert comment Signed-off-by: Xiaowei Ren * fix assert comments Signed-off-by: Xiaowei Ren * minor fix Signed-off-by: Xiaowei Ren * fix assert comments Signed-off-by: Xiaowei Ren * assert swa+CP cannot work with thd format Signed-off-by: Xiaowei Ren * add a new CP function for swa Signed-off-by: Xiaowei Ren * add a missing dgrads Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * add draft fwd function for swa+cp Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * enable flash attention for swa+cp Signed-off-by: Xiaowei Ren * remove an assert of swa+cp Signed-off-by: Xiaowei Ren * call SWAFuncWithCP for swa+cp Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * use 2hd layout Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change qkv_format check Signed-off-by: Xiaowei Ren * add a code comment Signed-off-by: Xiaowei Ren * tensor shape bug fix Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor shape fix Signed-off-by: Xiaowei Ren * add function to compute cu_seqlens of a cp rank Signed-off-by: Xiaowei Ren * add cu_seqlens and cu_seqlens_padded to context parallelism Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * fix FlashAttention output sequence length Signed-off-by: Xiaowei Ren * fix cu_seqlens_kv_per_step calculation Signed-off-by: Xiaowei Ren * zero dQKV for ending padded tokens Signed-off-by: Xiaowei Ren * zero dQKV tensors of FlashAttention Signed-off-by: Xiaowei Ren * fix softmax_lse correction Signed-off-by: Xiaowei Ren * remove padded tokens of KV to save comounication Signed-off-by: Xiaowei Ren * do not need to zero dkv for FlashAttention any mroe Signed-off-by: Xiaowei Ren * zero out tensors Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * fix CP unit test Signed-off-by: Xiaowei Ren * fix kv shape of cp test with thd format Signed-off-by: Xiaowei Ren * update cp unit test Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add simple code framework Signed-off-by: Xiaowei Ren * try not to have a separate CP function for SWA Signed-off-by: Xiaowei Ren * backup some code change Signed-off-by: Xiaowei Ren * back up code Signed-off-by: Xiaowei Ren * clean up fwd implementation of SWAFuncWithCP Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * fix assert info Signed-off-by: Xiaowei Ren * reduce kv chunk concat overheads Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * make AttnFuncWithCP and SWAFuncWithCP have same API Signed-off-by: Xiaowei Ren * add a docstring Signed-off-by: Xiaowei Ren * preliminary implementation of SWAFuncWithCP forward seems working Signed-off-by: Xiaowei Ren * fix output shape of SWAFuncWithCP Signed-off-by: Xiaowei Ren * code refactoring for FlashAttention and add a code placeholder for bwd Signed-off-by: Xiaowei Ren * use gather_along_first_dim Signed-off-by: Xiaowei Ren * finish the preliminary implementation of bwd Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * fix assert condition Signed-off-by: Xiaowei Ren * add draft implementation of SWA+CP with FusedAttention Signed-off-by: Xiaowei Ren * fix attention mask type of swa+cp Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * add qkv_layout Signed-off-by: Xiaowei Ren * add missing window_size argument Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * fix kv shape of swa+cp Signed-off-by: Xiaowei Ren * bug and typo fix Signed-off-by: Xiaowei Ren * fix dout shape Signed-off-by: Xiaowei Ren * add multi stream in fwd of swa+cp Signed-off-by: Xiaowei Ren * save chunk_ids_to_kv_ag in fwd Signed-off-by: Xiaowei Ren * add multi stream in bwd of swa+cp Signed-off-by: Xiaowei Ren * minor fix to cp stream sync Signed-off-by: Xiaowei Ren * rename AttnFuncWithCP Signed-off-by: Xiaowei Ren * check if window size is None Signed-off-by: Xiaowei Ren * fix docstring of AttnFuncWithCP Signed-off-by: Xiaowei Ren * minor fix Signed-off-by: Xiaowei Ren * add env var for users to choose KV ag or KV p2p Signed-off-by: Xiaowei Ren * update cp tests Signed-off-by: Xiaowei Ren * fix window size in cp unit test Signed-off-by: Xiaowei Ren * fix pytest skip messages Signed-off-by: Xiaowei Ren * add cp_comm_type into API Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code cleaning Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add deterministic konb in cuDNN fused attn backend Signed-off-by: Xiaowei Ren * pass fp8 and fp8_meta to attn_func_with_cp Signed-off-by: Xiaowei Ren * assert only Fused Attn can support FP8+CP Signed-off-by: Xiaowei Ren * remove redundant assert Signed-off-by: Xiaowei Ren * add a fwd draft implementation of FP8 + CP Signed-off-by: Xiaowei Ren * save fp8 and fp8_meta Signed-off-by: Xiaowei Ren * assert sequence length divisible requirements Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove a redundant qkv_layout compute Signed-off-by: Xiaowei Ren * if condition change Signed-off-by: Xiaowei Ren * some typo fix Signed-off-by: Xiaowei Ren * add support table of context parallelism Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo and code format fix Signed-off-by: Xiaowei Ren * do not print multiple disabling messages Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix aux_ctx_tensors of FP8 Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * fix device in torch.arange and adjust code for the PR of MLA Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * commit code change for FP8+CP Signed-off-by: Xiaowei Ren * commit more code change for FP8+CP Signed-off-by: Xiaowei Ren * commit more fp8 code for FP8+CP Signed-off-by: Xiaowei Ren * bug fixes Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * cast merged CP results from FP32 to BF16 Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * fix softmax_lse Signed-off-by: Xiaowei Ren * fix some bugs of FP8 dkv exchange Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * add FP8 unit test Signed-off-by: Xiaowei Ren * fix typos and clean asserts Signed-off-by: Xiaowei Ren * fix get_p2p_comm_info Signed-off-by: Xiaowei Ren * fix dkv p2p exchange Signed-off-by: Xiaowei Ren * minor fix Signed-off-by: Xiaowei Ren * change FP8 dkv P2P to A2A Signed-off-by: Xiaowei Ren * add FP8+CP unit test Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * assert amax reduction is needed for FP8+CP Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove duplicated code Signed-off-by: Xiaowei Ren * destroy process group in CP unit test Signed-off-by: Xiaowei Ren * remove interval from fp8_recipe because it has been deprecated Signed-off-by: Xiaowei Ren * try to fix the failed CP test with the latest CI pipeline Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove redundant f before string Signed-off-by: Xiaowei Ren * change META_O_CP Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xiaowei Ren --- .../fused_attn/run_fused_attn_with_cp.py | 147 ++-- .../fused_attn/test_fused_attn_with_cp.py | 12 +- transformer_engine/pytorch/attention.py | 696 ++++++++++++------ 3 files changed, 592 insertions(+), 263 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 2433a8a09d..6c775fb127 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -2,15 +2,18 @@ # # See LICENSE for license information. -import os, sys +import os, sys, logging +from contextlib import nullcontext import torch import torch.distributed as dist from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank import transformer_engine_torch as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn +from transformer_engine.pytorch.fp8 import fp8_autocast +from transformer_engine.common.recipe import DelayedScaling -dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} +dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} def run_dpa_with_cp( @@ -57,6 +60,9 @@ def run_dpa_with_cp( assert rank in cp_comm_ranks cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if dtype == "fp8": + fp8_recipe = DelayedScaling(fp8_dpa=True) + # instantiate core attn module core_attn = DotProductAttention( config.num_heads, @@ -171,18 +177,27 @@ def run_dpa_with_cp( # run core_attn without CP for x in [q, k, v]: x.requires_grad = True - out = core_attn( - q, - k, - v, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1], - ) - out.backward(dout) + + if dtype == "fp8": + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() + + with fp8_context: + out = core_attn( + q, + k, + v, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], + cu_seqlens_kv_padded=( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] + ), + ) + out.backward(dout) # run core_attn wit CP q_, k_, v_, dout_, *rest = [ @@ -226,31 +241,34 @@ def run_dpa_with_cp( core_attn.set_context_parallel_group( cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type ) - out_ = core_attn( - q_, - k_, - v_, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias_, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1], - ) - out_.backward(dout_) + + if dtype == "fp8": + core_attn.reset_fp8_meta_tensors() + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() + + with fp8_context: + out_ = core_attn( + q_, + k_, + v_, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias_, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], + cu_seqlens_kv_padded=( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] + ), + ) + out_.backward(dout_) for x in [out_, q_.grad, k_.grad, v_.grad]: assert torch.all(~torch.isnan(x)) assert torch.all(~torch.isinf(x)) # compare results with and without CP - tols = dict(atol=5e-3, rtol=5e-3) - if dtype == "bf16": - if config.num_heads == config.num_gqa_groups: - tols = dict(atol=2.5e-2, rtol=2.5e-2) - else: - tols = dict(atol=3.5e-2, rtol=3.5e-2) - if qkv_format == "bshd" or qkv_format == "sbhd": dq, dk, dv, out = [ x.view( @@ -309,32 +327,55 @@ def run_dpa_with_cp( else: assert False, f"{qkv_format} is an unsupported qkv_format!" + if dtype == "bf16": + if config.num_heads == config.num_gqa_groups: + tols = dict(atol=2.5e-2, rtol=2.5e-2) + else: + tols = dict(atol=3.5e-2, rtol=3.5e-2) + elif dtype == "fp16": + tols = dict(atol=5e-3, rtol=5e-3) + elif dtype == "fp8": + tols = dict(atol=5e-1, rtol=5e-1) + rmse_tol = 0.1 + else: + assert False, f"{dtype} is an unsupported dtype!" + + def _rmse(a, b): + return torch.sqrt((a - b).square().mean()).item() + + def _error(a, b): + if dtype != "fp8": + torch.testing.assert_close(a, b, **tols) + else: + try: + torch.testing.assert_close(a, b, **tols) + except Exception as e: + logging.debug(e) + + rmse = _rmse(a, b) + rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + assert ( + rmse < rmse_tol * rmse_range + ), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + rmse, rmse_tol * rmse_range, rmse_tol, rmse_range + ) + if qkv_format == "bshd": - torch.testing.assert_close(out_[:, 0], out[:, 0], **tols) - torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols) - torch.testing.assert_close(dk_[:, 0], dk[:, 0], **tols) - torch.testing.assert_close(dv_[:, 0], dv[:, 0], **tols) - torch.testing.assert_close(out_[:, 1], out[:, 1], **tols) - torch.testing.assert_close(dq_[:, 1], dq[:, 1], **tols) - torch.testing.assert_close(dk_[:, 1], dk[:, 1], **tols) - torch.testing.assert_close(dv_[:, 1], dv[:, 1], **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a[:, 0], b[:, 0]) + _error(a[:, 1], b[:, 1]) elif qkv_format == "sbhd": - torch.testing.assert_close(out_[0], out[0], **tols) - torch.testing.assert_close(dq_[0], dq[0], **tols) - torch.testing.assert_close(dk_[0], dk[0], **tols) - torch.testing.assert_close(dv_[0], dv[0], **tols) - torch.testing.assert_close(out_[1], out[1], **tols) - torch.testing.assert_close(dq_[1], dq[1], **tols) - torch.testing.assert_close(dk_[1], dk[1], **tols) - torch.testing.assert_close(dv_[1], dv[1], **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a[0], b[0]) + _error(a[1], b[1]) elif qkv_format == "thd": - torch.testing.assert_close(out_, out, **tols) - torch.testing.assert_close(dq_, dq, **tols) - torch.testing.assert_close(dk_, dk, **tols) - torch.testing.assert_close(dv_, dv, **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a, b) else: assert False, f"{qkv_format} is an unsupported qkv_format!" + dist.destroy_process_group() + def main(**kwargs): run_dpa_with_cp(**kwargs) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 0074d18cec..82875e2791 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -90,7 +90,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) +@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) @@ -121,8 +121,16 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): ) if config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip( - f"Fused attention does not support sliding window attention + context parallelism yet!" + "Fused attention does not support sliding window attention + context parallelism yet!" + ) + if cp_comm_type == "all_gather" and dtype == "fp8": + pytest.skip( + "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" ) + if dtype == "fp8" and qkv_format == "thd": + pytest.skip("FP8 attention cannot work with THD format yet!") + if dtype == "fp8" and config.attn_bias_type != "no_bias": + pytest.skip("FP8 attention cannot work with bias yet!") subprocess.run( get_bash_arguments( diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 71bc15fdad..8fac4778c8 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -95,6 +95,9 @@ META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +# repurpose some unused amax history buffers for partial results of CP fwd and bwd +META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT +META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) @@ -654,18 +657,6 @@ def get_attention_backend( logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False fused_attention_backend = None - if ( - use_fused_attention - and context_parallel - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] - ): - logger.debug( - "Disabling FusedAttention as only sub-backend %s does not support " - "context parallellism", - int(fused_attention_backend), - ) - use_fused_attention = False - fused_attention_backend = None if ( use_fused_attention and window_size is not None @@ -1322,6 +1313,8 @@ def forward( attn_bias, deterministic, use_fused_attention, + fp8, + fp8_meta, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -1407,6 +1400,43 @@ def forward( # synchronize fwd results correction across steps fwd_results_correction_done = torch.cuda.Event() + if fp8: + if use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_backend = FusedAttnBackend["FP8"] + if fp8_meta["recipe"].fp8_mha: + assert ( + isinstance(q, Float8Tensor) + and isinstance(k, Float8Tensor) + and isinstance(v, Float8Tensor) + ), "q/k/v must be Float8Tensors for FP8 MHA!" + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = q_fp8._data, k_fp8._data, v_fp8._data + else: + q_f16, k_f16, v_f16 = q, k, v + q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [k_f16, v_f16] + ] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV] + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S] + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S] + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + q_f16 = q + if use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + p2p_comm_buffers = [None for _ in range(cp_size)] if use_fused_attention and qkv_format in ["bshd", "sbhd"]: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) @@ -1433,7 +1463,23 @@ def forward( batch_p2p_comm, ) - kv_inputs[i % 2] = p2p_comm_buffers[i] + if ( + not fp8 + or fp8_meta["recipe"].fp8_mha + or int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ): + kv_inputs[i % 2] = p2p_comm_buffers[i] + else: + # KV exchange is in BF16/FP16, cast received KV in each step + kv_inputs[i % 2] = cast_to_fp8( + p2p_comm_buffers[i], + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + ) + if fp8 and use_fused_attention: + fp8_meta_kwargs["amax_s"] = amax_per_step[0][i] + fp8_meta_kwargs["amax_o"] = amax_per_step[1][i] if causal: if i == 0: if pad_between_seqs_q: @@ -1474,38 +1520,40 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) @@ -1572,42 +1620,44 @@ def forward( if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv // 2, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None - if cu_seqlens_kv_padded is None - else cu_seqlens_kv_padded // 2 - ), - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv // 2, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=( + None + if cu_seqlens_kv_padded is None + else cu_seqlens_kv_padded // 2 + ), + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) @@ -1693,42 +1743,44 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q // 2, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=( - None - if cu_seqlens_q_padded is None - else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q // 2, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=( + None + if cu_seqlens_q_padded is None + else cu_seqlens_q_padded // 2 + ), + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: if qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] @@ -1795,38 +1847,40 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q, - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, sq, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) @@ -1866,8 +1920,16 @@ def forward( softmax_lse_per_step[i - 1].squeeze_(-1) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): + if fp8: + out_per_step[i - 1] = cast_from_fp8( + out_per_step[i - 1], + fp8_meta["scaling_fwd"], + META_O_CP, + fp8_dtype_forward, + TE_DType[torch.float32], + ) if i == 1: - out = torch.zeros_like(q) + out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) if causal and qkv_format != "thd": # [b, np, sq] -> [b, np, 2, sq//2] @@ -1951,13 +2013,55 @@ def forward( else: out = out.view(-1, *out.shape[-2:]) + if fp8 and use_fused_attention: + amax_cp_fwd = amax_per_step.amax(dim=1) + fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0] + fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] + + out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype) + if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): + out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) + + if fp8 and fp8_meta["recipe"].fp8_mha: + out_ret = Float8Tensor( + data=out_fp8, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=q_fp8.dtype, + ) + else: + out_ret = out_f16 + + if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_save, kv_save, out_save = q, kv, out_fp8 + fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() + fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() + elif fp8 and fp8_meta["recipe"].fp8_mha: + kv_fp8 = Float8Tensor( + data=kv, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_QKV, + fp8_dtype=fp8_dtype_forward, + dtype=k_fp8.dtype, + ) + q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16 + fp8_fwd_scales, fp8_fwd_scale_invs = None, None + else: + q_save, kv_save, out_save = q_f16, kv, out_f16 + fp8_fwd_scales, fp8_fwd_scale_invs = None, None + ctx.save_for_backward( - q, - kv, - out, + q_save, + kv_save, + out_save, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, + fp8_fwd_scales, + fp8_fwd_scale_invs, *cu_seqlens_q_per_step, *cu_seqlens_kv_per_step, *rng_states, @@ -1976,7 +2080,9 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention - return out + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_meta = fp8_meta + return out_ret @staticmethod def backward(ctx, dout): @@ -1987,10 +2093,11 @@ def backward(ctx, dout): batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] - cu_seqlens_q_per_step = ctx.saved_tensors[6 : 6 + cp_size] - cu_seqlens_kv_per_step = ctx.saved_tensors[6 + cp_size : 6 + cp_size * 2] - rng_states = ctx.saved_tensors[6 + cp_size * 2 : 6 + cp_size * 3] - attn_biases = ctx.saved_tensors[6 + cp_size * 3 : 6 + cp_size * 4] + (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8] + cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size] + cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] + rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] + attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -2025,22 +2132,60 @@ def backward(ctx, dout): if ctx.use_fused_attention: # [b, np, sq//2] -> [b, np, sq//2, 1] softmax_lse_.unsqueeze_(-1) - if ctx.use_fused_attention: # [b, np, sq] -> [b, np, sq, 1] softmax_lse.unsqueeze_(-1) + + if ctx.fp8: + if ctx.use_fused_attention: + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + fused_attn_qkv_dtype = fp8_dtype_backward + fused_attn_dqkv_dtype = fp8_dtype_backward + fused_attn_backend = FusedAttnBackend["FP8"] + dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) + dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) + dkv_fp8_ = torch.empty_like(dkv_fp8) + dout_dtype = dout.dtype + if ctx.fp8_meta["recipe"].fp8_mha: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv + dout = dout._data + else: + dout = cast_to_fp8( + dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ) + p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] + fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] + fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] + fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] + fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] + fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] + fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] + fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + q, kv, dout = [x.from_float8(x.dtype) for x in [q, kv, dout]] + dq = torch.empty_like(q) + if ctx.qkv_format == "thd" and causal: + dq[cu_seqlens_q_padded[-1] :].fill_(0) + p2p_comm_buffers = [ + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + ] + p2p_comm_buffers[0][0].copy_(kv) + if ctx.use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_dqkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + out = out.view(*q.shape) dout = dout.view(*q.shape) - # Flash Attn outputs - dq = torch.empty_like(q) - if ctx.qkv_format == "thd" and causal: - dq[cu_seqlens_q_padded[-1] :].fill_(0) - - p2p_comm_buffers = [ - torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), - torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), - ] - p2p_comm_buffers[0][0].copy_(kv) send_recv_reqs = [] fa_optional_backward_kwargs = {} @@ -2056,18 +2201,40 @@ def backward(ctx, dout): send_tensor = p2p_comm_buffers[i % 2] recv_tensor = p2p_comm_buffers[(i + 1) % 2] - if i == 0: - send_tensor = send_tensor[0] - recv_tensor = recv_tensor[0] - if i == (cp_size - 1): - send_tensor = send_tensor[1] - recv_tensor = recv_tensor[1] - - send_recv_reqs = flash_attn_p2p_communicate( - rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm - ) + if ctx.fp8: + if i < cp_size - 1: + send_recv_reqs = flash_attn_p2p_communicate( + rank, + send_tensor[0], + send_dst, + recv_tensor[0], + recv_src, + ctx.cp_group, + batch_p2p_comm, + ) + else: + dkv_a2a_req = torch.distributed.all_to_all_single( + dkv_fp8, + dkv_fp8_, + group=ctx.cp_group, + async_op=True, + ) + send_recv_reqs = [dkv_a2a_req] + else: + if i == 0: + send_tensor = send_tensor[0] + recv_tensor = recv_tensor[0] + if i == (cp_size - 1): + send_tensor = send_tensor[1] + recv_tensor = recv_tensor[1] + send_recv_reqs = flash_attn_p2p_communicate( + rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm + ) kv = p2p_comm_buffers[i % 2][0] + if ctx.fp8 and ctx.use_fused_attention: + fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i] + fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i] # In reversed order of fwd if causal: if i == (cp_size - 1): @@ -2090,7 +2257,14 @@ def backward(ctx, dout): dout_ = dout.view(-1, *dout.shape[-3:]) elif ctx.qkv_format == "thd": q_, kv_, out_, dout_ = q, kv, out, dout - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2103,10 +2277,10 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, @@ -2114,6 +2288,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] @@ -2169,7 +2345,14 @@ def backward(ctx, dout): q_, out_, dout_ = q, out, dout # [2, t, np, hn] -> [2, t/2, np, hn] kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2182,10 +2365,10 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 @@ -2195,6 +2378,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type="padding" if padding else "no_mask", attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] @@ -2256,7 +2441,14 @@ def backward(ctx, dout): out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) kv_ = kv - aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2269,10 +2461,10 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=( None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 ), @@ -2282,6 +2474,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type="padding" if padding else "no_mask", attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: if ctx.qkv_format == "thd": @@ -2325,7 +2519,10 @@ def backward(ctx, dout): ) else: if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2338,10 +2535,10 @@ def backward(ctx, dout): kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], out, dout, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, @@ -2349,6 +2546,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, sq, np, hn] -> [b*sq, np, hn] @@ -2383,6 +2582,8 @@ def backward(ctx, dout): **fa_optional_backward_kwargs, ) + if ctx.fp8: + dq = dq_fp8[(rank + i + 1) % cp_size] if i >= (cp_size - rank - 1) or not causal: # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal # [b*sq, np, hn] -> [b, sq, np, hn] if not causal @@ -2395,7 +2596,17 @@ def backward(ctx, dout): # [b*sq//2, np, hn] -> [sq//2, b, np, hn] dq_ = dq_.view(-1, *dq.shape[-3:]) - if causal: + if ctx.fp8: + if i >= (cp_size - rank - 1) or not causal: + dq.copy_(dq_) + else: + if ctx.qkv_format == "bshd": + dq[:, 0, ...].fill_(0) + dq[:, 1, ...].copy_(dq_) + elif ctx.qkv_format == "sbhd": + dq[0].fill_(0) + dq[1].copy_(dq_) + elif causal: if i > (cp_size - rank - 1): dq.add_(dq_) elif i == (cp_size - rank - 1): @@ -2450,7 +2661,13 @@ def backward(ctx, dout): for req in send_recv_reqs: req.wait() - dkv = p2p_comm_buffers[(i + 1) % 2][1] + if ctx.fp8: + if i < cp_size - 1: + dkv = dkv_fp8_[(rank + i + 1) % cp_size] + else: + dkv = dkv_fp8[(rank + i + 1) % cp_size] + else: + dkv = p2p_comm_buffers[(i + 1) % 2][1] if ctx.use_fused_attention: dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) if ctx.qkv_format in ["bshd", "sbhd"]: @@ -2469,7 +2686,17 @@ def backward(ctx, dout): # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal dkv_ = dkv_.view(*dkv.shape) - if causal: + if ctx.fp8: + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + dkv[:, :, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) + dkv[:, 1, ...].fill_(0) + else: + dkv.copy_(dkv_) + elif causal: if i == (cp_size - 1): if rank == 0: if ctx.qkv_format == "bshd": @@ -2507,6 +2734,26 @@ def backward(ctx, dout): else: dkv.add_(dkv_) + if ctx.fp8 and ctx.use_fused_attention: + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0] + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1] + if ctx.qkv_format in ["bshd", "sbhd"]: + # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or + # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] + dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) + dq, dkv = [ + cast_from_fp8( + x, + ctx.fp8_meta["scaling_bwd"], + META_DQKV_CP, + fp8_dtype_backward, + TE_DType[torch.float32], + ) + for x in [dq_fp8, dkv_fp8] + ] + dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + if causal: if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] @@ -2527,6 +2774,25 @@ def backward(ctx, dout): dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) dkv = dkv_ + if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: + dq, dkv = [ + cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) + for x in [dq, dkv] + ] + dq, dk, dv = [ + Float8Tensor( + data=x, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=dout_dtype, + ) + for x in [dq, dkv[0], dkv[1]] + ] + else: + dk, dv = dkv[0], dkv[1] + if attn_dbias is not None: # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) @@ -2534,8 +2800,8 @@ def backward(ctx, dout): return ( None, dq, - dkv[0], - dkv[1], + dk, + dv, None, None, None, @@ -2553,12 +2819,14 @@ def backward(ctx, dout): attn_dbias, None, None, + None, + None, ) -@jit_fuser +@torch.compile def get_seq_chunk_ids_to_all_gathered_kv( - local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left + local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device ): """Compute sequence chunk ids to the all-gathered KV.""" seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv @@ -2569,7 +2837,7 @@ def get_seq_chunk_ids_to_all_gathered_kv( local_chunk_id - num_chunks + 1, local_chunk_id + 1, dtype=torch.int32, - device="cuda", + device=device, ) chunk_ids_to_all_gathered_kv = torch.where( chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1 @@ -2683,6 +2951,7 @@ def forward( if (window_size is None or window_size[0] == -1) else window_size[0] ), + k.device, ) chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag num_kv_chunks = chunk_ids_to_kv_ag.numel() @@ -3029,6 +3298,8 @@ def attn_forward_func_with_cp( deterministic=False, use_fused_attention=False, window_size=None, + fp8=False, + fp8_meta=None, ) -> torch.Tensor: """ Attention implementation with context parallelism. @@ -3109,6 +3380,8 @@ def attn_forward_func_with_cp( attn_bias, deterministic, use_fused_attention, + fp8, + fp8_meta, ) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") @@ -5638,9 +5911,21 @@ def forward( and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen) ) + if fp8: + assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( + f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" + " is required for FP8 attention!" + ) + assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!" + assert not context_parallel or fp8_meta["recipe"].reduce_amax, ( + "Amax reduction across TP+CP group is necessary when using context parallelism with" + " FP8!" + ) + if context_parallel: assert ( - fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + fp8 + or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen ), f"{fused_attention_backend} does not work with context parallelism!" assert core_attention_bias_type not in [ "alibi" @@ -5670,19 +5955,14 @@ def forward( attn_mask_type=attn_mask_type, attn_bias_type=core_attention_bias_type, attn_bias=core_attention_bias, + deterministic=self.deterministic, use_fused_attention=True, window_size=window_size, + fp8=fp8, + fp8_meta=fp8_meta, ) else: with self.attention_dropout_ctx(): - if fp8: - assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( - f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" - " is required for FP8 attention!" - ) - assert ( - fp8_meta is not None - ), "FP8 metadata fp8_meta is required for FP8 attention!" output = FusedAttnFunc.apply( self.training, max_seqlen_q, From bcf38d9eb424c682857b6154cdadd929eff9b2fe Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:56:19 -0700 Subject: [PATCH 67/72] [PyTorch] Add support for padding mask in `UnfusedDotProductAttention` (#1073) * add support for padding in UnfusedDPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add support for padding_causal/_bottom_right Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix padding_causal/_bottom_right Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * need to test max512 backend Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix mask logic in unfused Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use actual_seqlen for alibi/causal_bottom_right padding Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes and convert causal to causal_bottom_right for inference Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use causal in kv cache inference test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify get_alibi logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * simplify the non-padding path for get_alibi Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * avoid batch_size loop in generating padding_causal/_bottom_right masks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 6 +- .../common/fused_attn/fused_attn.cpp | 5 +- transformer_engine/pytorch/attention.py | 172 +++++++++++++----- transformer_engine/pytorch/softmax.py | 39 ++-- transformer_engine/pytorch/transformer.py | 2 +- 5 files changed, 155 insertions(+), 69 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a2023f539a..85cd4fc256 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1655,8 +1655,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ffn_hidden_size=4 * D, num_attention_heads=H, attn_input_format=input_format, - self_attn_mask_type="causal_bottom_right", - enc_dec_attn_mask_type="causal_bottom_right", + self_attn_mask_type="causal", + enc_dec_attn_mask_type="causal", layer_number=layer_number, attention_dropout=0.0, params_dtype=dtype, @@ -1670,7 +1670,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, qkv_format=input_format, layer_number=layer_number, attention_dropout=0.0, - attn_mask_type="causal_bottom_right", + attn_mask_type="causal", params_dtype=dtype, ) .cuda() diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 0fe62f8cb4..70f1fa409f 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -142,7 +142,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_ALIBI && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && sm_arch_ >= 90) || + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + sm_arch_ >= 90) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || ((cudnn_runtime_version >= 90000) && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8fac4778c8..6a46d6c3c1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -472,19 +472,25 @@ def get_attention_backend( use_fused_attention = False # Filter: Attention mask - # attn_mask_type | supported backends - # ------------------------------------------------------------------- - # no_mask | All - # padding | FlashAttention, FusedAttention - # causal | - # self-attention | All - # cross-attention | FusedAttention - # padding_causal | - # self-attention | FlashAttention, FusedAttention - # cross-attention | FusedAttention - # causal_bottom_right | All - # padding_causal_bottom_right | FlashAttention, FusedAttention - # arbitrary | UnfusedDotProductAttention + # attn_mask_type | attention_mask | supported backends + # ---------------------------------------------------------------------------------------- + # no_mask | None | All + # padding | | All + # self-attention | One tensor in shape [b, 1, 1, sq] | + # cross-attention | Tuple of two tensors in shapes | + # | [b, 1, 1, sq] and [b, 1, 1, skv] | + # causal | None | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # padding_causal | Same as "padding" | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # causal_bottom_right | None | All + # padding_causal_bottom_right | Same as "padding" | + # self-attention | | All + # cross-attention | | FlashAttention, UnfusedDotProductAttention + # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention + # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": if use_flash_attention: logger.debug("Disabling FlashAttention for arbitrary mask") @@ -492,9 +498,6 @@ def get_attention_backend( if use_fused_attention: logger.debug("Disabling FusedAttention for arbitrary mask") use_fused_attention = False - if use_unfused_attention and "padding" in attn_mask_type: - logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type) - use_unfused_attention = False if ( use_flash_attention and _flash_attn_2_1_plus @@ -780,7 +783,7 @@ def get_attention_backend( class InferenceParams: # pylint: disable=too-few-public-methods """ Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference. + to efficiently calculate and store the context during inference. Parameters ---------- @@ -886,6 +889,8 @@ def get_alibi( num_heads: int, max_seqlen_q: int, max_seqlen_kv: int, + actual_seqlens_q: Optional[torch.Tensor] = None, + actual_seqlens_kv: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, bias_dtype: Optional[torch.dtype] = None, bottom_right_alignment: bool = True, @@ -899,6 +904,10 @@ def get_alibi( Maximum sequence length for queries. max_seqlen_kv: int Maximum sequence length for keys and values. + actual_seqlens_q: Optional[torch.Tensor], default = `None` + Actual sequence lengths for queries, in shape [batch_size]. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + Actual sequence lengths for keys and values, in shape [batch_size]. alibi_slopes: Optional[torch.Tensor], default = `None` Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. bias_dtype: Optional[torch.dtype], default = `None` @@ -912,10 +921,12 @@ def get_alibi( alibi_slopes: torch.Tensor ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads]. alibi_bias: torch.Tensor - ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape, - then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if - `alibi_slopes` is in [batch_size, num_heads], then the bias is in - [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. + ALiBi bias in FP32 or `bias_dtype`. Its shape is + (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape, + and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or + (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in + [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and + `actual_seqlens_q` and `actual_seqlens_kv` are not `None`. """ global _alibi_cache if _alibi_cache["_alibi_slopes_require_update"]: @@ -941,17 +952,23 @@ def get_alibi( slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) if _alibi_cache["_alibi_slopes"].dim() == 2: slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) - if bottom_right_alignment: - bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv - ) - else: - bias = torch.arange( - 1 - max_seqlen_q, max_seqlen_kv - max_seqlen_q + 1, dtype=torch.int32, device="cuda" - ).view(1, 1, 1, max_seqlen_kv) - bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view( + bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv ) + if actual_seqlens_q is None and actual_seqlens_kv is None: + if bottom_right_alignment: + bias = bias + max_seqlen_kv - max_seqlen_q + elif actual_seqlens_q is not None and actual_seqlens_kv is not None: + batch_size = actual_seqlens_q.shape[0] + bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + if bottom_right_alignment: + bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) + else: + assert ( + False + ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!" bias = bias.abs().mul(-1) bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv @@ -3705,6 +3722,7 @@ class UnfusedDotProductAttention(torch.nn.Module): def __init__( self, softmax_scale: float, + attention_type: str = "self", attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, layer_number: Optional[int] = None, @@ -3712,6 +3730,7 @@ def __init__( super().__init__() self.softmax_scale = softmax_scale + self.attention_type = attention_type self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number @@ -3751,6 +3770,58 @@ def forward( query_layer, key_layer, value_layer = [ x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] ] + batch_size, max_seqlen_q, max_seqlen_kv = ( + query_layer.shape[1], + query_layer.shape[0], + key_layer.shape[0], + ) + if "padding" in attn_mask_type: + if self.attention_type == "self": + assert attention_mask.shape == ( + batch_size, + 1, + 1, + max_seqlen_q, + ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!" + attention_mask = torch.logical_or( + attention_mask.squeeze(1).unsqueeze(3), attention_mask + ) + else: + assert ( + len(attention_mask) == 2 + and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q) + and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv) + ), ( + "attention_mask should be a tuple of two tensors with shapes " + "[b, 1, 1, sq] and [b, 1, 1, skv]!" + ) + attention_mask = torch.logical_or( + attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] + ) + mask = attention_mask.squeeze(1).logical_not() + actual_seqlens_q = mask[:, :, 0].sum(dim=1) + actual_seqlens_kv = mask[:, 0, :].sum(dim=1) + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv + ) + if attn_mask_type == "padding_causal": + attention_mask = torch.logical_or( + torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0), + attention_mask, + ) + if attn_mask_type == "padding_causal_bottom_right": + attention_mask = torch.logical_or( + torch.where( + mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) + < 0, + 1, + 0, + ), + attention_mask, + ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 @@ -3805,7 +3876,7 @@ def forward( key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=scale, - ) + ).view(*output_size) elif core_attention_bias_type == "pre_scale_bias": assert core_attention_bias is not None, "core_attention_bias should not be None!" @@ -3813,10 +3884,7 @@ def forward( query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] ) - matmul_result = ( - matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3]) - + core_attention_bias - ).view(-1, output_size[2], output_size[3]) + matmul_result = matmul_result.view(*output_size) + core_attention_bias matmul_result *= scale elif core_attention_bias_type in ["post_scale_bias", "alibi"]: @@ -3827,6 +3895,8 @@ def forward( output_size[1], output_size[2], output_size[3], + actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, + actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) @@ -3837,26 +3907,21 @@ def forward( beta=0.0, alpha=scale, ) - matmul_result = ( - ( - matmul_result.view( - output_size[0], output_size[1], output_size[2], output_size[3] - ) - + core_attention_bias - ) - .view(-1, output_size[2], output_size[3]) - .to(dtype=query_layer.dtype) + matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to( + dtype=query_layer.dtype ) - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - # attention scores and attention mask [b, np, sq, sk] softmax_scale = self.layer_number if apply_qk_layer_scaling else None attention_probs = self.scale_mask_softmax( - attention_scores, attention_mask, attn_mask_type, softmax_scale + matmul_result, attention_mask, attn_mask_type, softmax_scale ) + # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q) + # the columns (pad tokens from k) are already zeroed out during softmax + if "padding" in attn_mask_type: + attention_probs = attention_probs.masked_fill(attention_mask, 0) + # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with self.attention_dropout_ctx(): @@ -6232,7 +6297,10 @@ def __init__( ) self.unfused_attention = UnfusedDotProductAttention( - softmax_scale, **attn_kwargs, layer_number=layer_number + softmax_scale, + attention_type=attention_type, + **attn_kwargs, + layer_number=layer_number, ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument @@ -6522,6 +6590,11 @@ def forward( if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" + # convert causal to causal_bottom_right in inference when KV-caching is in use + # so users can run with the same attn_mask_type for training and inference + if attn_mask_type in ["causal", "padding_causal"]: + attn_mask_type = attn_mask_type + "_bottom_right" + if qkv_format == "bshd": key_layer = key_layer.transpose(0, 1) value_layer = value_layer.transpose(0, 1) @@ -6628,7 +6701,6 @@ def forward( attention_mask is not None ), "Please provide attention_mask for padding!" if self.attention_type == "self": - assert max_seqlen_q == max_seqlen_kv cu_seqlens_q = get_cu_seqlens(attention_mask) cu_seqlens_kv = cu_seqlens_q else: diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 3632d2f367..4fb8a28857 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -329,25 +329,22 @@ def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: return False # sk must be 16 ~ 16384 if sk % 8 != 0: return False # sk must be divisor of 8 - if self.attn_mask_type == "arbitrary": - return False # Custom masks not supported - + if sq == 1: + return False # sq must be > 1 if self.attn_mask_type == "causal" and sq != sk: return False # Fused causal kernel only support causal_bottom_right if ( sq % 4 == 0 # sq must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4 - and self.attn_mask_type != "arbitrary" # Custom masks not supported ): batch_per_block = self.get_batch_per_block(int(sk)) - - if self.attn_mask_type == "padding": + if "padding" in self.attn_mask_type or self.attn_mask_type == "arbitrary": if ( mask is not None and sq % batch_per_block == 0 - and mask.shape[-2] == sq - and mask.shape[-1] == sk + and mask.shape[0] in [1, b] + and mask.shape[1:] == (1, sq, sk) ): return True else: @@ -358,13 +355,21 @@ def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: def forward_fused_softmax( self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None ) -> torch.Tensor: - """Fused masked softmax kernel""" + """ + Fused masked softmax path. + attn_mask_type | module + ----------------------------------------------------------------------------------------- + no_mask | ScaledSoftmax + causal (self-attention), causal_bottom_right | ScaledAlignedCausalMaskedSoftmax + padding, padding_causal, padding_causal_bottom_right | ScaledMaskedSoftmax + arbitrary ([1, 1, sq, sk] or [b, 1, sq, sk]) | ScaledMaskedSoftmax + """ scale = 1.0 if scale is None else scale - if "causal" in self.attn_mask_type: + if self.attn_mask_type in ["causal", "causal_bottom_right"]: return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) - # input is 4D tensor (b, np, sq, sk) + # input is 4D tensor (1, 1, sq, sk) or (b, 1, sq, sk) if mask is not None and self.attn_mask_type != "no_mask": return ScaledMaskedSoftmax.apply(inp, mask, scale) return ScaledSoftmax.apply(inp, scale) @@ -379,13 +384,19 @@ def forward_torch_softmax( if scale is not None: inp = inp * scale - if "causal" in self.attn_mask_type: + if self.attn_mask_type in ["causal", "causal_bottom_right"]: seq_len_q, seq_len_k = inp.size(2), inp.size(3) if is_in_onnx_export_mode() and self.kvcache_max_seq > 0: assert self.kvcache_max_seq >= seq_len_k - mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask) + causal_mask = _get_onnx_export_causal_mask( + seq_len_q, seq_len_k, self.onnx_causal_mask + ) + else: + causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + if mask is None: + mask = causal_mask else: - mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + mask = torch.logical_or(mask, causal_mask) mask_output = inp if mask is not None and self.attn_mask_type != "no_mask": diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 4cbee3d628..bd6e27594d 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -624,7 +624,7 @@ def forward( Whether to set output tensors to 0 or not before use. inference_params: InferenceParams, default = None Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference. + to efficiently calculate and store the context during inference. """ if self_attn_mask_type is None: From fc6e641b1d5b62d5c511e30652c3e14278d1930c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 21 Aug 2024 22:33:22 -0700 Subject: [PATCH 68/72] Re-add framework specific required dependencies for source build (#1124) * Re-add framework specific required dependencies for source build Signed-off-by: Kirthi Shankar Sivamani * fix build Signed-off-by: Kirthi Shankar Sivamani * Fix Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- setup.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/setup.py b/setup.py index e418cb95ff..6cee4690dc 100644 --- a/setup.py +++ b/setup.py @@ -89,6 +89,18 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if not found_pybind11(): setup_reqs.append("pybind11") + # Framework-specific requirements + if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + if "pytorch" in frameworks: + install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"]) + test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) + if "jax" in frameworks: + install_reqs.extend(["jax", "flax>=0.7.1"]) + test_reqs.extend(["numpy", "praxis"]) + if "paddle" in frameworks: + install_reqs.append("paddlepaddle-gpu") + test_reqs.append("numpy") + return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] From a37a36c21a0c94b0a7b356ff37df19d4fa89267b Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 27 Aug 2024 06:50:06 -0700 Subject: [PATCH 69/72] Hide non-necessary symbols from shared object (#1136) Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/CMakeLists.txt | 4 ++++ transformer_engine/common/libtransformer_engine.version | 4 ++++ 2 files changed, 8 insertions(+) create mode 100644 transformer_engine/common/libtransformer_engine.version diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 7fab75dca0..58bd4f828c 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -15,6 +15,10 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") endif() +# Hide non-necessary symbols in shared object. +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") + # Transformer Engine library project(transformer_engine LANGUAGES CUDA CXX) diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version new file mode 100644 index 0000000000..0683ec01ea --- /dev/null +++ b/transformer_engine/common/libtransformer_engine.version @@ -0,0 +1,4 @@ +{ + global: *nvte*; *transformer_engine*; + local: *; +}; From 61f8415f502e9f6bb2b0b58eb27d28921735acf3 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Thu, 29 Aug 2024 22:44:16 -0700 Subject: [PATCH 70/72] Fix QKV dtype in the bwd of FP8+CP (#1134) * fix qkv_dtype of FP8+CP Signed-off-by: Xiaowei Ren * config cp correction dtype of FP8+CP Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code style change Signed-off-by: Xiaowei Ren * always do FP8 CP correction in FP32 Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6a46d6c3c1..ff121527d3 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2155,8 +2155,9 @@ def backward(ctx, dout): if ctx.fp8: if ctx.use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - fused_attn_qkv_dtype = fp8_dtype_backward + fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_backend = FusedAttnBackend["FP8"] dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) @@ -2198,7 +2199,7 @@ def backward(ctx, dout): if ctx.use_fused_attention: fp8_meta_kwargs = {} fused_attn_qkv_dtype = TE_DType[q.dtype] - fused_attn_dqkv_dtype = TE_DType[q.dtype] + fused_attn_dqkv_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] out = out.view(*q.shape) From 669b8164b4cb4591ed01f8ba45b4aeebc090b334 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:14:41 -0700 Subject: [PATCH 71/72] Update cudnn-frontend to v1.6.1 (#1108) * update FE to 1.6 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update to 1.6.1-rc for testing Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to fe 1.6.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- .../common/fused_attn/fused_attn_fp8.cu | 30 +++++++++++++++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 98ca4e1941..2533f5e5c1 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019 +Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index bda3f5beba..fb7765e1a4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1835,8 +1835,14 @@ void fused_attn_fp8_fwd_impl_v1( generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); - amax_o->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_s->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_o->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_s->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); Stats->set_output(true) .set_data_type(fe::DataType_t::FLOAT) @@ -2182,10 +2188,22 @@ void fused_attn_fp8_bwd_impl_v1( dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); - amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dQ->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dK->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dV->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); dO->set_data_type(bwd_tensor_type); dQ->set_data_type(bwd_tensor_type); From a7e9d3e7d9015f9233c5e768263c8f7b9c26953e Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 3 Sep 2024 09:24:52 -0700 Subject: [PATCH 72/72] Improvements for building wheels (#1148) * Improvements for wheels Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * Fixes for wheel build Signed-off-by: Kirthi Shankar Sivamani * Move package finder to common Signed-off-by: Kirthi Shankar Sivamani * format Signed-off-by: Kirthi Shankar Sivamani * Fixes Signed-off-by: Kirthi Shankar Sivamani * Lint Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * FIx Signed-off-by: Kirthi Shankar Sivamani * Fix CI and distributed test Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * fix paddle ci Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- build_tools/utils.py | 3 +- build_tools/wheel_utils/Dockerfile.aarch | 2 +- build_tools/wheel_utils/Dockerfile.x86 | 2 +- build_tools/wheel_utils/build_wheels.sh | 56 +++++++---- qa/L0_jax_wheel/test.sh | 26 +++-- qa/L0_paddle_wheel/test.sh | 27 +++-- qa/L0_pytorch_wheel/test.sh | 26 +++-- qa/L1_pytorch_distributed_unittest/test.sh | 4 + setup.py | 109 ++++++++++++--------- transformer_engine/common/__init__.py | 11 +++ transformer_engine/jax/__init__.py | 35 ++++++- transformer_engine/paddle/__init__.py | 32 ++++++ transformer_engine/pytorch/__init__.py | 37 ++++++- 13 files changed, 280 insertions(+), 90 deletions(-) diff --git a/build_tools/utils.py b/build_tools/utils.py index 81b9a896cb..27ceea844b 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -296,7 +296,7 @@ def install_and_import(package): globals()[main_package] = importlib.import_module(main_package) -def uninstall_te_fw_packages(): +def uninstall_te_wheel_packages(): subprocess.check_call( [ sys.executable, @@ -304,6 +304,7 @@ def uninstall_te_fw_packages(): "pip", "uninstall", "-y", + "transformer_engine_cu12", "transformer_engine_torch", "transformer_engine_paddle", "transformer_engine_jax", diff --git a/build_tools/wheel_utils/Dockerfile.aarch b/build_tools/wheel_utils/Dockerfile.aarch index a0bcd80347..7d839958cb 100644 --- a/build_tools/wheel_utils/Dockerfile.aarch +++ b/build_tools/wheel_utils/Dockerfile.aarch @@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "false", "false", "true"] +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"] diff --git a/build_tools/wheel_utils/Dockerfile.x86 b/build_tools/wheel_utils/Dockerfile.x86 index 602d99ed4d..7dedf2a761 100644 --- a/build_tools/wheel_utils/Dockerfile.x86 +++ b/build_tools/wheel_utils/Dockerfile.x86 @@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"] +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"] diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 1896fc4e42..7682a2b6aa 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -5,10 +5,11 @@ set -e PLATFORM=${1:-manylinux_2_28_x86_64} -BUILD_COMMON=${2:-true} -BUILD_JAX=${3:-true} +BUILD_METAPACKAGE=${2:-true} +BUILD_COMMON=${3:-true} BUILD_PYTORCH=${4:-true} -BUILD_PADDLE=${5:-true} +BUILD_JAX=${5:-true} +BUILD_PADDLE=${6:-true} export NVTE_RELEASE_BUILD=1 export TARGET_BRANCH=${TARGET_BRANCH:-} @@ -20,12 +21,33 @@ cd /TransformerEngine git checkout $TARGET_BRANCH git submodule update --init --recursive +if $BUILD_METAPACKAGE ; then + cd /TransformerEngine + NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt + mv dist/* /wheelhouse/ +fi + if $BUILD_COMMON ; then + VERSION=`cat build_tools/VERSION.txt` + WHL_BASE="transformer_engine-${VERSION}" + + # Create the wheel. /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt + + # Repack the wheel for cuda specific package, i.e. cu12. + /opt/python/cp38-cp38/bin/wheel unpack dist/* + # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). + sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" + /opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE} + + # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) IFS='-' read -ra whl_parts <<< "$whl_name" - whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}" - mv dist/"$whl_name" /wheelhouse/"$whl_name_target" + whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}" + rm -rf $WHL_BASE dist + mv *.whl /wheelhouse/"$whl_name_target" fi if $BUILD_PYTORCH ; then @@ -37,8 +59,8 @@ fi if $BUILD_JAX ; then cd /TransformerEngine/transformer_engine/jax - /opt/python/cp38-cp38/bin/pip install jax jaxlib - /opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt + /opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib + /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt cp dist/* /wheelhouse/ fi @@ -48,30 +70,30 @@ if $BUILD_PADDLE ; then dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64 cd /TransformerEngine/transformer_engine/paddle - /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl + /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt - /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl + /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt - /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl + /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt - /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl + /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt - /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl + /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt - /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu mv dist/* /wheelhouse/ fi diff --git a/qa/L0_jax_wheel/test.sh b/qa/L0_jax_wheel/test.sh index 109633495b..2c3b832933 100644 --- a/qa/L0_jax_wheel/test.sh +++ b/qa/L0_jax_wheel/test.sh @@ -6,16 +6,30 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax + +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel + cd transformer_engine/jax -python setup.py sdist +NVTE_RELEASE_BUILD=1 python setup.py sdist -export NVTE_RELEASE_BUILD=0 pip install dist/* cd $TE_PATH -pip install dist/* +pip install dist/*.whl --no-deps python $TE_PATH/tests/jax/test_sanity_import.py diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh index e2d6d38dd4..30fbb1df1f 100644 --- a/qa/L0_paddle_wheel/test.sh +++ b/qa/L0_paddle_wheel/test.sh @@ -6,15 +6,28 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel==0.44.0 pydantic + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel -pip install dist/* -cd transformer_engine/paddle -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle -export NVTE_RELEASE_BUILD=0 +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel +pip install dist/*.whl --no-deps + +cd transformer_engine/paddle +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel pip install dist/* python $TE_PATH/tests/paddle/test_sanity_import.py diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index e108e93cdb..fd8457c44b 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -6,16 +6,30 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch + +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel + cd transformer_engine/pytorch -python setup.py sdist +NVTE_RELEASE_BUILD=1 python setup.py sdist -export NVTE_RELEASE_BUILD=0 pip install dist/* cd $TE_PATH -pip install dist/* +pip install dist/*.whl --no-deps python $TE_PATH/tests/pytorch/test_sanity_import.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index fef48fd4b0..50394c33a9 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -4,6 +4,10 @@ set -e +# pkg_resources is deprecated in setuptools 70+ and the packaging submodule +# has been removed from it. This is a temporary fix until upstream MLM fix. +pip install setuptools==69.5.1 + : ${TE_PATH:=/opt/transformerengine} pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py diff --git a/setup.py b/setup.py index 6cee4690dc..942f57d3c1 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,8 @@ remove_dups, get_frameworks, install_and_import, - uninstall_te_fw_packages, + remove_dups, + uninstall_te_wheel_packages, ) from build_tools.te_version import te_version @@ -105,46 +106,69 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if __name__ == "__main__": - # Dependencies - setup_requires, install_requires, test_requires = setup_requirements() - __version__ = te_version() - ext_modules = [setup_common_extension()] - if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): - # Remove residual FW packages since compiling from source - # results in a single binary with FW extensions included. - uninstall_te_fw_packages() - if "pytorch" in frameworks: - from build_tools.pytorch import setup_pytorch_extension - - ext_modules.append( - setup_pytorch_extension( - "transformer_engine/pytorch/csrc", - current_file_path / "transformer_engine" / "pytorch" / "csrc", - current_file_path / "transformer_engine", + with open("README.rst", encoding="utf-8") as f: + long_description = f.read() + + # Settings for building top level empty package for dependency management. + if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): + assert bool( + int(os.getenv("NVTE_RELEASE_BUILD", "0")) + ), "NVTE_RELEASE_BUILD env must be set for metapackage build." + ext_modules = [] + cmdclass = {} + package_data = {} + include_package_data = False + setup_requires = [] + install_requires = ([f"transformer_engine_cu12=={__version__}"],) + extras_require = { + "pytorch": [f"transformer_engine_torch=={__version__}"], + "jax": [f"transformer_engine_jax=={__version__}"], + "paddle": [f"transformer_engine_paddle=={__version__}"], + } + else: + setup_requires, install_requires, test_requires = setup_requirements() + ext_modules = [setup_common_extension()] + cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} + package_data = {"": ["VERSION.txt"]} + include_package_data = True + extras_require = {"test": test_requires} + + if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + # Remove residual FW packages since compiling from source + # results in a single binary with FW extensions included. + uninstall_te_wheel_packages() + if "pytorch" in frameworks: + from build_tools.pytorch import setup_pytorch_extension + + ext_modules.append( + setup_pytorch_extension( + "transformer_engine/pytorch/csrc", + current_file_path / "transformer_engine" / "pytorch" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) - if "jax" in frameworks: - from build_tools.jax import setup_jax_extension - - ext_modules.append( - setup_jax_extension( - "transformer_engine/jax/csrc", - current_file_path / "transformer_engine" / "jax" / "csrc", - current_file_path / "transformer_engine", + if "jax" in frameworks: + from build_tools.jax import setup_jax_extension + + ext_modules.append( + setup_jax_extension( + "transformer_engine/jax/csrc", + current_file_path / "transformer_engine" / "jax" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) - if "paddle" in frameworks: - from build_tools.paddle import setup_paddle_extension - - ext_modules.append( - setup_paddle_extension( - "transformer_engine/paddle/csrc", - current_file_path / "transformer_engine" / "paddle" / "csrc", - current_file_path / "transformer_engine", + if "paddle" in frameworks: + from build_tools.paddle import setup_paddle_extension + + ext_modules.append( + setup_paddle_extension( + "transformer_engine/paddle/csrc", + current_file_path / "transformer_engine" / "paddle" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) # Configure package setuptools.setup( @@ -157,13 +181,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: "transformer_engine/build_tools", ], ), - extras_require={ - "test": test_requires, - "pytorch": [f"transformer_engine_torch=={__version__}"], - "jax": [f"transformer_engine_jax=={__version__}"], - "paddle": [f"transformer_engine_paddle=={__version__}"], - }, + extras_require=extras_require, description="Transformer acceleration library", + long_description=long_description, + long_description_content_type="text/x-rst", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=">=3.8, <3.13", @@ -177,6 +198,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: setup_requires=setup_requires, install_requires=install_requires, license_files=("LICENSE",), - include_package_data=True, - package_data={"": ["VERSION.txt"]}, + include_package_data=include_package_data, + package_data=package_data, ) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index f4eb2c419f..46cfa9176a 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -4,6 +4,7 @@ """FW agnostic user-end APIs""" +import sys import glob import sysconfig import subprocess @@ -15,6 +16,16 @@ import transformer_engine +def is_package_installed(package): + """Checks if a pip package is installed.""" + return ( + subprocess.run( + [sys.executable, "-m", "pip", "show", package], capture_output=True, check=False + ).returncode + == 0 + ) + + def get_te_path(): """Find Transformer Engine install path using pip""" return Path(transformer_engine.__path__[0]).parent diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 3200c8a019..05adbd624c 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -5,21 +5,50 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging import ctypes +from importlib.metadata import version -from transformer_engine.common import get_te_path +from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_jax" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[jax]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[jax]==VERSION'", + module_name, + ) + extension = _get_sys_extension() try: so_dir = get_te_path() / "transformer_engine" - so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: so_dir = get_te_path() - so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) diff --git a/transformer_engine/paddle/__init__.py b/transformer_engine/paddle/__init__.py index 62fa1fe626..50cf2186d6 100644 --- a/transformer_engine/paddle/__init__.py +++ b/transformer_engine/paddle/__init__.py @@ -6,9 +6,41 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging +from importlib.metadata import version + +from transformer_engine.common import is_package_installed + def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_paddle" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[paddle]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[paddle]==VERSION'", + module_name, + ) + from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 20b6f79da6..07ade71905 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -6,25 +6,54 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging import importlib +import importlib.util import sys import torch +from importlib.metadata import version -from transformer_engine.common import get_te_path +from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_torch" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[pytorch]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[pytorch]==VERSION'", + module_name, + ) + extension = _get_sys_extension() try: so_dir = get_te_path() / "transformer_engine" - so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: so_dir = get_te_path() - so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) - module_name = "transformer_engine_torch" spec = importlib.util.spec_from_file_location(module_name, so_path) solib = importlib.util.module_from_spec(spec) sys.modules[module_name] = solib