diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 5b8917fcd48..b293adaca81 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -245,6 +245,7 @@ from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler +from .deprecation import deprecated from .memory import find_executable_batch_size, release_memory from .other import ( check_os_kernel, diff --git a/src/accelerate/utils/deprecation.py b/src/accelerate/utils/deprecation.py new file mode 100644 index 00000000000..9f70ef23b78 --- /dev/null +++ b/src/accelerate/utils/deprecation.py @@ -0,0 +1,80 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import textwrap +import warnings +from typing import Callable, TypeVar + +from typing_extensions import ParamSpec + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def deprecated(since: str, removed_in: str, instruction: str) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """Marks functions as deprecated. + + It will result in a warning when the function is called and a note in the docstring. + + Args: + since (`str`): + The version when the function was first deprecated. + removed_in (`str`): + The version when the function will be removed. + instruction (`str`): + The action users should take. + + Returns: + `Callable`: A decorator that will mark the function as deprecated. + """ + + def decorator(function: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(function) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + warnings.warn( + f"'{function.__module__}.{function.__name__}' " + f"is deprecated in version {since} and will be " + f"removed in {removed_in}. {instruction}.", + category=FutureWarning, + stacklevel=2, + ) + return function(*args, **kwargs) + + # Add a deprecation note to the docstring. + docstring = function.__doc__ or "" + + deprecation_note = textwrap.dedent( + f"""\ + .. deprecated:: {since} + Deprecated and will be removed in version {removed_in}. {instruction}. + """ + ) + + # Split docstring at first occurrence of newline + summary_and_body = docstring.split("\n\n", 1) + if len(summary_and_body) > 1: + summary, body = summary_and_body + body = textwrap.dedent(body) + new_docstring_parts = [deprecation_note, "\n\n", summary, body] + else: + summary = summary_and_body[0] + new_docstring_parts = [deprecation_note, "\n\n", summary] + + wrapper.__doc__ = "".join(new_docstring_parts) + + return wrapper + + return decorator diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 1e6b1c9c6f6..eff3966b0a3 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -31,6 +31,7 @@ from ..state import AcceleratorState from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from .dataclasses import AutocastKwargs, CustomDtype, DistributedType +from .deprecation import deprecated from .imports import ( is_mlu_available, is_mps_available, @@ -471,11 +472,8 @@ class FindTiedParametersResult(list): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @deprecated(since="1.0.0rc0", removed_in="1.3.0", instruction="use another method instead") def values(self): - warnings.warn( - "The 'values' method of FindTiedParametersResult is deprecated and will be removed in Accelerate v1.3.0. ", - FutureWarning, - ) return sum([x[1:] for x in self], []) diff --git a/tests/test_utils.py b/tests/test_utils.py index f240d64402a..2d9f969837f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,6 +14,7 @@ import os import pickle import tempfile +import textwrap import unittest import warnings from collections import UserDict, namedtuple @@ -57,6 +58,7 @@ save, send_to_device, ) +from accelerate.utils.deprecation import deprecated from accelerate.utils.operations import is_namedtuple @@ -417,6 +419,83 @@ def test_convert_dict_to_env_variables(self): valid_env_items = convert_dict_to_env_variables(env) assert valid_env_items == ["ACCELERATE_DEBUG_MODE=1\n", "OTHER_ENV=2\n"] + def test_deprecated(self): + @deprecated("0.2.0", "0.3.0", "toy instruction") + def long_deprecated_demo(arg1: int, arg2: int) -> tuple: + """This is a long summary. This is a long summary. This is a long + summary. This is a long summary. + + Args: + arg1 (int): Description. + arg2 (int): Description. + + Returns: + Description. + """ + return arg1, arg2 + + with pytest.warns( + FutureWarning, match="deprecated in version 0.2.0 and will be removed in 0.3.0. toy instruction." + ): + self.assertEqual((1, 2), long_deprecated_demo(1, 2)) + + long_expected_docstring = textwrap.dedent(""" + .. deprecated:: 0.2.0 + Deprecated and will be removed in version 0.3.0. toy instruction. + + This is a long summary. This is a long summary. This is a long + summary. This is a long summary. + + Args: + arg1 (int): Description. + arg2 (int): Description. + + Returns: + Description. + """) + + long_expected_docstring = "".join(long_expected_docstring.split()) + long_actual_docstring = "".join(long_deprecated_demo.__doc__.split()) + + self.assertEqual(long_expected_docstring, long_actual_docstring) + + @deprecated("0.2.0", "0.3.0", "toy instruction") + def short_deprecated_demo(): + """Short summary.""" + + short_expected_docstring = textwrap.dedent(""" + .. deprecated:: 0.2.0 + Deprecated and will be removed in version 0.3.0. toy instruction. + + Short summary. + """) + short_expected_docstring = "".join(short_expected_docstring.split()) + short_actual_docstring = "".join(short_deprecated_demo.__doc__.split()) + + self.assertEqual(short_expected_docstring, short_actual_docstring) + + @deprecated("0.2.0", "0.3.0", "toy instruction") + class OldClass: + """Old class docstring.""" + + def method(self): + pass + + with pytest.warns( + FutureWarning, match="deprecated in version 0.2.0 and will be removed in 0.3.0. toy instruction." + ): + OldClass() + + class_expected_docstring = textwrap.dedent(""" + .. deprecated:: 0.2.0 + Deprecated and will be removed in version 0.3.0. toy instruction. + Old class docstring. + """) + class_expected_docstring = "".join(class_expected_docstring.split()) + class_actual_docstring = "".join(OldClass.__doc__.split()) + + self.assertEqual(class_expected_docstring, class_actual_docstring) + def test_has_offloaded_params(self): model = RegressionModel() assert not has_offloaded_params(model)