diff --git a/scripts/all_reduce_test.py b/scripts/all_reduce_test.py index 31201b45..a331916d 100644 --- a/scripts/all_reduce_test.py +++ b/scripts/all_reduce_test.py @@ -51,7 +51,7 @@ def main(config: Config): globals={"all_reduce": all_reduce, "mat": mat, "transfer_dtype": transfer_dtype}, ) elif config.transfer_dtype is not None and torch.uint8: - from zeroband.compression import uniform_8bit_quantize + from zeroband.C.compression import uniform_8bit_quantize t0 = benchmark.Timer( stmt="all_reduce(mat, quantization_func=foo)",