Skip to content

Commit

Permalink
Avoid COW materialize in various places (1) (pytorch#124984)
Browse files Browse the repository at this point in the history
Most, not all, of these cases were found automatically with `git grep -n '^\s*\<const\>.*\*.*=.*\<data_ptr\>'`

Part of pytorch#97856

Pull Request resolved: pytorch#124984
Approved by: https://github.com/Skylion007
  • Loading branch information
kurtamohler authored and pytorchmergebot committed Apr 26, 2024
1 parent 2ea1e84 commit abcb42c
Show file tree
Hide file tree
Showing 52 changed files with 147 additions and 145 deletions.
8 changes: 4 additions & 4 deletions aten/src/ATen/NestedTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ inline std::vector<int64_t> construct_opt_sizes(const at::Tensor& sizes) {
std::vector<int64_t> result(1, sizes.sizes()[0]);
if (sizes.dim() > 0) {
size_t nested_dim = result.size();
int64_t* sizes_ptr = sizes.data_ptr<int64_t>();
const int64_t* sizes_ptr = sizes.const_data_ptr<int64_t>();
result.resize(nested_dim + sizes.sizes()[1]);
int64_t sizes_size_0 = sizes.sizes()[0];
int64_t sizes_size_1 = sizes.sizes()[1];
Expand Down Expand Up @@ -114,7 +114,7 @@ at::Tensor construct_nested_strides(const at::Tensor& sizes) {
return sizes;
}
at::Tensor strides = sizes.new_empty(sizes.sizes());
const int64_t* sizes_ptr = sizes.data_ptr<int64_t>();
const int64_t* sizes_ptr = sizes.const_data_ptr<int64_t>();
int64_t* strides_ptr = strides.data_ptr<int64_t>();
for (int64_t i = 0; i < sizes.size(0); i++) {
strides_ptr[orig_dim - 1] = 1;
Expand Down Expand Up @@ -152,7 +152,7 @@ at::Tensor construct_offsets(const at::Tensor& sizes) {
std::iota(offsets_ptr, offsets_ptr + ntensors, 0);
return offsets;
}
const int64_t* sizes_ptr = sizes.data_ptr<int64_t>();
const int64_t* sizes_ptr = sizes.const_data_ptr<int64_t>();
offsets_ptr[0] = 0;
for (const auto i : c10::irange(ntensors - 1)) {
const int64_t row_product = std::accumulate(sizes_ptr, sizes_ptr + orig_dim, 1, std::multiplies());
Expand Down Expand Up @@ -344,7 +344,7 @@ int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor) {
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()),
static_cast<uint64_t>(std::numeric_limits<size_t>::max()));

const int64_t* sizes_ptr = tensor.data_ptr<int64_t>();
const int64_t* sizes_ptr = tensor.const_data_ptr<int64_t>();
const auto nt_dim = tensor.size(1);
uint64_t num_elements{0};

Expand Down
11 changes: 6 additions & 5 deletions aten/src/ATen/NestedTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
}
const Tensor &sizemat = nt->get_nested_sizes(),
&stridemat = nt->get_nested_strides();
int64_t* offsets_ptr = nt->get_storage_offsets().data_ptr<int64_t>();
const int64_t* offsets_ptr =
nt->get_storage_offsets().const_data_ptr<int64_t>();
int64_t orig_dim = sizemat.size(1);
// nesting scalars
if (orig_dim == 0) {
Expand All @@ -243,8 +244,8 @@ inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
// nesting tensors
else {
// if any underlying tensor is non-contiguous
const int64_t *sizemat_ptr = sizemat.data_ptr<int64_t>(),
*stridemat_ptr = stridemat.data_ptr<int64_t>();
const int64_t *sizemat_ptr = sizemat.const_data_ptr<int64_t>(),
*stridemat_ptr = stridemat.const_data_ptr<int64_t>();
for (int64_t i = 0; i < ntensors; i++) {
if (stridemat_ptr[orig_dim - 1] != 1) {
return false;
Expand All @@ -263,8 +264,8 @@ inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
if (offsets_ptr[0] != 0) {
return false;
}
sizemat_ptr = sizemat.data_ptr<int64_t>();
stridemat_ptr = stridemat.data_ptr<int64_t>();
sizemat_ptr = sizemat.const_data_ptr<int64_t>();
stridemat_ptr = stridemat.const_data_ptr<int64_t>();
for (int64_t i = 1; i < ntensors; i++) {
if (offsets_ptr[i] !=
offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/core/Formatting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ static std::tuple<double, int> __printFormat(std::ostream& stream, const Tensor&
return std::make_tuple(1., 0);
}
bool intMode = true;
auto self_p = self.data_ptr<double>();
auto self_p = self.const_data_ptr<double>();
for (const auto i : c10::irange(size)) {
auto z = self_p[i];
if(std::isfinite(z)) {
Expand Down Expand Up @@ -189,7 +189,7 @@ static void __printMatrix(std::ostream& stream, const Tensor& self, int64_t line
}
for (const auto l : c10::irange(self.size(0))) {
Tensor row = self.select(0,l);
double *row_ptr = row.data_ptr<double>();
const double *row_ptr = row.const_data_ptr<double>();
for (const auto c : c10::irange(firstColumn, lastColumn+1)) {
stream << std::setw(sz) << row_ptr[c]/scale;
if(c == lastColumn) {
Expand Down Expand Up @@ -279,15 +279,15 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
tensor = tensor_.to(kCPU, kDouble).contiguous();
}
if(tensor.ndimension() == 0) {
stream << defaultfloat << tensor.data_ptr<double>()[0] << '\n';
stream << defaultfloat << tensor.const_data_ptr<double>()[0] << '\n';
stream << "[ " << tensor_.toString() << "{}";
} else if(tensor.ndimension() == 1) {
if (tensor.numel() > 0) {
auto [scale, sz] = __printFormat(stream, tensor);
if(scale != 1) {
printScale(stream, scale);
}
double* tensor_p = tensor.data_ptr<double>();
const double* tensor_p = tensor.const_data_ptr<double>();
for (const auto i : c10::irange(tensor.size(0))) {
stream << std::setw(sz) << tensor_p[i]/scale << '\n';
}
Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,7 @@ void _linalg_check_errors(
} else {
// Find the first non-zero info
auto infos_cpu = infos.to(at::kCPU);
auto ptr = infos_cpu.data_ptr<int32_t>();
auto ptr = infos_cpu.const_data_ptr<int32_t>();
auto n = infos.numel();
auto info_ptr = std::find_if(ptr, ptr + n, [](int32_t x) { return x != 0; });
info = *info_ptr;
Expand Down Expand Up @@ -2794,13 +2794,13 @@ static void linalg_eig_make_complex_eigenvectors_impl(Tensor& result, const Tens
auto matrix_stride = matrixStride(real_vectors);

auto result_data = result.data_ptr<c10::complex<scalar_t>>();
auto real_vectors_data = real_vectors.data_ptr<scalar_t>();
auto values_data = complex_values.data_ptr<c10::complex<scalar_t>>();
auto real_vectors_data = real_vectors.const_data_ptr<scalar_t>();
auto values_data = complex_values.const_data_ptr<c10::complex<scalar_t>>();

for (auto b = decltype(batch_size){0}; b < batch_size; b++) {
scalar_t* vecs = &real_vectors_data[b * matrix_stride];
const scalar_t* vecs = &real_vectors_data[b * matrix_stride];
c10::complex<scalar_t>* res = &result_data[b * matrix_stride];
c10::complex<scalar_t>* vals = &values_data[b * n];
const c10::complex<scalar_t>* vals = &values_data[b * n];
for (auto j = decltype(n){0}; j < n; j++) {
if (vals[j].imag() == 0.0) { // eigenvalue is real, then v(j) = VR(:,j)
for (auto i = decltype(n){0}; i < n; i++) {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/ForeachUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ inline std::vector<c10::Scalar> convert_tensor_to_scalar_list(
scalarList_.scalar_type(),
"convert_tensor_to_scalar_list",
[&]() {
const scalar_t* scalar_data = scalarList_.data_ptr<scalar_t>();
const scalar_t* scalar_data = scalarList_.const_data_ptr<scalar_t>();
TORCH_CHECK(
(expect_length == scalarList_.size(0)),
"Expected length of scalars to match input of length ",
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Tensor fbgemm_linear_int8_weight_fp32_activation(
"and will be removed in a future PyTorch release.")

const Tensor input_contig = input.contiguous();
const float* input_ptr = input_contig.data_ptr<float>();
const float* input_ptr = input_contig.const_data_ptr<float>();

TORCH_CHECK(input.dim() >= 2);
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
Expand Down Expand Up @@ -305,7 +305,7 @@ Tensor fbgemm_pack_quantized_matrix(const Tensor& weight) {
const int64_t K = weight.size(1);
const int64_t N = weight.size(0);
const Tensor weight_contig = weight.contiguous();
const int8_t* weight_ptr = weight_contig.data_ptr<int8_t>();
const int8_t* weight_ptr = weight_contig.const_data_ptr<int8_t>();
auto ptr = std::make_unique<fbgemm::PackBMatrix<int8_t>>(
/*trans=*/fbgemm::matrix_op_t::Transpose,
/*nRow=*/K,
Expand Down Expand Up @@ -424,7 +424,7 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation(
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");

const Tensor input_contig = input.contiguous();
const float* input_ptr = input_contig.data_ptr<float>();
const float* input_ptr = input_contig.const_data_ptr<float>();

// Pull out the PackedGemmMatrixFP16 instance from the owning tensor
const fbgemm::PackedGemmMatrixFP16& packed_weight_fp16 =
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/SummaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Tensor _bincount_cpu_template(
int64_t nbins = static_cast<int64_t>(*self.max().data_ptr<input_t>()) + 1L;
nbins = std::max(nbins, minlength); // at least minlength # of bins

const input_t* self_p = self.data_ptr<input_t>();
const input_t* self_p = self.const_data_ptr<input_t>();
if (has_weights) {
output = at::zeros(
{nbins},
Expand All @@ -52,7 +52,7 @@ Tensor _bincount_cpu_template(
weights.options().device_opt(),
weights.options().pinned_memory_opt());
weights_t* output_p = output.data_ptr<weights_t>();
const weights_t* weights_p = weights.data_ptr<weights_t>();
const weights_t* weights_p = weights.const_data_ptr<weights_t>();
for (const auto i : c10::irange(self_size)) {
output_p[self_p[i]] += weights_p[i];
}
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,7 @@ void convert_indices_from_coo_to_csr_cpu(
const Tensor& input,
const int64_t size) {
int64_t numel = input.numel();
const input_t* data_in = input.data_ptr<input_t>();
const input_t* data_in = input.const_data_ptr<input_t>();
output_t* data_out = result.data_ptr<output_t>();

if (numel == 0) {
Expand Down Expand Up @@ -1525,7 +1525,7 @@ void convert_indices_from_csr_to_coo_cpu(
batch_indices.copy_(at::sparse::full_coo_indices(crow_indices.sizes().slice(0, batch_ndim), crow_indices.options())
.repeat_interleave(nnz, 1));
}
const input_t* crow_indices_data_in = crow_indices_->data_ptr<input_t>();
const input_t* crow_indices_data_in = crow_indices_->const_data_ptr<input_t>();
TORCH_INTERNAL_ASSERT(indices.is_contiguous());
auto row0 = indices.select(0, transpose ? batch_ndim + 1 : batch_ndim + 0);
auto row1 = indices.select(0, transpose ? batch_ndim + 0 : batch_ndim + 1);
Expand Down
32 changes: 16 additions & 16 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2058,7 +2058,7 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in
// fill in src_int_idx, sorted_int_idx, int_counts
{
const auto sorted_len = sorted.numel();
const auto* ptr_sorted = sorted.data_ptr<int64_t>();
const auto* ptr_sorted = sorted.const_data_ptr<int64_t>();
const auto* ptr_sorted_start = ptr_sorted;
const auto* ptr_sorted_end = ptr_sorted + sorted_len;

Expand Down Expand Up @@ -2121,7 +2121,7 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in
auto* ptr_selected_src = selected_src.data_ptr<int64_t>();

const auto thread_offsets = compressed_int_counts.cumsum(0).sub_(compressed_int_counts);
const auto* ptr_sorted_idx = sorted_idx.data_ptr<int64_t>();
const auto* ptr_sorted_idx = sorted_idx.const_data_ptr<int64_t>();
at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
const auto start = tid * chunk_size_src;
const auto end = std::min(start + chunk_size_src, src_len);
Expand Down Expand Up @@ -2163,7 +2163,7 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in
bool run_in_parallel = true) -> Tensor {
auto cidx = at::empty({len + 1}, idx.options());

const auto* ptr_idx = idx.data_ptr<int64_t>();
const auto* ptr_idx = idx.const_data_ptr<int64_t>();
auto* ptr_cidx = cidx.data_ptr<int64_t>();

const auto idx_len = idx.numel();
Expand Down Expand Up @@ -2202,7 +2202,7 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in
}
else {
auto* ptr_counts = counts.data_ptr<int64_t>();
const auto* ptr_vals = t.data_ptr<int64_t>();
const auto* ptr_vals = t.const_data_ptr<int64_t>();
for (C10_UNUSED const auto _ : c10::irange(t.numel())) {
++ptr_counts[*ptr_vals++];
}
Expand Down Expand Up @@ -2310,10 +2310,10 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in
const auto src_idx_len = src_intersection_offsets.const_data_ptr<int64_t>()[size - 1];
auto src_idx = at::empty({src_idx_len}, src.options());

const auto* ptr_src = src.data_ptr<int64_t>();
const auto* ptr_intersection_counts = intersection_counts.data_ptr<int64_t>();
const auto* ptr_src_intersection_counts = src_intersection_counts.data_ptr<int64_t>();
const auto* ptr_src_intersection_offsets = src_intersection_offsets.data_ptr<int64_t>();
const auto* ptr_src = src.const_data_ptr<int64_t>();
const auto* ptr_intersection_counts = intersection_counts.const_data_ptr<int64_t>();
const auto* ptr_src_intersection_counts = src_intersection_counts.const_data_ptr<int64_t>();
const auto* ptr_src_intersection_offsets = src_intersection_offsets.const_data_ptr<int64_t>();
auto* ptr_src_idx = src_idx.data_ptr<int64_t>();

const auto src_len = src.numel();
Expand Down Expand Up @@ -2362,16 +2362,16 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in
auto counts_per_thread = idx_counts_per_thread.mul_(src_counts).sum(-1);
return counts_per_thread.cumsum(0).sub_(counts_per_thread);
}();
const auto* ptr_thread_offset = thread_offset.data_ptr<int64_t>();
const auto* ptr_thread_offset = thread_offset.const_data_ptr<int64_t>();

auto idx_selected = at::empty({res_len}, idx.options());
auto src_selected = at::empty({res_len}, src.options());

const auto* ptr_idx = idx.data_ptr<int64_t>();
const auto* ptr_src_counts = src_counts.data_ptr<int64_t>();
const auto* ptr_intersection_counts = intersection_counts.data_ptr<int64_t>();
const auto* ptr_src_idx = src_idx.data_ptr<int64_t>();
const auto* ptr_src_idx_offsets = src_idx_offsets.data_ptr<int64_t>();
const auto* ptr_idx = idx.const_data_ptr<int64_t>();
const auto* ptr_src_counts = src_counts.const_data_ptr<int64_t>();
const auto* ptr_intersection_counts = intersection_counts.const_data_ptr<int64_t>();
const auto* ptr_src_idx = src_idx.const_data_ptr<int64_t>();
const auto* ptr_src_idx_offsets = src_idx_offsets.const_data_ptr<int64_t>();
auto* ptr_idx_selected = idx_selected.data_ptr<int64_t>();
auto* ptr_src_selected = src_selected.data_ptr<int64_t>();

Expand Down Expand Up @@ -2433,8 +2433,8 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in
}
}();

const auto* ptr_outer = outer.data_ptr<int64_t>();
const auto* ptr_inner = inner.data_ptr<int64_t>();
const auto* ptr_outer = outer.const_data_ptr<int64_t>();
const auto* ptr_inner = inner.const_data_ptr<int64_t>();
// NOTE: if very critical, replace std::vector with
// a data structure that operates on stack up to some limit.
auto outer_selected_idx = std::vector<int64_t>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ PackedLinearWeightQnnp::PackedLinearWeightQnnp(

std::tie(w_zero_points_, w_scales_) =
make_zero_points_and_scales_tensor(weight_contig);
const float* weight_scales_data = w_scales_.data_ptr<float>();
const float* weight_scales_data = w_scales_.const_data_ptr<float>();
at::Tensor qnnp_weight = at::_empty_affine_quantized(
weight_contig.sizes(),
at::device(c10::kCPU).dtype(c10::kQUInt8),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ BCSRSerializationType PackedLinearWeight::serialize() {
BCSRSerializationType PackedLinearWeightQnnp::serialize() {
at::Tensor w_scales_compact;
at::Tensor w_zero_points_compact;
const float* w_scales_data_ptr = w_scales_.data_ptr<float>();
const float* w_scales_data_ptr = w_scales_.const_data_ptr<float>();
std::function<int8_t(uint8_t)> subtract_128 = [](uint8_t v) {
return static_cast<int8_t>(static_cast<int16_t>(v) - 128);
};
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cpu/HistogramKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,10 @@ void infer_bin_edges_from_input(const Tensor& input, const int64_t N,

TORCH_INTERNAL_ASSERT(min.is_contiguous() && max.is_contiguous());

const scalar_t *min_data = min.data_ptr<scalar_t>();
const scalar_t *min_data = min.const_data_ptr<scalar_t>();
std::copy(min_data, min_data + N, leftmost_edges.begin());

const scalar_t *max_data = max.data_ptr<scalar_t>();
const scalar_t *max_data = max.const_data_ptr<scalar_t>();
std::copy(max_data, max_data + N, rightmost_edges.begin());
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/MultinomialKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ multinomial_with_replacement_apply(
/* cumulative probability distribution vector */
Tensor cum_dist = at::empty({n_categories}, self.options().dtype(kFloat));

const scalar_t* const self_ptr = self.data_ptr<scalar_t>();
const scalar_t* const self_ptr = self.const_data_ptr<scalar_t>();
float* const cum_dist_ptr = cum_dist.data_ptr<float>();
int64_t* const result_ptr = result.data_ptr<int64_t>();

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/SparseFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void _spdiags_kernel_cpu(
"spdiags_cpu",
[&] {
auto* const values_write_ptr = values.data_ptr<scalar_t>();
const auto* const diagonals_ptr = diagonals.data_ptr<scalar_t>();
const auto* const diagonals_ptr = diagonals.const_data_ptr<scalar_t>();

cpu_kernel(
iter,
Expand Down
18 changes: 9 additions & 9 deletions aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ at::Tensor unpack_rgb(const at::Tensor& packed_tensor) {
// into as 32 bits. This generalizes to num_channels <= 4 and also works for
// non-channels_last tensors.

const uint8_t* packed = (const uint8_t*)packed_tensor.data_ptr<uint8_t>();
const uint8_t* packed = (const uint8_t*)packed_tensor.const_data_ptr<uint8_t>();
auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
auto num_channels = packed_tensor.size(0);

Expand Down Expand Up @@ -180,18 +180,18 @@ void ImagingResampleHorizontal(
// Although this may not be needed if / when we port all this code to use
// Vec.h since this would potentially give us another fall-back implem

const int16_t* kk = (int16_t*)(horiz_indices_weights[3].data_ptr<double>());
const int16_t* kk = (int16_t*)(horiz_indices_weights[3].const_data_ptr<double>());

auto xout = unpacked_output.size(2);
auto yout = unpacked_output.size(1);
auto xin = unpacked_input.size(2);
TORCH_INTERNAL_ASSERT(num_channels == unpacked_input.size(0));

const int64_t* idx_ptr_xmin = horiz_indices_weights[0].data_ptr<int64_t>();
const int64_t* idx_ptr_size = horiz_indices_weights[1].data_ptr<int64_t>();
const int64_t* idx_ptr_xmin = horiz_indices_weights[0].const_data_ptr<int64_t>();
const int64_t* idx_ptr_size = horiz_indices_weights[1].const_data_ptr<int64_t>();

uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
const uint8_t* unpacked_input_p = unpacked_input.data_ptr<uint8_t>();
const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();

int64_t yy = 0;
auto xout_stride = xout * num_channels;
Expand Down Expand Up @@ -255,13 +255,13 @@ void ImagingResampleVertical(
// basic_loop_aa_vertical<uint8_t>)
// Although this may not be needed if / when we port all this code to use
// Vec.h since this would potentially give us another fall-back implem
const int16_t* kk = (int16_t*)(vert_indices_weights[3].data_ptr<double>());
const int16_t* kk = (int16_t*)(vert_indices_weights[3].const_data_ptr<double>());

const int64_t* idx_ptr_xmin = vert_indices_weights[0].data_ptr<int64_t>();
const int64_t* idx_ptr_size = vert_indices_weights[1].data_ptr<int64_t>();
const int64_t* idx_ptr_xmin = vert_indices_weights[0].const_data_ptr<int64_t>();
const int64_t* idx_ptr_size = vert_indices_weights[1].const_data_ptr<int64_t>();

uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
const uint8_t* unpacked_input_p = unpacked_input.data_ptr<uint8_t>();
const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();

auto xout = unpacked_output.size(2);
auto yout = unpacked_output.size(1);
Expand Down
Loading

0 comments on commit abcb42c

Please sign in to comment.