Skip to content

Commit

Permalink
Add env var option to override torch.Tensor.__repr__ (#384)
Browse files Browse the repository at this point in the history
The __repr__ for large tensors is way too slow and the debugger takes
forever to access them stalling for minutes at a time. This gives the
option to replace it.
  • Loading branch information
sogartar authored Nov 6, 2024
1 parent 8ff3c95 commit 0e253c6
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from copy import deepcopy
from collections.abc import Collection, Sequence
from numbers import Integral, Number
import os

from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -55,6 +56,18 @@
"UnreducedTensor",
]

if (
"SHARKTANK_OVERRIDE_TORCH_TENSOR_REPR" in os.environ
and os.environ["SHARKTANK_OVERRIDE_TORCH_TENSOR_REPR"] != "0"
):

def _tensor_debugger_friendly_repr(self: torch.Tensor):
"""Override for the torch.Tensor.__repr__ so it does not take forever when the
debugger wants to query many/large tensors."""
return f"Tensor({list(self.shape)}, {self.dtype})"

Tensor.__repr__ = _tensor_debugger_friendly_repr

# JSON encodable value types.
MetaDataValueType = Union[int, bool, float, str]
UnnamedTensorName = "<unnamed>"
Expand Down

0 comments on commit 0e253c6

Please sign in to comment.