Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lint fixes test folders#1 #1393

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@ include = [
"torchao/quantization/**/*.py",
"torchao/dtypes/**/*.py",
"torchao/sparsity/**/*.py",
"torchao/profiler/**/*.py",
"torchao/testing/**/*.py",
"torchao/prototype/low_bit_optim/**.py",
"torchao/utils.py",
"torchao/ops.py",
"torchao/_executorch_ops.py",
# Test folders
"test/dora/**/*.py",
"test/dtypes/**/*.py",
"test/float8/**/*.py",
"test/galore/**/*.py",
"test/hqq/**/*.py",
"test/quantization/**/*.py",
"test/dtypes/**/*.py",
"test/sparsity/**/*.py",
"test/prototype/low_bit_optim/**.py",
]
Expand Down
20 changes: 6 additions & 14 deletions test/dora/test_dora_fusion.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import itertools
import sys

import pytest
import torch

from torchao.prototype.dora.kernels.matmul import triton_mm
from torchao.prototype.dora.kernels.smallk import triton_mm_small_k

if sys.version_info < (3, 11):
pytest.skip("requires Python >= 3.11", allow_module_level=True)

triton = pytest.importorskip("triton", reason="requires triton")

import itertools

import torch

from torchao.prototype.dora.kernels.matmul import triton_mm
from torchao.prototype.dora.kernels.smallk import triton_mm_small_k

torch.manual_seed(0)

# Test configs
Expand Down Expand Up @@ -48,13 +46,7 @@ def _arg_to_id(arg):


def check(expected, actual, dtype):
if dtype == torch.float32:
atol = 1e-4
elif dtype == torch.float16:
atol = 1e-3
elif dtype == torch.bfloat16:
atol = 1e-2
else:
if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise ValueError(f"Unsupported dtype: {dtype}")
diff = (expected - actual).abs().max()
print(f"diff: {diff}")
Expand Down
16 changes: 5 additions & 11 deletions test/dora/test_dora_layer.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
import itertools
import sys

import pytest
import torch

from torchao.prototype.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear

if sys.version_info < (3, 11):
pytest.skip("requires Python >= 3.11", allow_module_level=True)

bnbnn = pytest.importorskip("bitsandbytes.nn", reason="requires bitsandbytes")
hqq_core = pytest.importorskip("hqq.core.quantize", reason="requires hqq")

import itertools

import torch

# Import modules as opposed to classes directly, otherwise pytest.importorskip always skips
Linear4bit = bnbnn.Linear4bit
BaseQuantizeConfig = hqq_core.BaseQuantizeConfig
HQQLinear = hqq_core.HQQLinear
from torchao.prototype.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear


def check(expected, actual, dtype):
if dtype == torch.float32:
atol = 1e-4
elif dtype == torch.float16:
atol = 1e-3
elif dtype == torch.bfloat16:
atol = 1e-2
else:
if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise ValueError(f"Unsupported dtype: {dtype}")
diff = (expected - actual).abs().max()
print(f"diff: {diff}")
Expand Down
2 changes: 1 addition & 1 deletion test/galore/memory_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _convert_to_units(df, col):
convert_cols_to_MB = {col: partial(_convert_to_units, col=col) for col in COL_NAMES}

df = pd.DataFrame(
[l[1:] for l in df.iloc[:, 1].to_list()], columns=COL_NAMES
[row[1:] for row in df.iloc[:, 1].to_list()], columns=COL_NAMES
).assign(**convert_cols_to_MB)
df["Total"] = df.sum(axis=1)
return df
Expand Down
4 changes: 2 additions & 2 deletions test/galore/profile_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def run(args, file_prefix):
model_config = LlamaConfig()
try:
model_config_dict = getattr(model_configs, args.model_config.upper())
except:
except Exception:
raise ValueError(f"Model config {args.model_config} not found")
model_config.update(model_config_dict)
model = LlamaForCausalLM(model_config).to("cuda")
Expand Down Expand Up @@ -163,7 +163,7 @@ def run(args, file_prefix):
if args.torch_profiler:
print(f"Finished profiling, outputs saved to {args.output_dir}/{file_prefix}*")
else:
print(f"Finished profiling")
print("Finished profiling")


if __name__ == "__main__":
Expand Down
2 changes: 0 additions & 2 deletions test/galore/profiling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def get_cuda_memory_usage(units="MB", show=True):


def export_memory_snapshot(prefix) -> None:

# Prefix for file names.
timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = f"{prefix}_{timestamp}"
Expand Down Expand Up @@ -115,7 +114,6 @@ def trace_handler(
export_memory_timeline=True,
print_table=True,
):

timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = os.path.join(output_dir, f"{prefix}_{timestamp}")

Expand Down
113 changes: 79 additions & 34 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,136 @@
import unittest

import torch

from torchao.quantization import (
ZeroPointDomain,
MappingType,
ZeroPointDomain,
int4_weight_only,
uintx_weight_only,
)

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
)
from torchao.quantization import (
uintx_weight_only,
int4_weight_only,
)

cuda_available = torch.cuda.is_available()

#Parameters
device = 'cuda:0'
compute_dtype = torch.bfloat16
group_size = 64
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size) #axis=1
preserve_zero = False
# Parameters
device = "cuda:0"
compute_dtype = torch.bfloat16
group_size = 64
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size) # axis=1
preserve_zero = False
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_dtype = compute_dtype
inner_k_tiles = 8
in_features = 4096
out_features = 11800
torch_seed = 100
zero_point_dtype = compute_dtype
inner_k_tiles = 8
in_features = 4096
out_features = 11800
torch_seed = 100


def _init_data(in_features, out_features, compute_dtype, device, torch_seed):
torch.random.manual_seed(torch_seed)
linear_layer = torch.nn.Linear(in_features, out_features, bias=False).to(device)
x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20.
x = (
torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)
/ 20.0
)
y_ref = linear_layer(x)
W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype)
return W, x, y_ref


def _eval_hqq(dtype):
W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed)
W, x, y_ref = _init_data(
in_features, out_features, compute_dtype, device, torch_seed
)

dummy_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False)
dummy_linear = torch.nn.Linear(
in_features=in_features, out_features=out_features, bias=False
)
dummy_linear.weight.data = W
if dtype == torch.uint4:
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(dummy_linear).weight
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(
dummy_linear
).weight
else:
q_tensor_hqq = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)(dummy_linear).weight
q_tensor_hqq = uintx_weight_only(
dtype, group_size=max(block_size), use_hqq=True
)(dummy_linear).weight

quant_linear_layer = torch.nn.Linear(W.shape[1], W.shape[0], bias=False, device=W.device)
quant_linear_layer = torch.nn.Linear(
W.shape[1], W.shape[0], bias=False, device=W.device
)
del quant_linear_layer.weight
quant_linear_layer.weight = q_tensor_hqq
dequantize_error = (W - q_tensor_hqq.dequantize()).abs().mean().item()
dot_product_error = (y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item()
dot_product_error = (
(y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item()
)

return dequantize_error, dot_product_error


@unittest.skipIf(not cuda_available, "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+")
class TestHQQ(unittest.TestCase):
def _test_hqq(self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None):
if(dtype is None): return
def _test_hqq(
self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None
):
if dtype is None:
return
dequantize_error, dot_product_error = _eval_hqq(dtype)
self.assertTrue(dequantize_error < ref_dequantize_error)
self.assertTrue(dot_product_error < ref_dot_product_error)

def test_hqq_plain_8bit(self):
self._test_hqq(dtype=torch.uint8, ref_dequantize_error=5e-5, ref_dot_product_error=0.00013)
self._test_hqq(
dtype=torch.uint8, ref_dequantize_error=5e-5, ref_dot_product_error=0.00013
)

def test_hqq_plain_7bit(self):
self._test_hqq(dtype=torch.uint7, ref_dequantize_error=6e-05, ref_dot_product_error=0.000193)
self._test_hqq(
dtype=torch.uint7,
ref_dequantize_error=6e-05,
ref_dot_product_error=0.000193,
)

def test_hqq_plain_6bit(self):
self._test_hqq(dtype=torch.uint6, ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353)
self._test_hqq(
dtype=torch.uint6,
ref_dequantize_error=0.0001131,
ref_dot_product_error=0.000353,
)

def test_hqq_plain_5bit(self):
self._test_hqq(dtype=torch.uint5, ref_dequantize_error=0.00023, ref_dot_product_error=0.000704)
self._test_hqq(
dtype=torch.uint5,
ref_dequantize_error=0.00023,
ref_dot_product_error=0.000704,
)

def test_hqq_plain_4bit(self):
self._test_hqq(dtype=torch.uint4, ref_dequantize_error=0.000487, ref_dot_product_error=0.001472)
self._test_hqq(
dtype=torch.uint4,
ref_dequantize_error=0.000487,
ref_dot_product_error=0.001472,
)

def test_hqq_plain_3bit(self):
self._test_hqq(dtype=torch.uint3, ref_dequantize_error=0.00101, ref_dot_product_error=0.003047)
self._test_hqq(
dtype=torch.uint3,
ref_dequantize_error=0.00101,
ref_dot_product_error=0.003047,
)

def test_hqq_plain_2bit(self):
self._test_hqq(dtype=torch.uint2, ref_dequantize_error=0.002366, ref_dot_product_error=0.007255)
self._test_hqq(
dtype=torch.uint2,
ref_dequantize_error=0.002366,
ref_dot_product_error=0.007255,
)


if __name__ == "__main__":
unittest.main()
17 changes: 9 additions & 8 deletions test/hqq/test_triton_mm.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# Skip entire test if following module not available, otherwise CI failure
import itertools

import pytest
import torch

from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm

triton = pytest.importorskip(
"triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test"
)
hqq = pytest.importorskip("hqq", reason="hqq required to run this test")
hqq_quantize = pytest.importorskip("hqq.core.quantize", reason="hqq required to run this test")
hqq_quantize = pytest.importorskip(
"hqq.core.quantize", reason="hqq required to run this test"
)
HQQLinear = hqq_quantize.HQQLinear
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig

import itertools

import torch

from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm

# Test configs
SHAPES = [
[16, 128, 128],
Expand Down Expand Up @@ -96,7 +97,7 @@ def test_mixed_mm(
W_q = W_q.to(dtype=quant_dtype)
W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
if not quant_config["weight_quant_params"]["bitpack"]
else W_q
)
W_dq = hqq_linear.dequantize()
Expand Down
14 changes: 6 additions & 8 deletions test/hqq/test_triton_qkv_fused.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import itertools

import pytest
import torch

from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm

triton = pytest.importorskip(
"triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test"
Expand All @@ -10,13 +15,6 @@
HQQLinear = hqq_quantize.HQQLinear
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig

import itertools

import torch
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear, Quantizer

from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm

torch.manual_seed(0)
# N, K = shape
Q_SHAPES = [[4096, 4096]]
Expand Down Expand Up @@ -60,7 +58,7 @@ def quantize_helper(
W_q = W_q.to(dtype=quant_dtype)
W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
if not quant_config["weight_quant_params"]["bitpack"]
else W_q
)

Expand Down
Loading