diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 67fc4db0d0..6a2a801efd 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -516,6 +516,12 @@ def checkpoint( kwargs : dict dictionary of string keys for keyword arguments to :attr:`function`. """ + only_tensor_args = True + for arg in args: + if not isinstance(arg, torch.Tensor): + only_tensor_args = False + break + # Pop out te.distributed.checkpoint() arguments global _USE_REENTRANT_ACTIVATION_RECOMPUTE _USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True) @@ -523,6 +529,27 @@ def checkpoint( tp_group = kwargs.pop("tp_group", None) get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None) + # Ensure backward compatibility. + if not only_tensor_args: + warnings.warn( + "Passing non-tensor non-keyword arguments is deprecated and support will be removed in " + "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and " + "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.", + DeprecationWarning, stacklevel=2, + ) + assert len(args) > 3, "Incorrect number of arguments for deprecated `checkpoint` API." + assert ( + isinstance(args[0], bool) and callable(args[1]) + and isinstance(args[2], None | dist_group_type) + ), "Incorrect arguments for deprecated `checkpoint` API." + for arg in args[3:]: + assert ( + isinstance(arg, None | torch.Tensor) + ), f"Expected tensor argument, found {type(arg)}." + + distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking + args = args[3:] + # Trigger the native PyTorch checkpoint if: # 1. `function` is a `torch.nn.Module` # AND @@ -555,16 +582,6 @@ def checkpoint( assert torch.distributed.is_initialized(), "torch.distributed is not initialized." tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group - # Make sure at least one tensor input has `requires_grad=True` - input_requires_grad = False - for arg in args: - if isinstance(arg, torch.Tensor) and arg.requires_grad: - input_requires_grad = True - break - assert input_requires_grad, ( - "`use_reentrant=True` requires at least one input tensor with `requires_grad=True`." - ) - return _CheckpointFunction.apply( function, distribute_saved_activations,