Skip to content

Commit

Permalink
Merge pull request #38 from JadenFiotto-Kaufman/proxy_operators
Browse files Browse the repository at this point in the history
More proxy operators
  • Loading branch information
JadenFiotto-Kaufman authored Jan 3, 2024
2 parents 7c6c777 + b6847ff commit a6a26e0
Showing 1 changed file with 55 additions and 2 deletions.
57 changes: 55 additions & 2 deletions src/nnsight/tracing/Proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, node: "Node") -> None:

def __getstate__(self):
return self.__dict__

def __setstate__(self, d: dict):
self.__dict__ = d

Expand All @@ -45,7 +45,6 @@ def __call__(self, *args, **kwargs) -> Proxy:
if self.node.args[0] is self.node.graph.module_proxy.node and not isinstance(
self.node.proxy_value, torch.nn.Module
):

value = self.node.proxy_value.__func__(
self.node.graph.module_proxy, *args, **kwargs
)
Expand Down Expand Up @@ -83,30 +82,78 @@ def __len__(self) -> Proxy:
args=[self.node],
)

def __abs__(self) -> Proxy:
return self.node.graph.add(
target=operator.abs,
args=[self.node],
)

def __invert__(self) -> Proxy:
return self.node.graph.add(
target=operator.invert,
args=[self.node],
)

def __add__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.add,
args=[self.node, other],
)

def __radd__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.add,
args=[self.node, other],
)

def __sub__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.sub,
args=[self.node, other],
)

def __rsub__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.sub,
args=[self.node, other],
)

def __pow__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.pow,
args=[self.node, other],
)

def __rpow__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.pow,
args=[self.node, other],
)

def __mul__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.mul,
args=[self.node, other],
)

def __rmul__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.mul,
args=[self.node, other],
)

def __mod__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.mod,
args=[self.node, other],
)

def __rmod__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.mod,
args=[self.node, other],
)

def __matmul__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.matmul,
Expand All @@ -119,6 +166,12 @@ def __truediv__(self, other: Union[Proxy, Any]) -> Proxy:
args=[self.node, other],
)

def __rtruediv__(self, other: Union[Proxy, Any]) -> Proxy:
return self.node.graph.add(
target=operator.truediv,
args=[self.node, other],
)

def __bool__(self) -> bool:
return self.node.proxy_value.__bool__()

Expand Down

0 comments on commit a6a26e0

Please sign in to comment.