diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 6a2a801efd..239cecf39b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -516,12 +516,6 @@ def checkpoint( kwargs : dict dictionary of string keys for keyword arguments to :attr:`function`. """ - only_tensor_args = True - for arg in args: - if not isinstance(arg, torch.Tensor): - only_tensor_args = False - break - # Pop out te.distributed.checkpoint() arguments global _USE_REENTRANT_ACTIVATION_RECOMPUTE _USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True) @@ -530,23 +524,14 @@ def checkpoint( get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None) # Ensure backward compatibility. - if not only_tensor_args: + if (len(args) > 3 and isinstance(args[0], bool) and callable(args[1]) + and isinstance(args[2], None | dist_group_type)): warnings.warn( "Passing non-tensor non-keyword arguments is deprecated and support will be removed in " "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and " "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.", DeprecationWarning, stacklevel=2, ) - assert len(args) > 3, "Incorrect number of arguments for deprecated `checkpoint` API." - assert ( - isinstance(args[0], bool) and callable(args[1]) - and isinstance(args[2], None | dist_group_type) - ), "Incorrect arguments for deprecated `checkpoint` API." - for arg in args[3:]: - assert ( - isinstance(arg, None | torch.Tensor) - ), f"Expected tensor argument, found {type(arg)}." - distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking args = args[3:]