Skip to content

Commit

Permalink
eliminate expensive long cast in average tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Oct 4, 2024
1 parent a5318b9 commit c4dc41f
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/zeroband/csrc/compress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
#include <cmath>
#include <immintrin.h>

#include <thread>
#include <vector>
#include <algorithm>
#include <chrono>
#include <iostream>

namespace py = pybind11;

constexpr int n_bins = 256; // 8-bit quantization
Expand Down Expand Up @@ -105,17 +111,16 @@ torch::Tensor average_buckets(const torch::Tensor& tensor, const torch::Tensor&
}

torch::Tensor average_buckets_multithread(const torch::Tensor& tensor, const torch::Tensor& quant_weight, int64_t n_bins) {
std::cout << quant_weight.sizes() << tensor.sizes() << std::endl;
torch::NoGradGuard no_grad;
auto flat_tensor = tensor.flatten().contiguous();
auto flat_quant_weight = quant_weight.flatten().to(torch::kLong).contiguous();
auto flat_quant_weight = quant_weight.flatten().contiguous();
auto options = flat_tensor.options();
auto bin_sums = torch::zeros({n_bins}, options);
auto bin_counts = torch::zeros({n_bins}, options.dtype(torch::kLong));

// Get raw pointers
float* tensor_data = flat_tensor.data_ptr<float>();
int64_t* quant_data = flat_quant_weight.data_ptr<int64_t>();
uint8_t* quant_data = flat_quant_weight.data_ptr<uint8_t>();
float* sums_data = bin_sums.data_ptr<float>();
int64_t* counts_data = bin_counts.data_ptr<int64_t>();
int64_t numel = flat_tensor.numel();
Expand All @@ -133,8 +138,8 @@ torch::Tensor average_buckets_multithread(const torch::Tensor& tensor, const tor
std::vector<int64_t> local_counts(n_bins, 0);

for (int64_t i = start; i < end; ++i) {
int64_t bin = quant_data[i];
if (bin >= 0 && bin < n_bins) {
uint8_t bin = quant_data[i];
if (bin < n_bins) { // No need to check for >= 0 as uint8_t is always non-negative
local_sums[bin] += tensor_data[i];
local_counts[bin]++;
}
Expand Down Expand Up @@ -188,8 +193,7 @@ std::tuple<torch::Tensor, torch::Tensor> uniform_8bit_quantize(torch::Tensor ten
// Call average_buckets to create the lookup table
torch::Tensor lookup = average_buckets_multithread(tensor, quantized_tensor, n_bins);

return std::make_tuple(centered_tensor, centered_tensor);
//return std::make_tuple(quantized_tensor, lookup);
return std::make_tuple(quantized_tensor, lookup);
}


Expand Down

0 comments on commit c4dc41f

Please sign in to comment.