Skip to content

Commit

Permalink
Merge pull request #179 from ndif-team/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Jul 23, 2024
2 parents da36d3c + 041d18b commit c1f52c5
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 15 deletions.
7 changes: 3 additions & 4 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
for key, value in getmembers(math, isbuiltin):
DEFAULT_PATCHER.add(Patch(math, proxy_wrapper(value), key))

# TODO THis does not work. Because of accelerate also patching? because they are overloaded?
#DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.zeros), "zeros"))
# DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.ones), "ones"))
# DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.rand), "rand"))
DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.zeros), "zeros"))
DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.ones), "ones"))
DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.rand), "rand"))

from torch._subclasses.fake_tensor import FakeTensor

Expand Down
4 changes: 3 additions & 1 deletion src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def grad(self, value: Union[InterventionProxy, Any]) -> None:
"""
self.node.add(target="swap", args=[self.grad.node, value], value=True)

self.__dict__["_grad"] = None

def __call__(self, *args, **kwargs) -> Self:

# We don't want to call backward on fake tensors
Expand Down Expand Up @@ -148,7 +150,7 @@ def __setattr__(
) -> None:

if key == "grad":
getattr(self.__class__, key).fset(self, value)
return getattr(self.__class__, key).fset(self, value)

return super().__setattr__(key, value)

Expand Down
4 changes: 1 addition & 3 deletions src/nnsight/models/NNsightModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def __init__(

# Otherwise load from _load(...).
if not self._custom_model:
# accelerate.init_empty_weights makes all parameters loaded on the 'meta' device.
# Also do .to('meta') because why not.
with accelerate.init_empty_weights(include_buffers=True):
with torch.device('meta'):
self._model = self._load(self._model_key, *args, **kwargs).to("meta")

self._envoy = Envoy(self._model)
Expand Down
18 changes: 18 additions & 0 deletions src/nnsight/pydantics/format/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ def compile(self, graph: Graph, nodes: Dict[str, NodeModel]) -> slice:
self.step.compile(graph, nodes),
)

class EllipsisModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
type_name: Literal["ELLIPSIS"] = "ELLIPSIS"

def compile(self, graph: Graph, nodes: Dict[str, NodeModel]) -> type(...): # It will be better to use EllipsisType, but it requires python>=3.10
return ...



class ListModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
Expand Down Expand Up @@ -165,6 +173,14 @@ def compile(self, graph: Graph, nodes: Dict[str, NodeModel]) -> FUNCTION:
),
]

EllipsisType = Annotated[
type(...), # It will be better to use EllipsisType, but it requires python>=3.10
AfterValidator(
lambda value: EllipsisModel()
),
]


ListType = Annotated[list, AfterValidator(lambda value: ListModel(values=value))]

TupleType = Annotated[
Expand Down Expand Up @@ -201,6 +217,7 @@ def compile(self, graph: Graph, nodes: Dict[str, NodeModel]) -> FUNCTION:
ListModel,
TupleModel,
DictModel,
EllipsisModel
],
Field(discriminator="type_name"),
],
Expand All @@ -212,5 +229,6 @@ def compile(self, graph: Graph, nodes: Dict[str, NodeModel]) -> FUNCTION:
ListType,
TupleType,
DictType,
EllipsisType
],
]
13 changes: 7 additions & 6 deletions src/nnsight/tracing/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import inspect
import weakref
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)

import torch
from torch._subclasses.fake_tensor import FakeTensor
Expand Down Expand Up @@ -45,8 +46,6 @@ def __init__(
args = list()
if kwargs is None:
kwargs = dict()

args = list(args)

self.name = name
self.graph: "Graph" = graph
Expand All @@ -64,7 +63,9 @@ def __init__(

# Resolve values from completed tracer/runner contexts
self.args = util.apply(self.args, lambda x: x.value if x.done() else x, Node)
self.kwargs = util.apply(self.kwargs, lambda x: x.value if x.done() else x, Node)
self.kwargs = util.apply(
self.kwargs, lambda x: x.value if x.done() else x, Node
)

# Add all arguments that are nodes to nodes dependencies
# (unless the arg is already .done(), for when you want to apply things to proxies after model execution?)
Expand Down Expand Up @@ -187,8 +188,8 @@ def prepare_inputs(
Returns:
Any: Prepared inputs.
"""
inputs = util.apply(inputs, lambda x : x, object)

inputs = util.apply(inputs, lambda x: x, inspect._empty)

def _value(node: Proxy | Node):

Expand Down
8 changes: 7 additions & 1 deletion src/nnsight/tracing/Proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,19 @@ def __rtruediv__(self, other: Union[Proxy, Any]) -> Self:
target=operator.truediv,
args=[other, self.node],
)

def __floordiv__(self, other: Union[Proxy, Any]) -> Self:
return self.node.add(
target=operator.floordiv,
args=[self.node, other],
)

def __rfloordiv__(self, other: Union[Proxy, Any]) -> Self:
return self.node.add(
target=operator.floordiv,
args=[other, self.node],
)

def __eq__(self, other: Union[Proxy, Any]) -> Self:
return self.node.add(target=operator.eq, args=[self.node, other])

Expand Down
55 changes: 55 additions & 0 deletions tests/test_tiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from collections import OrderedDict

import pytest
import torch

import nnsight
from nnsight import NNsight

input_size = 5
hidden_dims = 10
output_size = 2


@pytest.fixture(scope="module")
def tiny_model(device: str):

net = torch.nn.Sequential(
OrderedDict(
[
("layer1", torch.nn.Linear(input_size, hidden_dims)),
("layer2", torch.nn.Linear(hidden_dims, output_size)),
]
)
)

return NNsight(net).to(device)


@pytest.fixture
def tiny_input():
return torch.rand((1, input_size))


@torch.no_grad()
def test_tiny(tiny_model: NNsight, tiny_input: torch.Tensor):

with tiny_model.trace(tiny_input):

hs = tiny_model.layer2.output.save()

assert isinstance(hs.value, torch.Tensor)


def test_grad_setting(tiny_model: NNsight, tiny_input: torch.Tensor):
with tiny_model.trace(tiny_input):
l1_grad = tiny_model.layer1.output.grad.clone().save()

tiny_model.layer1.output.grad = tiny_model.layer1.output.grad.clone() * 2

l1_grad_double = tiny_model.layer1.output.grad.save()

loss = tiny_model.output.sum()
loss.backward()

assert torch.equal(l1_grad.value * 2, l1_grad_double.value)

0 comments on commit c1f52c5

Please sign in to comment.