Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Visualization 0.4 #283

Merged
merged 7 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/nnsight/intervention/contexts/tracer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import inspect
import weakref
from functools import wraps
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
TypeVar, Union)
from typing import Any, Callable, Dict, Optional, TypeVar, Union

from ...tracing.contexts import Tracer
from ...tracing.graph import Proxy
from ..graph import (InterventionNodeType, InterventionProxy,
InterventionProxyType)
from . import LocalContext
from ... import CONFIG


class InterventionTracer(Tracer[InterventionNodeType, InterventionProxyType]):
"""Extension of base Tracer to add additional intervention functionality and type hinting for intervention proxies.
"""
Expand Down Expand Up @@ -48,3 +44,18 @@ def inner(*args, **kwargs):

# TODO: error
pass

@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.

Returns:
- Dict: dictionary style.
"""

default_style = super().style()

default_style["node"] = {"color": "purple", "shape": "polygon", "sides": 6}
default_style["arg_kname"][1] = "method"

return default_style
34 changes: 22 additions & 12 deletions src/nnsight/intervention/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def context_dependency(
self,
context_node: InterventionNode,
intervention_subgraphs: List[SubGraph],
):
) -> None:

context_graph: SubGraph = context_node.args[0]

Expand All @@ -64,25 +64,26 @@ def context_dependency(

for intervention_subgraph in intervention_subgraphs:

if intervention_subgraph.subset[-1] < start:
continue

if intervention_subgraph.subset[0] > end:
# continue if the subgraph does not overlap with the context's graph
if intervention_subgraph.subset[-1] < start or end < intervention_subgraph.subset[0]:
continue

for intervention_index in intervention_subgraph.subset:

if intervention_index >= start and intervention_index <= end:
# if there's an overlapping node, make the context depend on the intervention node in the subgraph
if start <= intervention_index and intervention_index <= end:

# the first node in the subgraph is an InterventionProtocol node
intervention_node = intervention_subgraph[0]

context_node._dependencies.add(intervention_node.index)
intervention_node._listeners.add(context_node.index)
# TODO: maybe we don't need this
intervention_subgraph.subset.append(context_node.index)

break

def compile(self) -> None:
def compile(self) -> Optional[Dict[str, List[InterventionNode]]]:

if self.compiled:
return self.interventions
Expand All @@ -94,11 +95,14 @@ def compile(self) -> None:

start = self[0].index

# is the first node corresponding to an executable graph?
# occurs when a Conditional or Iterator context is explicitly entered by a user
if isinstance(self[0].target, type) and issubclass(
self[0].target, Context
):
graph = self[0].args[0]

# handle emtpy if statments or for loops
if len(graph) > 0:
start = graph[0].index

Expand All @@ -108,10 +112,12 @@ def compile(self) -> None:
defer_start: int = None
context_node: InterventionNode = None

# looping over all the nodes created within this graph's context
for index in range(start, end):

node: InterventionNodeType = self.nodes[index]

# is this node part of an inner context's subgraph?
if context_node is None and node.graph is not self:

context_node = self.nodes[node.graph[-1].index + 1]
Expand All @@ -124,6 +130,7 @@ def compile(self) -> None:

if node.target is InterventionProtocol:

# build intervention subgraph
subgraph = SubGraph(self, subset=sorted(list(node.subgraph())))

module_path, *_ = node.args
Expand All @@ -132,10 +139,13 @@ def compile(self) -> None:

intervention_subgraphs.append(subgraph)

# if the InterventionProtocol is defined within a sub-context
if context_node is not None:


# make the current context node dependent on this intervention node
context_node._dependencies.add(node.index)
node._listeners.add(context_node.index)
# TODO: maybe we don't need this
self.subset.append(node.index)

graph: SubGraph = node.graph
Expand All @@ -145,13 +155,13 @@ def compile(self) -> None:
node.kwargs["start"] = context_start
node.kwargs["defer_start"] = defer_start

node.graph = self

else:

node.kwargs["start"] = self.subset.index(subgraph.subset[0])
node.kwargs["defer_start"] = node.kwargs["start"]

node.graph = self

elif node.target is GradProtocol:

subgraph = SubGraph(self, subset=sorted(list(node.subgraph())))
Expand All @@ -172,12 +182,12 @@ def compile(self) -> None:

node.kwargs["start"] = context_start

node.graph = self

else:

node.kwargs["start"] = self.subset.index(subgraph.subset[1])

node.graph = self

elif node.target is ApplyModuleProtocol:

node.graph = self
Expand Down
29 changes: 21 additions & 8 deletions src/nnsight/intervention/protocols/grad.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict

import torch

from ...tracing.protocols import Protocol
from ...tracing.graph import GraphType

if TYPE_CHECKING:
from ..graph import InterventionNode, InterventionNodeType

Expand Down Expand Up @@ -36,11 +37,7 @@ def execute(cls, node: "InterventionNode") -> None:
hook = None

def grad(value):


# print(backwards_iteration, node)



# Set the value of the Node.
node.set_value(value)

Expand All @@ -57,4 +54,20 @@ def grad(value):
return value

# Register hook.
hook = tensor.register_hook(grad)
hook = tensor.register_hook(grad)

@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.

Returns:
- Dict: dictionary style.
"""

default_style = super().style()

default_style["node"] = {"color": "green4", "shape": "box"}

return default_style


22 changes: 19 additions & 3 deletions src/nnsight/intervention/protocols/intervention.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import TYPE_CHECKING, Any, List

from typing import TYPE_CHECKING, Any, Dict

import torch
from ... import util
from .entrypoint import EntryPoint

if TYPE_CHECKING:
from ..graph import InterventionNodeType
from ..interleaver import Interleaver
from ..graph import InterventionNodeType, InterventionProxyType, InterventionGraph, InterventionProxy, InterventionNode

class InterventionProtocol(EntryPoint):

Expand Down Expand Up @@ -196,3 +195,20 @@ def narrow(acts: torch.Tensor):
def execute(cls, node: "InterventionNodeType"):
# To prevent the node from looking like its executed when calling Graph.execute
node.executed = False

@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.

Returns:
- Dict: dictionary style.
"""

default_style = super().style()

default_style["node"] = {"color": "green4", "shape": "box"}
default_style["arg_kname"][0] = "module_path"
default_style["arg_kname"][1] = "batch_group"
default_style["arg_kname"][2] = "call_counter"

return default_style
24 changes: 18 additions & 6 deletions src/nnsight/intervention/protocols/module.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import TYPE_CHECKING, Any, Dict

import inspect
from typing import TYPE_CHECKING
import torch
from ...tracing.protocols import Protocol
from typing_extensions import Self
from ...tracing.graph import SubGraph

from ... import util
from ...tracing.protocols import Protocol

if TYPE_CHECKING:

from ..graph import InterventionProxyType, InterventionNode, InterventionGraph
from ..graph import InterventionGraph, InterventionNode

class ApplyModuleProtocol(Protocol):
"""Protocol that references some root model, and calls its .forward() method given some input.
Expand Down Expand Up @@ -92,3 +91,16 @@ def execute(cls, node: "InterventionNode") -> None:

node.set_value(output)

@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.

Returns:
- Dict: dictionary style.
"""

default_style = super().style()

default_style["node"] = {"color": "green4", "shape": "polygon", "sides": 6}

return default_style
21 changes: 16 additions & 5 deletions src/nnsight/intervention/protocols/swap.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict

import torch
from ...tracing.protocols import Protocol
from ... import util

if TYPE_CHECKING:
from ..graph import InterventionNodeType, InterventionGraph, InterventionProxyType
from ..graph import InterventionNodeType


class SwapProtocol(Protocol):
Expand All @@ -22,4 +21,16 @@ def execute(cls, node: "InterventionNodeType") -> None:

intervention_node.kwargs['swap'] = value


@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.

Returns:
- Dict: dictionary style.
"""

default_style = super().style()

default_style["node"] = {"color": "green4", "shape": "ellipse"}

return default_style
9 changes: 6 additions & 3 deletions src/nnsight/tracing/contexts/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

from contextlib import AbstractContextManager
from typing import Any, Callable, Generic, Optional, Type, Union
from typing import Generic, Optional, Type

from typing_extensions import Self

from ... import CONFIG
from ...tracing.graph import Node, NodeType, Proxy, ProxyType
from ..backends import Backend, ExecutionBackend
from ..graph import Graph, GraphType, SubGraph
from ..graph import Graph, GraphType, SubGraph, viz_graph
from ..protocols import Protocol
from ... import CONFIG

class Context(Protocol, AbstractContextManager, Generic[GraphType]):
"""A `Context` represents a scope (or slice) of a computation graph with specific logic for adding and executing nodes defined within it.
Expand Down Expand Up @@ -81,6 +81,9 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:

self.backend(graph)

def vis(self, *args, **kwargs):
viz_graph(self.graph, *args, **kwargs)

@classmethod
def execute(cls, node: NodeType):

Expand Down
20 changes: 17 additions & 3 deletions src/nnsight/tracing/contexts/conditional.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

from typing import Any, Optional
from typing import Any, Dict, Optional

from ...tracing.graph import NodeType, ProxyType, SubGraph
from ...tracing.graph import NodeType, SubGraph
from ..contexts import Context


class Condition(Context[SubGraph]):

def __init__(
self, condition: Any, branch: NodeType = False, *args, **kwargs
self, condition: Optional[NodeType], branch: Optional[NodeType] = None, *args, **kwargs
) -> None:
super().__init__(*args, **kwargs)

Expand All @@ -32,6 +32,7 @@ def execute(cls, node: NodeType):
condition: Any
condition, branch = node.prepare_inputs((condition, branch))

# else case has a True condition
if condition is None and not branch:
condition = True

Expand All @@ -45,5 +46,18 @@ def execute(cls, node: NodeType):
graph.clean()
node.set_value(branch)

@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.

Returns:
- Dict: dictionary style.
"""

default_style = super().style()

default_style["node"] = {"color": "#FF8C00", "shape": "polygon", "sides": 6}
default_style["edge"][2] = {"style": "solid", "label": "branch", "color": "#FF8C00", "fontsize": 10}

return default_style

Loading
Loading