From 08382bbdc55bb165fcc28c7fb23b394269068bf5 Mon Sep 17 00:00:00 2001 From: Adam Belfki Date: Tue, 23 Jul 2024 16:04:50 -0400 Subject: [PATCH 1/2] Use meta device to load parameters --- src/nnsight/models/NNsightModel.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/nnsight/models/NNsightModel.py b/src/nnsight/models/NNsightModel.py index 32d32204..058772c4 100644 --- a/src/nnsight/models/NNsightModel.py +++ b/src/nnsight/models/NNsightModel.py @@ -66,9 +66,7 @@ def __init__( # Otherwise load from _load(...). if not self._custom_model: - # accelerate.init_empty_weights makes all parameters loaded on the 'meta' device. - # Also do .to('meta') because why not. - with accelerate.init_empty_weights(include_buffers=True): + with torch.device('meta'): self._model = self._load(self._model_key, *args, **kwargs).to("meta") self._envoy = Envoy(self._model) From 30b8409b9dae8572c3e8cb4437c87f0fc6414d72 Mon Sep 17 00:00:00 2001 From: Adam Belfki Date: Tue, 23 Jul 2024 16:06:01 -0400 Subject: [PATCH 2/2] Add support for more torch functions via patching --- src/nnsight/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/nnsight/__init__.py b/src/nnsight/__init__.py index bde37c2e..a0dff340 100644 --- a/src/nnsight/__init__.py +++ b/src/nnsight/__init__.py @@ -33,10 +33,9 @@ for key, value in getmembers(math, isbuiltin): DEFAULT_PATCHER.add(Patch(math, proxy_wrapper(value), key)) -# TODO THis does not work. Because of accelerate also patching? because they are overloaded? -#DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.zeros), "zeros")) -# DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.ones), "ones")) -# DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.rand), "rand")) +DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.zeros), "zeros")) +DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.ones), "ones")) +DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.rand), "rand")) from torch._subclasses.fake_tensor import FakeTensor