Skip to content

Commit

Permalink
Merge pull request #178 from ndif-team/meta-device
Browse files Browse the repository at this point in the history
Meta device
  • Loading branch information
JadenFiotto-Kaufman authored Jul 23, 2024
2 parents dabe1c3 + 30b8409 commit 041d18b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
7 changes: 3 additions & 4 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
for key, value in getmembers(math, isbuiltin):
DEFAULT_PATCHER.add(Patch(math, proxy_wrapper(value), key))

# TODO THis does not work. Because of accelerate also patching? because they are overloaded?
#DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.zeros), "zeros"))
# DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.ones), "ones"))
# DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.rand), "rand"))
DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.zeros), "zeros"))
DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.ones), "ones"))
DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.rand), "rand"))

from torch._subclasses.fake_tensor import FakeTensor

Expand Down
4 changes: 1 addition & 3 deletions src/nnsight/models/NNsightModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def __init__(

# Otherwise load from _load(...).
if not self._custom_model:
# accelerate.init_empty_weights makes all parameters loaded on the 'meta' device.
# Also do .to('meta') because why not.
with accelerate.init_empty_weights(include_buffers=True):
with torch.device('meta'):
self._model = self._load(self._model_key, *args, **kwargs).to("meta")

self._envoy = Envoy(self._model)
Expand Down

0 comments on commit 041d18b

Please sign in to comment.