diff --git a/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py b/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py index 3cbf63dc8..22ca1f591 100644 --- a/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py +++ b/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py @@ -143,7 +143,7 @@ def _get_json_tensor( if qp.get("input_zp_dtype") is not None else "torch.int8" ) - quantization_dtype = tensors._serialized_name_to_dtype( + quantization_dtype = tensors.serialized_name_to_dtype( quantization_type.split(".")[-1] ) if output_scale is not None: diff --git a/sharktank/sharktank/types/layouts.py b/sharktank/sharktank/types/layouts.py index 54210da9b..586e4f673 100644 --- a/sharktank/sharktank/types/layouts.py +++ b/sharktank/sharktank/types/layouts.py @@ -22,8 +22,8 @@ register_quantized_layout, MetaDataValueType, QuantizedLayout, - _dtype_to_serialized_name, - _serialized_name_to_dtype, + dtype_to_serialized_name, + serialized_name_to_dtype, ) from .layout_utils import ( @@ -96,7 +96,7 @@ def create( m = planes.get("m") dtype_str = metadata.get("dtype") if dtype_str is not None: - dtype = _serialized_name_to_dtype(dtype_str) + dtype = serialized_name_to_dtype(dtype_str) else: # Backwards compat with old serialized. Emulate original behavior # before mixed precision. @@ -106,7 +106,7 @@ def create( @property def metadata(self) -> Optional[dict[str, MetaDataValueType]]: """Additional metadata needed to reconstruct a layout.""" - return {"dtype": _dtype_to_serialized_name(self._dtype)} + return {"dtype": dtype_to_serialized_name(self._dtype)} @property def planes(self) -> dict[str, torch.Tensor]: diff --git a/sharktank/sharktank/types/quantizers.py b/sharktank/sharktank/types/quantizers.py index 21f1c89ec..575c969de 100644 --- a/sharktank/sharktank/types/quantizers.py +++ b/sharktank/sharktank/types/quantizers.py @@ -38,8 +38,8 @@ QuantizedTensor, UnnamedTensorName, register_inference_tensor, - _serialized_name_to_dtype, - _dtype_to_serialized_name, + serialized_name_to_dtype, + dtype_to_serialized_name, ) __all__ = [ @@ -246,7 +246,7 @@ def create( raise IOError("Missing property") from e axis = int(extra_properties["axis"]) if "axis" in extra_properties else None disable_saturate = bool(extra_properties.get("disable_saturate")) - dtype = _serialized_name_to_dtype(dtype_name) + dtype = serialized_name_to_dtype(dtype_name) return cls( name=name, scale=scale, @@ -272,7 +272,7 @@ def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetad scale_name = f"{self.name}:scale" rscale_name = f"{self.name}:rscale" offset_name = f"{self.name}:offset" - extra_properties = {"dtype": _dtype_to_serialized_name(self._dtype)} + extra_properties = {"dtype": dtype_to_serialized_name(self._dtype)} if self._axis is not None: extra_properties["axis"] = self._axis if self._disable_saturate: @@ -388,7 +388,7 @@ def create( dtype_name = extra_properties["dtype"] except KeyError as e: raise IOError("Missing property") from e - dtype = _serialized_name_to_dtype(dtype_name) + dtype = serialized_name_to_dtype(dtype_name) return cls( name=name, dtype=dtype, @@ -400,7 +400,7 @@ def globals(self) -> dict[str, torch.Tensor]: def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetadata: """Adds this tensor to the global archive.""" - extra_properties = {"dtype": _dtype_to_serialized_name(self._dtype)} + extra_properties = {"dtype": dtype_to_serialized_name(self._dtype)} raw_tensors = {} return InferenceTensorMetadata( self.serialized_name(), diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 13c64e8c2..221475371 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -40,6 +40,7 @@ __all__ = [ "AnyTensor", "DefaultPrimitiveTensor", + "dtype_to_serialized_name", "flatten_tensor_tree", "InferenceTensor", "MetaDataValueType", @@ -49,6 +50,7 @@ "QuantizedTensor", "register_quantized_layout", "ReplicatedTensor", + "serialized_name_to_dtype", "ShardedTensor", "SplitPrimitiveTensor", "torch_tree_flatten", @@ -1248,7 +1250,7 @@ def unbox_tensor(t: Any) -> Tensor: ######################################################################################## -def _dtype_to_serialized_name(dtype: torch.dtype) -> str: +def dtype_to_serialized_name(dtype: torch.dtype) -> str: try: return _DTYPE_TO_NAME[dtype] except KeyError as e: @@ -1257,7 +1259,7 @@ def _dtype_to_serialized_name(dtype: torch.dtype) -> str: ) from e -def _serialized_name_to_dtype(dtype_name: str) -> torch.dtype: +def serialized_name_to_dtype(dtype_name: str) -> torch.dtype: try: return _NAME_TO_DTYPE[dtype_name] except KeyError as e: