Skip to content

Commit

Permalink
Fix warning from torch.load starting in torch 2.4 (#1064)
Browse files Browse the repository at this point in the history
* Fix warning from torch.load starting in torch 2.4

See discussion in #1063

Starting from PyTorch 2.4, there is a warning when torch.load is called
without setting the weights_only argument. This is because in the
future, the default will switch from False to True, which can result in
a lot of errors when trying to load torch files (which are pickle files
and thus insecure).

In this PR, we add a possibility for the user to influence the kwargs
passed to torch.load so that they can control that behavior. If not
further indicated by the user, we will use the same defaults as the
installed torch version. Therefore, users will only encounter this issue
via skorch if they would have encountered it via torch anyway.

Since it's not 100% certain if the default will switch in torch 2.6.0,
we may have to adjust the version check in the future.

Besides directly testing the kwargs being passed on, a test was also
added that net.load_params does not give any warnings. This is already
indirectly tested through some accelerate tests that are currently
failing with torch 2.4, but it's better to have an explicit test.

After this is merged, the CI should pass when using torch 2.4.0.

* Reviewer feedback: return kwargs directly

* Reviewer feedback: One more test w/o monkeypatch

Instead, rely on the installed torch version and skip if it doesn't fit.

* Reviewer feedback: rename function, fix typo
  • Loading branch information
BenjaminBossan authored Sep 19, 2024
1 parent 9252477 commit e724424
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 1 deletion.
36 changes: 35 additions & 1 deletion skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from skorch.utils import to_device
from skorch.utils import to_numpy
from skorch.utils import to_tensor
from skorch.utils import get_default_torch_load_kwargs


# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -235,6 +236,33 @@ class NeuralNet:
callbacks.
Implementation note: It is the job of the callbacks to honor this setting.
torch_load_kwargs : dict or None (default=None)
Additional arguments that will be passed to torch.load when load pickled
parameters.
In particular, this is important to because PyTorch will switch (probably
in version 2.6.0) to only allow weights to be loaded for security reasons
(i.e weights_only switches from False to True). As a consequence, loading
pickled parameters may raise an error after upgrading torch because some
types are used that are considered insecure. In skorch, we will also make
that switch at the same time. To resolve the error, follow the
instructions in the torch error message to designate the offending types
as secure. Only do this if you trust the source of the file.
If you want to keep loading non-weight types the same way as before,
please pass:
torch_load_kwargs={'weights_only': False}
You should be aware that this is considered insecure and should only be
used if you trust the source of the file. However, this does not introduce
new insecurities, it rather corresponds to the status quo from before
torch made the switch.
Another way to avoid this issue is to pass use_safetensors=True when
calling save_params and load_params. This avoid using pickle in favor of
the safetensors format, which is secure by design.
Attributes
----------
prefixes_ : list of str
Expand Down Expand Up @@ -311,6 +339,7 @@ def __init__(
device='cpu',
compile=False,
use_caching='auto',
torch_load_kwargs=None,
**kwargs
):
self.module = module
Expand All @@ -330,6 +359,7 @@ def __init__(
self.device = device
self.compile = compile
self.use_caching = use_caching
self.torch_load_kwargs = torch_load_kwargs

self._check_deprecated_params(**kwargs)
history = kwargs.pop('history', None)
Expand Down Expand Up @@ -2620,10 +2650,14 @@ def _get_state_dict(f_name):

return state_dict
else:
torch_load_kwargs = self.torch_load_kwargs
if torch_load_kwargs is None:
torch_load_kwargs = get_default_torch_load_kwargs()

def _get_state_dict(f_name):
map_location = get_map_location(self.device)
self.device = self._check_device(self.device, map_location)
return torch.load(f_name, map_location=map_location)
return torch.load(f_name, map_location=map_location, **torch_load_kwargs)

kwargs_full = {}
if checkpoint is not None:
Expand Down
108 changes: 108 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest.mock import patch
import sys
import time
import warnings
from contextlib import ExitStack

from flaky import flaky
Expand All @@ -30,6 +31,7 @@
import torch
from torch import nn

import skorch
from skorch.tests.conftest import INFERENCE_METHODS
from skorch.utils import flatten
from skorch.utils import to_numpy
Expand Down Expand Up @@ -561,6 +563,17 @@ def test_load_params_unknown_attribute_raises(self, net_fit):
with pytest.raises(AttributeError, match=msg):
net_fit.load_params(f_unknown='some-file.pt')

def test_load_params_no_warning(self, net_fit, tmp_path, recwarn):
# See discussion in 1063
# Ensure that there is no FutureWarning (and DeprecationWarning for good
# measure) caused by torch.load.
net_fit.save_params(f_params=tmp_path / 'weights.pt')
net_fit.load_params(f_params=tmp_path / 'weights.pt')
assert not any(
isinstance(warning.message, (DeprecationWarning, FutureWarning))
for warning in recwarn.list
)

@pytest.mark.parametrize('use_safetensors', [False, True])
def test_save_load_state_dict_file(
self, net_cls, module_cls, net_fit, data, tmpdir, use_safetensors):
Expand Down Expand Up @@ -2983,6 +2996,101 @@ def test_save_load_state_dict_custom_module(
weights_loaded = net_new.custom_.state_dict()['sequential.3.weight']
assert (weights_before == weights_loaded).all()

def test_torch_load_kwargs_auto_weights_only_false_when_load_params(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Here we assume that the torch version is low enough that weights_only
# defaults to False. Check that when no argument is set in skorch, the
# right default is used.
# See discussion in 1063
net = net_cls(module_cls).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()
expected_kwargs = {"weights_only": False}

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)
monkeypatch.setattr(
skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs
)

net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_torch_load_kwargs_auto_weights_only_true_when_load_params(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Here we assume that the torch version is high enough that weights_only
# defaults to True. Check that when no argument is set in skorch, the
# right default is used.
# See discussion in 1063
net = net_cls(module_cls).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()
expected_kwargs = {"weights_only": True}

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)
monkeypatch.setattr(
skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs
)

net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_torch_load_kwargs_forwarded_to_torch_load(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Here we check that custom set torch load args are forwarded to
# torch.load.
# See discussion in 1063
expected_kwargs = {'weights_only': 123, 'foo': 'bar'}
net = net_cls(module_cls, torch_load_kwargs=expected_kwargs).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)

net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Same test as
# test_torch_load_kwargs_auto_weights_only_false_when_load_params but
# without monkeypatching get_default_torch_load_kwargs. There is no
# corresponding test for >= 2.6.0 since it's not clear yet if the switch
# will be made in that version.
# See discussion in 1063.
from skorch._version import Version

if Version(torch.__version__) >= Version('2.6.0'):
pytest.skip("Test only for torch < v2.6.0")

net = net_cls(module_cls).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()
expected_kwargs = {"weights_only": False}

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)
net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_custom_module_params_passed_to_optimizer(
self, net_custom_module_cls, module_cls):
# custom module parameters should automatically be passed to the optimizer
Expand Down
16 changes: 16 additions & 0 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from skorch.exceptions import DeviceWarning
from skorch.exceptions import NotInitializedError
from ._version import Version

try:
import torch_geometric
Expand Down Expand Up @@ -768,3 +769,18 @@ def _check_f_arguments(caller_name, **kwargs):
key = 'module_' if key == 'f_params' else key[2:] + '_'
kwargs_module[key] = val
return kwargs_module, kwargs_other


def get_default_torch_load_kwargs():
"""Returns the kwargs passed to torch.load that correspond to the current
torch version.
The plan is to switch from weights_only=False to True in PyTorch version
2.6.0, but depending on what happens, this may require updating.
"""
version_torch = Version(torch.__version__)
version_default_switch = Version('2.6.0')
if version_torch >= version_default_switch:
return {"weights_only": True}
return {"weights_only": False}

0 comments on commit e724424

Please sign in to comment.