From c9c90765c1a4b0eb7649f9a0dcedcf0b9e4ebaae Mon Sep 17 00:00:00 2001 From: andreasfloros <77194848+andreasfloros@users.noreply.github.com> Date: Tue, 15 Aug 2023 05:25:29 +0000 Subject: [PATCH] grad_mode decorators without paren (#107086) This PR implements the feature described in #107036 for `no_grad`, `enable_grad` and `inference_mode`. Users can still use the above as before but they can also use them without parentheses. For example: ```python import torch a = torch.ones(1, requires_grad=True) def do_something(): print(2 * a) with torch.no_grad(): do_something() # tensor([2.]) torch.no_grad()(do_something)() # tensor([2.]) torch.no_grad(do_something)() # tensor([2.]) do_something() # tensor([2.], grad_fn=) ``` For `inference_mode`, decorating without parenthesis is equivalent to decorating with the default `mode=True`, similiar to how dataclasses behave (https://docs.python.org/3/library/dataclasses.html#module-contents) Closes #107036 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107086 Approved by: https://github.com/albanD --- test/test_autograd.py | 45 +++++++++++++++++++++++++------------ torch/autograd/grad_mode.py | 42 ++++++++++++++++++++++++++++------ torch/utils/_contextlib.py | 9 ++++++++ 3 files changed, 75 insertions(+), 21 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index b028ea1ec5a6f..305aac4b9618b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1972,18 +1972,20 @@ def test_no_grad(self): with torch.no_grad(): w = x + y - @torch.no_grad() def adder(x, y): return x + y - z = adder(x, y) + adders = [torch.no_grad()(adder), torch.no_grad(adder)] - self.assertFalse(w.requires_grad) - self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) - self.assertIsNone(w.grad_fn) - self.assertFalse(z.requires_grad) - self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5))) - self.assertIsNone(z.grad_fn) + for adder in adders: + z = adder(x, y) + + self.assertFalse(w.requires_grad) + self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) + self.assertIsNone(w.grad_fn) + self.assertFalse(z.requires_grad) + self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5))) + self.assertIsNone(z.grad_fn) # test nested decorator and with-statement on no_grad with torch.no_grad(): @@ -1991,6 +1993,17 @@ def adder(x, y): w = adder(x, y) self.assertFalse(torch.is_grad_enabled()) + def test_enable_grad_decorator_no_paren(self): + x = torch.ones(1, requires_grad=True) + + @torch.enable_grad + def doubler(x): + return x * 2 + + with torch.no_grad(): + z = doubler(x) + self.assertTrue(z.requires_grad) + def test_set_grad_generator_functions(self): @torch.no_grad() def gen_no_grad(): @@ -10160,15 +10173,19 @@ def test_inference_mode_context_manager(self): self.assertFalse(torch.is_inference_mode_enabled()) def test_inference_mode_decorator(self): - for mode in (True, False): - @torch.inference_mode(mode) - def func(x): - self.assertEqual(torch.is_inference_mode_enabled(), mode) - return x * x + def func(x): + self.assertEqual(torch.is_inference_mode_enabled(), mode) + return x * x + for mode in (True, False, None): + if mode is None: + decorated = torch.inference_mode(func) + mode = True + else: + decorated = torch.inference_mode(mode)(func) for requires_grad in (True, False): c = torch.ones(1, 2, 3, requires_grad=requires_grad) - d = func(c) + d = decorated(c) self.assertTrue(not mode or torch.is_inference(d)) self.assertEqual(d.requires_grad, requires_grad and not mode) diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index 627989f313b42..5834e602fc582 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -2,7 +2,10 @@ import torch -from torch.utils._contextlib import _DecoratorContextManager +from torch.utils._contextlib import ( + _DecoratorContextManager, + _NoParamDecoratorContextManager, +) __all__ = [ "no_grad", @@ -13,7 +16,7 @@ ] -class no_grad(_DecoratorContextManager): +class no_grad(_NoParamDecoratorContextManager): r"""Context-manager that disables gradient calculation. Disabling gradient calculation is useful for inference, when you are sure @@ -29,7 +32,7 @@ class no_grad(_DecoratorContextManager): This context manager is thread local; it will not affect computation in other threads. - Also functions as a decorator. (Make sure to instantiate with parenthesis.) + Also functions as a decorator. .. note:: No-grad is one of several mechanisms that can enable or @@ -54,6 +57,12 @@ class no_grad(_DecoratorContextManager): >>> z = doubler(x) >>> z.requires_grad False + >>> @torch.no_grad + ... def tripler(x): + ... return x * 3 + >>> z = tripler(x) + >>> z.requires_grad + False >>> # factory function exception >>> with torch.no_grad(): ... a = torch.nn.Parameter(torch.rand(10)) @@ -74,7 +83,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: torch.set_grad_enabled(self.prev) -class enable_grad(_DecoratorContextManager): +class enable_grad(_NoParamDecoratorContextManager): r"""Context-manager that enables gradient calculation. Enables gradient calculation, if it has been disabled via :class:`~no_grad` @@ -83,7 +92,7 @@ class enable_grad(_DecoratorContextManager): This context manager is thread local; it will not affect computation in other threads. - Also functions as a decorator. (Make sure to instantiate with parenthesis.) + Also functions as a decorator. .. note:: enable_grad is one of several mechanisms that can enable or @@ -111,6 +120,13 @@ class enable_grad(_DecoratorContextManager): ... z = doubler(x) >>> z.requires_grad True + >>> @torch.enable_grad + ... def tripler(x): + ... return x * 3 + >>> with torch.no_grad(): + ... z = tripler(x) + >>> z.requires_grad + True """ @@ -191,7 +207,7 @@ class inference_mode(_DecoratorContextManager): This context manager is thread local; it will not affect computation in other threads. - Also functions as a decorator. (Make sure to instantiate with parenthesis.) + Also functions as a decorator. .. note:: Inference mode is one of several mechanisms that can enable or @@ -199,7 +215,8 @@ class inference_mode(_DecoratorContextManager): more information on how they compare. Args: - mode (bool): Flag whether to enable or disable inference mode + mode_or_func (bool or function): Either a boolean flag whether to enable or disable inference mode + or a Python function to decorate with inference mode enabled Example:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) @@ -220,6 +237,12 @@ class inference_mode(_DecoratorContextManager): >>> out = func(x) >>> out.requires_grad False + >>> @torch.inference_mode + ... def doubler(x): + ... return x * 2 + >>> out = doubler(x) + >>> out.requires_grad + False """ @@ -230,6 +253,11 @@ def __init__(self, mode: bool = True) -> None: self._inference_mode_raii_context: Optional[torch._C._InferenceMode] = None self.mode = mode + def __new__(cls, mode_or_func=True): + if isinstance(mode_or_func, bool): + return super().__new__(cls) + return cls()(mode_or_func) + def __enter__(self) -> None: self._inference_mode_context = torch._C._InferenceMode(self.mode) self._inference_mode_context.__enter__() diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py index d34d2b9e15e70..f58235f5852a0 100644 --- a/torch/utils/_contextlib.py +++ b/torch/utils/_contextlib.py @@ -141,3 +141,12 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def clone(self): # override this method if your children class takes __init__ parameters return self.__class__() + + +class _NoParamDecoratorContextManager(_DecoratorContextManager): + """Allow a context manager to be used as a decorator without parentheses""" + + def __new__(cls, orig_func=None): + if orig_func is None: + return super().__new__(cls) + return cls()(orig_func)