Skip to content

Commit

Permalink
grad_mode decorators without paren (pytorch#107086)
Browse files Browse the repository at this point in the history
This PR implements the feature described in pytorch#107036 for `no_grad`, `enable_grad` and `inference_mode`.

Users can still use the above as before but they can also use them without parentheses.

For example:

```python
import torch

a = torch.ones(1, requires_grad=True)

def do_something():
    print(2 * a)

with torch.no_grad():
    do_something()  # tensor([2.])

torch.no_grad()(do_something)()  # tensor([2.])

torch.no_grad(do_something)()  # tensor([2.])

do_something()  # tensor([2.], grad_fn=<MulBackward0>)
```

For `inference_mode`, decorating without parenthesis is equivalent to decorating with the default `mode=True`, similiar to how dataclasses behave (https://docs.python.org/3/library/dataclasses.html#module-contents)

Closes pytorch#107036

Pull Request resolved: pytorch#107086
Approved by: https://github.com/albanD
  • Loading branch information
andreasfloros authored and pytorchmergebot committed Aug 15, 2023
1 parent ba1da47 commit c9c9076
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 21 deletions.
45 changes: 31 additions & 14 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,25 +1972,38 @@ def test_no_grad(self):
with torch.no_grad():
w = x + y

@torch.no_grad()
def adder(x, y):
return x + y

z = adder(x, y)
adders = [torch.no_grad()(adder), torch.no_grad(adder)]

self.assertFalse(w.requires_grad)
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
self.assertIsNone(w.grad_fn)
self.assertFalse(z.requires_grad)
self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
self.assertIsNone(z.grad_fn)
for adder in adders:
z = adder(x, y)

self.assertFalse(w.requires_grad)
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
self.assertIsNone(w.grad_fn)
self.assertFalse(z.requires_grad)
self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
self.assertIsNone(z.grad_fn)

# test nested decorator and with-statement on no_grad
with torch.no_grad():
self.assertFalse(torch.is_grad_enabled())
w = adder(x, y)
self.assertFalse(torch.is_grad_enabled())

def test_enable_grad_decorator_no_paren(self):
x = torch.ones(1, requires_grad=True)

@torch.enable_grad
def doubler(x):
return x * 2

with torch.no_grad():
z = doubler(x)
self.assertTrue(z.requires_grad)

def test_set_grad_generator_functions(self):
@torch.no_grad()
def gen_no_grad():
Expand Down Expand Up @@ -10160,15 +10173,19 @@ def test_inference_mode_context_manager(self):
self.assertFalse(torch.is_inference_mode_enabled())

def test_inference_mode_decorator(self):
for mode in (True, False):
@torch.inference_mode(mode)
def func(x):
self.assertEqual(torch.is_inference_mode_enabled(), mode)
return x * x
def func(x):
self.assertEqual(torch.is_inference_mode_enabled(), mode)
return x * x
for mode in (True, False, None):
if mode is None:
decorated = torch.inference_mode(func)
mode = True
else:
decorated = torch.inference_mode(mode)(func)

for requires_grad in (True, False):
c = torch.ones(1, 2, 3, requires_grad=requires_grad)
d = func(c)
d = decorated(c)
self.assertTrue(not mode or torch.is_inference(d))
self.assertEqual(d.requires_grad, requires_grad and not mode)

Expand Down
42 changes: 35 additions & 7 deletions torch/autograd/grad_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import torch

from torch.utils._contextlib import _DecoratorContextManager
from torch.utils._contextlib import (
_DecoratorContextManager,
_NoParamDecoratorContextManager,
)

__all__ = [
"no_grad",
Expand All @@ -13,7 +16,7 @@
]


class no_grad(_DecoratorContextManager):
class no_grad(_NoParamDecoratorContextManager):
r"""Context-manager that disables gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure
Expand All @@ -29,7 +32,7 @@ class no_grad(_DecoratorContextManager):
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
Also functions as a decorator.
.. note::
No-grad is one of several mechanisms that can enable or
Expand All @@ -54,6 +57,12 @@ class no_grad(_DecoratorContextManager):
>>> z = doubler(x)
>>> z.requires_grad
False
>>> @torch.no_grad
... def tripler(x):
... return x * 3
>>> z = tripler(x)
>>> z.requires_grad
False
>>> # factory function exception
>>> with torch.no_grad():
... a = torch.nn.Parameter(torch.rand(10))
Expand All @@ -74,7 +83,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch.set_grad_enabled(self.prev)


class enable_grad(_DecoratorContextManager):
class enable_grad(_NoParamDecoratorContextManager):
r"""Context-manager that enables gradient calculation.
Enables gradient calculation, if it has been disabled via :class:`~no_grad`
Expand All @@ -83,7 +92,7 @@ class enable_grad(_DecoratorContextManager):
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
Also functions as a decorator.
.. note::
enable_grad is one of several mechanisms that can enable or
Expand Down Expand Up @@ -111,6 +120,13 @@ class enable_grad(_DecoratorContextManager):
... z = doubler(x)
>>> z.requires_grad
True
>>> @torch.enable_grad
... def tripler(x):
... return x * 3
>>> with torch.no_grad():
... z = tripler(x)
>>> z.requires_grad
True
"""

Expand Down Expand Up @@ -191,15 +207,16 @@ class inference_mode(_DecoratorContextManager):
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
Also functions as a decorator.
.. note::
Inference mode is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
Args:
mode (bool): Flag whether to enable or disable inference mode
mode_or_func (bool or function): Either a boolean flag whether to enable or disable inference mode
or a Python function to decorate with inference mode enabled
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
Expand All @@ -220,6 +237,12 @@ class inference_mode(_DecoratorContextManager):
>>> out = func(x)
>>> out.requires_grad
False
>>> @torch.inference_mode
... def doubler(x):
... return x * 2
>>> out = doubler(x)
>>> out.requires_grad
False
"""

Expand All @@ -230,6 +253,11 @@ def __init__(self, mode: bool = True) -> None:
self._inference_mode_raii_context: Optional[torch._C._InferenceMode] = None
self.mode = mode

def __new__(cls, mode_or_func=True):
if isinstance(mode_or_func, bool):
return super().__new__(cls)
return cls()(mode_or_func)

def __enter__(self) -> None:
self._inference_mode_context = torch._C._InferenceMode(self.mode)
self._inference_mode_context.__enter__()
Expand Down
9 changes: 9 additions & 0 deletions torch/utils/_contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,12 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
def clone(self):
# override this method if your children class takes __init__ parameters
return self.__class__()


class _NoParamDecoratorContextManager(_DecoratorContextManager):
"""Allow a context manager to be used as a decorator without parentheses"""

def __new__(cls, orig_func=None):
if orig_func is None:
return super().__new__(cls)
return cls()(orig_func)

0 comments on commit c9c9076

Please sign in to comment.