Skip to content

Commit

Permalink
Merge pull request #103 from ndif-team/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Apr 1, 2024
2 parents 180fbf5 + bdbf682 commit 6ca5db2
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
17 changes: 17 additions & 0 deletions src/nnsight/envoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ def __init__(self, module: torch.nn.Module, module_path: str = ""):

self._add_envoy(module, name)

def _update(self, module: torch.nn.Module) -> None:
"""Updates the ._model attribute using a new model of the same architecture.
Used when loading the real weights (dispatching) and need to replace the underlying modules.
"""

self._module = module

self._hook_handle.remove()

self._hook_handle = self._module.register_forward_hook(
self._hook, with_kwargs=True
)

for i, module in enumerate(self._module.children()):

self._sub_envoys[i]._update(module)

def _add_envoy(self, module: torch.nn.Module, name: str):

envoy = Envoy(module, module_path=f"{self._module_path}.{name}")
Expand Down
11 changes: 5 additions & 6 deletions src/nnsight/models/NNsightModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,12 @@ def __init__(
# Also do .to('meta') because why not.
with accelerate.init_empty_weights(include_buffers=True):
self._model = self._load(self._model_key, *args, **kwargs).to("meta")

self._envoy = Envoy(self._model)

if dispatch and not self._dispatched:
# Dispatch ._model on initialization vs lazy dispatching.
self.dispatch_model()

else:
self._envoy = Envoy(self._model)
self.dispatch_model()

logger.info(f"Initialized `{self._model_key}`")

Expand Down Expand Up @@ -267,13 +266,13 @@ def interleave(
return output

def dispatch_model(self, *args, **kwargs) -> None:
"""Dispatch ._model to have real parameters using .load(...)."""
"""Dispatch ._model to have real parameters using ._load(...)."""

logger.info(f"Dispatching `{self._model_key}`...")

self._model = self._load(self._model_key, *self._args, *args, **kwargs, **self._kwargs)

self._envoy = Envoy(self._model)
self._envoy._update(self._model)

self._dispatched = True

Expand Down
9 changes: 5 additions & 4 deletions src/nnsight/tracing/Node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import inspect
import weakref
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)

Expand Down Expand Up @@ -133,12 +134,12 @@ def __init__(
# (unless the arg is already .done(), for when you want to apply things to proxies after model execution?)
util.apply(
self.args,
lambda x: x.listeners.append(self) if not x.done() else None,
lambda x: x.listeners.append(weakref.proxy(self)) if not x.done() else None,
Node,
)
util.apply(
self.kwargs,
lambda x: x.listeners.append(self) if not x.done() else None,
lambda x: x.listeners.append(weakref.proxy(self)) if not x.done() else None,
Node,
)

Expand Down Expand Up @@ -290,7 +291,7 @@ def execute(self) -> None:

tensor: torch.Tensor = args[0]
backward_idx: int = args[1]

hook = None

def grad(value):
Expand All @@ -307,7 +308,7 @@ def grad(value):
value = self.graph.get_swap(value)

backward_idx = -1

hook.remove()

return value
Expand Down

0 comments on commit 6ca5db2

Please sign in to comment.