Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Jan 31, 2025
2 parents 824c0a8 + eccb2fb commit 80ac87f
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 1,238 deletions.
16 changes: 16 additions & 0 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ def inner(cls, fake_mode, elem, device, constant=None):
Patch(FakeTensor, fake_tensor_new_wrapper(FakeTensor.__new__), "__new__")
)

def autoamp_init_wrapper(fn):

@wraps(fn)
def inner(self, device_type, dtype=None, **kwargs):

if device_type == "meta":
dtype = torch.get_autocast_cpu_dtype()

return fn(self, device_type, dtype, **kwargs)

return inner

DEFAULT_PATCHER.add(
Patch(torch.autocast, autoamp_init_wrapper(torch.autocast.__init__), "__init__")
)

DEFAULT_PATCHER.__enter__()

from .intervention.contexts import GlobalInterventionTracingContext
Expand Down
356 changes: 0 additions & 356 deletions src/nnsight/models/LanguageModel.py

This file was deleted.

Loading

0 comments on commit 80ac87f

Please sign in to comment.