Skip to content

Commit

Permalink
Merge pull request #67 from JadenFiotto-Kaufman/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Jan 31, 2024
2 parents 9151fa4 + fbc9574 commit 98979b3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def where(input: torch.Tensor, *args, **kwargs):
if isinstance(args[0], torch.Tensor)
else type(args[0])
)
if isinstance(args[0], torch.Tensor):
return torch.zeros_like(torch.broadcast_tensors(input, args[0])[0], dtype=input.dtype, device="meta")
return torch.zeros_like(input, dtype=input.dtype, device="meta")
return meta_nonzero(input, as_tuple=True)

Expand Down
10 changes: 8 additions & 2 deletions src/nnsight/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module for utility functions and classes used throughout the package."""

import time
import types
from functools import wraps
from typing import Any, Callable, Collection, Type

Expand All @@ -20,7 +21,7 @@ def apply(data: Collection, fn: Callable, cls: Type) -> Collection:
"""
if isinstance(data, cls):
return fn(data)

data_type = type(data)

if data_type == list:
Expand Down Expand Up @@ -65,7 +66,12 @@ def wrap(object: object, wrapper: Type, *args, **kwargs) -> object:
if isinstance(object, wrapper):
return object

object.__class__ = type(object.__class__.__name__, (wrapper, object.__class__), {})
new_class = types.new_class(
object.__class__.__name__,
(object.__class__, wrapper),
)

object.__class__ = new_class

wrapper.__init__(object, *args, **kwargs)

Expand Down

0 comments on commit 98979b3

Please sign in to comment.