From 4df84889cb5743113b0fa59839b941486df16ace Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 30 Sep 2024 19:33:23 -0700 Subject: [PATCH] Removed the unused options from GroupedLinear docs and fixed the bug with offsets (#1220) * Removing the unused options from GroupedLinear docs and fixing the bug with offsets Signed-off-by: Przemyslaw Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * offsets -> fp8_meta_offsets Signed-off-by: Przemyslaw Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemyslaw Tredak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../pytorch/module/grouped_linear.py | 91 ++++++------------- 1 file changed, 27 insertions(+), 64 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 0bad1306c3..14edd64249 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -44,18 +44,6 @@ __all__ = ["GroupedLinear"] -""" -The offset for fp8_meta_index. -_GEMM_INPUT = 0 -_GEMM_WEIGHT = num_gemms -_GEMM_OUTPUT = 2 * num_gemms -Must be properly set in GroupedLinear's initialization. -""" -_GEMM_INPUT = 0 -_GEMM_WEIGHT = 0 -_GEMM_OUTPUT = 0 -_GRAD_OUTPUT = 0 - class _GroupedLinear(torch.autograd.Function): """GroupedLinear semi-top level module @@ -74,12 +62,9 @@ def forward( fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, cpu_offloading: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, sequence_parallel: bool, - tensor_parallel: bool, activation_dtype: torch.dtype, - parallel_mode: Union[str, None], + fp8_meta_offsets: Dict[str, int], is_grad_enabled: bool, weights_fp8: List[Union[Float8Tensor, None]], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], @@ -103,7 +88,6 @@ def forward( inputmats_t = [] inputmat_scale_inv = None - global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device) @@ -114,7 +98,9 @@ def forward( and not sequence_parallel ): # FP8 input for forward, FP8 input transpose for backward wgrad - indices = list(range(_GEMM_INPUT, _GEMM_INPUT + num_gemms)) + indices = list( + range(fp8_meta_offsets["input"], fp8_meta_offsets["input"] + num_gemms) + ) inputmats, inputmats_t = fp8_multi_cast_transpose_fused( inputmats_no_fp8, fp8_meta["scaling_fwd"], @@ -130,7 +116,7 @@ def forward( cast_to_fp8( inputmats_no_fp8[i], fp8_meta["scaling_fwd"], - _GEMM_INPUT + i, + fp8_meta_offsets["input"] + i, fp8_dtype_forward, scale_inv=inputmat_scale_inv, ) @@ -194,14 +180,14 @@ def forward( for i in range(num_gemms): # amax of input amin, amax = inputmats[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][_GEMM_INPUT + i] = torch.max( - -amin, amax - ).float() + fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["input"] + i] = ( + torch.max(-amin, amax).float() + ) # amax of weight amin, amax = weights[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][_GEMM_WEIGHT + i] = torch.max( - -amin, amax - ).float() + fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["weight"] + i] = ( + torch.max(-amin, amax).float() + ) out = torch.empty( [sum(m_splits), weights[0].size(0)], @@ -266,11 +252,8 @@ def forward( ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel - ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp.shape - ctx.parallel_mode = parallel_mode - ctx.tp_group = tp_group - ctx.tp_size = tp_size + ctx.fp8_meta_offsets = fp8_meta_offsets ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): @@ -300,7 +283,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], w.main_grad = main_grads[i] weights[i] = w - global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT # preprocess grad_output grad_output = grad_output.contiguous() grad_output_mats = torch.split( @@ -318,13 +300,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], fp8_cast_transpose_bgrad_fused( grad_output_mats[i], ctx.fp8_meta["scaling_bwd"], - _GRAD_OUTPUT + i, + ctx.fp8_meta_offsets["grad_output"] + i, fp8_dtype_backward, ) ) else: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - indices = list(range(_GRAD_OUTPUT, _GRAD_OUTPUT + ctx.num_gemms)) + indices = list( + range( + ctx.fp8_meta_offsets["grad_output"], + ctx.fp8_meta_offsets["grad_output"] + ctx.num_gemms, + ) + ) grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused( grad_output_mats, ctx.fp8_meta["scaling_bwd"], @@ -338,7 +325,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_c[i] = cast_to_fp8( grad_output_mats[i], ctx.fp8_meta["scaling_bwd"], - _GRAD_OUTPUT + i, + ctx.fp8_meta_offsets["grad_output"] + i, fp8_dtype_backward, ) @@ -363,7 +350,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weights_fp8[0]._fp8_dtype, grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, - _GRAD_OUTPUT, + ctx.fp8_meta_offsets["grad_output"], fp8_dtype_backward, [dgrad], ctx.activation_dtype, @@ -416,7 +403,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, - _GRAD_OUTPUT, + ctx.fp8_meta_offsets["grad_output"], fp8_dtype_backward, wgrad_list, ctx.activation_dtype, @@ -497,12 +484,9 @@ def handle_custom_ddp_from_mcore(w, wgrad): None, # fp8_meta None, # fuse_wgrad_accumulation None, # cpu_offloading - None, # tp_group - None, # tp_size None, # sequence_parallel - None, # tensor_parallel None, # activation_dtype - None, # parallel_mode + None, # fp8_meta_offsets None, # is_grad_enabled None, # weights_fp8 *wgrad_list, @@ -536,23 +520,6 @@ class GroupedLinear(TransformerEngineBaseModule): responsibility to ensure all parameters are moved to the GPU before running the forward pass. - Parallelism parameters - ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - tp_size : int, default = 1 - used as TP (tensor parallel) world size when TP groups are not formed during - initialization. In this case, users must call the - `set_tensor_parallel_group(tp_group)` method on the initialized module before the - forward pass to supply the tensor parallel group needed for tensor and sequence - parallel collectives. - parallel_mode : {None, 'column', 'row'}, default = `None` - used to decide whether this GroupedLinear layer is Column Parallel Linear or Row - Parallel Linear as described `here `_. - When set to `None`, no communication is performed. - Optimization parameters ----------------------- fuse_wgrad_accumulation : bool, default = 'False' @@ -613,8 +580,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT - _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, num_gemms, 2 * num_gemms + self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0} if tp_group is None: self.tp_size = tp_size @@ -651,7 +617,7 @@ def __init__( ), init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=_GEMM_WEIGHT + i, + fp8_meta_index=self._offsets["weight"] + i, ) # Construct bias parameters if needed @@ -774,7 +740,7 @@ def forward( weight_tensors_fp8[i] = self.get_fp8_workspace( tensor=weight_tensors[i], fp8_meta_forward=True, - fp8_meta_index=_GEMM_WEIGHT + i, + fp8_meta_index=self._offsets["weight"] + i, cache_name=(None if is_first_microbatch is None else f"weight{i}"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, @@ -798,12 +764,9 @@ def forward( self.fp8_meta, self.fuse_wgrad_accumulation, CPUOffloadEnabled, - self.tp_group, - self.tp_size, self.sequence_parallel, - self.tp_size > 1, self.activation_dtype, - self.parallel_mode, + self._offsets, torch.is_grad_enabled(), weight_tensors_fp8, *weight_tensors,