From c4dc41fac631d57bc69a30e583281ef70b1fe60b Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 4 Oct 2024 16:15:28 +0800 Subject: [PATCH] eliminate expensive long cast in average tensors --- src/zeroband/csrc/compress.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/zeroband/csrc/compress.cpp b/src/zeroband/csrc/compress.cpp index 8c0bb419..d31012e2 100644 --- a/src/zeroband/csrc/compress.cpp +++ b/src/zeroband/csrc/compress.cpp @@ -3,6 +3,12 @@ #include #include +#include +#include +#include +#include +#include + namespace py = pybind11; constexpr int n_bins = 256; // 8-bit quantization @@ -105,17 +111,16 @@ torch::Tensor average_buckets(const torch::Tensor& tensor, const torch::Tensor& } torch::Tensor average_buckets_multithread(const torch::Tensor& tensor, const torch::Tensor& quant_weight, int64_t n_bins) { - std::cout << quant_weight.sizes() << tensor.sizes() << std::endl; torch::NoGradGuard no_grad; auto flat_tensor = tensor.flatten().contiguous(); - auto flat_quant_weight = quant_weight.flatten().to(torch::kLong).contiguous(); + auto flat_quant_weight = quant_weight.flatten().contiguous(); auto options = flat_tensor.options(); auto bin_sums = torch::zeros({n_bins}, options); auto bin_counts = torch::zeros({n_bins}, options.dtype(torch::kLong)); // Get raw pointers float* tensor_data = flat_tensor.data_ptr(); - int64_t* quant_data = flat_quant_weight.data_ptr(); + uint8_t* quant_data = flat_quant_weight.data_ptr(); float* sums_data = bin_sums.data_ptr(); int64_t* counts_data = bin_counts.data_ptr(); int64_t numel = flat_tensor.numel(); @@ -133,8 +138,8 @@ torch::Tensor average_buckets_multithread(const torch::Tensor& tensor, const tor std::vector local_counts(n_bins, 0); for (int64_t i = start; i < end; ++i) { - int64_t bin = quant_data[i]; - if (bin >= 0 && bin < n_bins) { + uint8_t bin = quant_data[i]; + if (bin < n_bins) { // No need to check for >= 0 as uint8_t is always non-negative local_sums[bin] += tensor_data[i]; local_counts[bin]++; } @@ -188,8 +193,7 @@ std::tuple uniform_8bit_quantize(torch::Tensor ten // Call average_buckets to create the lookup table torch::Tensor lookup = average_buckets_multithread(tensor, quantized_tensor, n_bins); - return std::make_tuple(centered_tensor, centered_tensor); - //return std::make_tuple(quantized_tensor, lookup); + return std::make_tuple(quantized_tensor, lookup); }