Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming protocol #250

Closed
wants to merge 12 commits into from
100 changes: 93 additions & 7 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# #
# :::: ::: :::: ::: :::::::: ::::::::::: :::::::: ::: ::: ::::::::::: ::::::: :::::::: #
# :+:+: :+: :+:+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: #
Expand All @@ -8,10 +8,10 @@
# #+# #+#+# #+# #+#+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #
# ### #### ### #### ######## ########### ######## ### ### ### ####### ### ######## #
# #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
import os
from functools import wraps
from typing import Dict, Union
from typing import Callable, Dict, Union

import torch
import yaml
Expand Down Expand Up @@ -49,11 +49,11 @@
from torch._subclasses.fake_tensor import FakeTensor


def _bool(self):
def fake_bool(self):
return True


DEFAULT_PATCHER.add(Patch(FakeTensor, _bool, "__bool__"))
DEFAULT_PATCHER.add(Patch(FakeTensor, fake_bool, "__bool__"))


def fake_tensor_new_wrapper(fn):
Expand Down Expand Up @@ -111,10 +111,11 @@ def noop(input: torch.Tensor, *args, **kwargs):
)

import warnings

_str = str
_bool = bool

try:



from torch.amp.autocast_mode import autocast, is_autocast_available

Expand Down Expand Up @@ -548,3 +549,88 @@ def set_module_tensor_to_device(
apply = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply
log = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.log
cond = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.cond

import inspect

from . import util
from .intervention import InterventionProxy


def trace(fn: Callable):
"""Helper decorator to add a function to the intervention graph via `.apply(...)`.
This is opposed to entering the function during tracing and tracing all inner operations.

Args:
fn (Callable): Function to apply.

Returns:
Callable: Traceable function.
"""

@wraps(fn)
def inner(*args, **kwargs):

return apply(fn, *args, **kwargs)

return inner

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this pattern defined in numerous locations (__init__.py, GraphBasedContext.py). Do you think this could be reused as a util function (with a descriptive docstring)?


def local(object: Callable | InterventionProxy):
"""Helper decorator to add a function to the intervention graph via `.apply(...)`
AND convert all input Proxies to local ones via `.local()`.

If a non-function is passed in, its assumed to be an `InterventionProxy` and `.local()` is called and returned.

Args:
object ( Callable | InterventionProxy): Function to apply or Proxy to make local.

Returns:
Callable | InterventionProxy: Traceable local function or local Proxy.
"""

if inspect.isroutine(object):

fn = trace(object)

@wraps(fn)
def inner(*args, **kwargs):

args, kwargs = util.apply(
(args, kwargs), lambda x: x.local(), InterventionProxy
)

return fn(*args, **kwargs)

return inner

return object.local()


def remote(object: Callable | Any):
"""Helper decorator to add a function to the intervention graph via `.apply(...)`
AND convert all input Proxies to downloaded local ones via `.local()`
AND convert the output to an uploaded remote one via `remote()`.

If a non-function is passed in, `remote(object)` is called and returned.

Args:
object ( Callable | Any): Function to apply or object to make remote.

Returns:
Callable | InterventionProxy: Traceable local -> remote function or remote Proxy.
"""

if inspect.isroutine(object):

fn = local(object)

@wraps(fn)
def inner(*args, **kwargs):

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.remote(
fn(*args, **kwargs)
)

return inner

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.remote(object)
12 changes: 12 additions & 0 deletions src/nnsight/contexts/GraphBasedContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ def log(self, *data: Any) -> None:
data (Any): Data to print.
"""
self.apply(print, *data)

def remote(self, data:Any) -> InterventionProxy:
"""Streams data remotely when it becomes available locally.
The remote service will block until the local value is uploaded and received.

Is a no-op when not executing remotely.

Returns:
InterventionProxy: Proxy.
"""

return protocols.StreamingUploadProtocol.add(self.graph, data)

def bool(self, *args, **kwargs) -> InterventionProxy:
"""NNsight helper method to create a traceable bool."""
Expand Down
4 changes: 4 additions & 0 deletions src/nnsight/contexts/Tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if TYPE_CHECKING:
from ..models.mixins import RemoteableMixin
from ..models.NNsightModel import NNsight
from ..tracing.Node import Node


class Tracer(GraphBasedContext, RemoteMixin, BridgeMixin, EditMixin):
Expand Down Expand Up @@ -179,6 +180,9 @@ def remote_backend_handle_result_value(self, value: Dict[str, Any]) -> None:
# TODO : graph mismatch handle. hash json ?
for node_name, node_value in value.items():
self.graph.nodes[node_name]._value = node_value

def remote_backend_get_stream_node(self, name: str, graph_id: str) -> "Node":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that the function name provides context for when the function is called (and perhaps the only actual usecase), but why not use a more general name, like get_node()? It is just a getter function afterall

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, it overwrites the function signature in RemoteBackend, and you addressed it with a comment. Is there a reason why this needs to be inherited from RemoteBackend, instead of from GraphBasedContext or just defined directly in this class? (If there is a good reason, I feel like it should be documented)

return self.graph.nodes[name]

def remote_backend_cleanup(self):

Expand Down
Loading
Loading