From 4df8a12ae9dcf71a85ba04a5a278be8bf2dba439 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Wed, 22 Jan 2025 14:40:17 -0800 Subject: [PATCH] Updates and fixes to tensor_accessor.h (2I/N) (#3571) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3571 X-link: https://github.com/facebookresearch/FBGEMM/pull/656 - Fix `TensorAccessorBase` constructor to work with empty tensors, which are used in FBGEMM code - Add better logging for errors Reviewed By: basilwong Differential Revision: D68048640 fbshipit-source-id: d0b7ead4dd032dc08993ef51c40bc83dab9a38b6 --- .../embedding_backward_split_cpu_template.cpp | 12 +- .../forward/embedding_forward_split_cpu.cpp | 37 ++-- .../fbgemm_gpu/embedding_forward_split_cpu.h | 7 +- .../fbgemm_gpu/utils/tensor_accessor.h | 184 ++++++++++++------ fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp | 28 ++- fbgemm_gpu/test/utils/tensor_accessor_test.cu | 92 +++++++++ .../tensor_accessor_with_memcheck_test.cu | 127 ++++++++++++ 7 files changed, 393 insertions(+), 94 deletions(-) create mode 100644 fbgemm_gpu/test/utils/tensor_accessor_test.cu create mode 100644 fbgemm_gpu/test/utils/tensor_accessor_with_memcheck_test.cu diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp index 29cc9eb8b8..3f6095c955 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp @@ -87,14 +87,16 @@ for (const auto t : c10::irange(num_tables)) { int feature_begin = table_to_feature_offset[t]; int64_t hash_size = get_hash_size(feature_begin); +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "::internal::csr2csc"; +#endif + using weight_t = at::acc_type; ::internal::csr2csc( cscs[t], B, - offsets.accessor(), - indices.accessor(), - indice_weights.defined() - ? indice_weights.accessor, 1>() - : at::TensorAccessor, 1>(nullptr, nullptr, nullptr), + MAKE_TA_WITH_NAME(func_name, offsets, int64_t, 1), + MAKE_TA_WITH_NAME(func_name, indices, int64_t, 1), + MAKE_TA_WITH_NAME(func_name, indice_weights, weight_t, 1), pooling_mode, table_to_feature_offset + t, hash_size); diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp index 720d0a2612..5117583415 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp @@ -14,6 +14,7 @@ #include "fbgemm_gpu/utils/cpu_utils.h" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" +#include "fbgemm_gpu/utils/tensor_accessor.h" #ifdef FBCODE_CAFFE2 #include #else @@ -384,9 +385,9 @@ template void csr2csc_template_( HyperCompressedSparseColumn& csc, int B, - const at::TensorAccessor& csr_offsets, - const at::TensorAccessor& csr_indices, - const at::TensorAccessor& csr_weights, + const pta::TensorAccessor& csr_offsets, + const pta::TensorAccessor& csr_indices, + const pta::TensorAccessor& csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, int64_t num_embeddings) { @@ -585,9 +586,9 @@ void csr2csc_template_( template void csr2csc_template_( \ HyperCompressedSparseColumn & csc, \ int B, \ - const at::TensorAccessor& csr_offsets, \ - const at::TensorAccessor& csr_indices, \ - const at::TensorAccessor& csr_weights, \ + const pta::TensorAccessor& csr_offsets, \ + const pta::TensorAccessor& csr_indices, \ + const pta::TensorAccessor& csr_weights, \ int64_t pooling_mode, \ const int* table_to_feature_offset, \ int64_t num_embeddings); @@ -613,9 +614,9 @@ template void csr2csc( HyperCompressedSparseColumn& csc, int B, - const at::TensorAccessor& csr_offsets, - const at::TensorAccessor& csr_indices, - const at::TensorAccessor& csr_weights, + const pta::TensorAccessor& csr_offsets, + const pta::TensorAccessor& csr_indices, + const pta::TensorAccessor& csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, int64_t num_embeddings) { @@ -644,15 +645,15 @@ void csr2csc( } } -#define INSTANTIATE_CSR2CSC_0(index_t, scalar_t) \ - template void csr2csc( \ - HyperCompressedSparseColumn & csc, \ - int B, \ - const at::TensorAccessor& csr_offsets, \ - const at::TensorAccessor& csr_indices, \ - const at::TensorAccessor& csr_weights, \ - int64_t pooling_mode, \ - const int* table_to_feature_offset, \ +#define INSTANTIATE_CSR2CSC_0(index_t, scalar_t) \ + template void csr2csc( \ + HyperCompressedSparseColumn & csc, \ + int B, \ + const pta::TensorAccessor& csr_offsets, \ + const pta::TensorAccessor& csr_indices, \ + const pta::TensorAccessor& csr_weights, \ + int64_t pooling_mode, \ + const int* table_to_feature_offset, \ int64_t num_embeddings); #define INSTANTIATE_CSR2CSC_1(index_t) \ diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h index 2025f9d7fb..b06e7f878e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h @@ -11,6 +11,7 @@ #include #include #include "fbgemm/Utils.h" +#include "fbgemm_gpu/utils/tensor_accessor.h" at::Tensor split_embedding_codegen_forward_cpu( at::Tensor weights, @@ -120,9 +121,9 @@ template void csr2csc( HyperCompressedSparseColumn& csc, int B, - const at::TensorAccessor& csr_offsets, - const at::TensorAccessor& csr_indices, - const at::TensorAccessor& csr_weights, + const pta::TensorAccessor& csr_offsets, + const pta::TensorAccessor& csr_indices, + const pta::TensorAccessor& csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, int64_t num_embeddings); diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h index ae2604b6d1..c39d72c5b4 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h @@ -19,10 +19,39 @@ #include #include +//////////////////////////////////////////////////////////////////////////////// +// Extended TensorAccessor +// +// This file contains TensorAccessor and PackedTensorAccessor implementations +// that are used in FBGEMM_GPU for additional bounds checks that are not +// available in the standard ATen implementation. Using the builder macro +// MAKE_TA_WITH_NAME and MAKE_PTA_WITH_NAME, bounds checks can be enabled using +// the FBGEMM_GPU_MEMCHECK flag. +// +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/TensorAccessor.h +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/TensorBase.h +//////////////////////////////////////////////////////////////////////////////// + namespace fbgemm_gpu { -static constexpr size_t PTR_NAME_MAX_LEN = 16; -static constexpr size_t FUNC_NAME_MAX_LEN = 64; +static constexpr size_t PTR_NAME_MAX_LEN = 32; +static constexpr size_t FUNC_NAME_MAX_LEN = 128; + +C10_HOST_DEVICE inline void +copy_str(char* dst, const char* src, const size_t max_len) { + // Count src buffer length up to max_len + size_t len = 0; + for (len = 0; src[len] != 0 && len < max_len; len++) { + // no action - calculating string length + } + len = len < (max_len - 1) ? len : (max_len - 1); + + // Copy src to dst + for (auto i = 0; i < len; i++) { + dst[i] = src[i]; + } + dst[len] = '\0'; +} // The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor // is used to enable the __restrict__ keyword/modifier for the data @@ -53,7 +82,7 @@ struct RestrictPtrTraits { template < typename T, size_t N, - template class PtrTraits = DefaultPtrTraits, + template class PtrTraits = at::DefaultPtrTraits, typename index_t = int64_t> class TensorAccessorBase { public: @@ -65,16 +94,17 @@ class TensorAccessorBase { const index_t* const strides, const char* const ptr_name, const char* const func_name) - : data_(data), - sizes_(sizes), - strides_(strides), - ptr_name_(ptr_name), - func_name_(func_name) { - numel_ = 1; - for (size_t d = 0; d < N; d++) { - numel_ += (sizes[d] - 1) * strides[d]; + : data_(data), sizes_(sizes), strides_(strides) { + if (sizes && strides) { + numel_ = 1; + for (size_t d = 0; d < N; d++) { + numel_ += (sizes[d] - 1) * strides[d]; + } } + copy_str(ptr_name_, ptr_name, PTR_NAME_MAX_LEN); + copy_str(func_name_, func_name, FUNC_NAME_MAX_LEN); } + C10_HOST at::IntArrayRef sizes() const { return at::IntArrayRef(sizes_, N); } @@ -93,17 +123,20 @@ class TensorAccessorBase { C10_HOST_DEVICE const PtrType data() const { return data_; } + C10_HOST_DEVICE T& at(index_t idx) const { if (idx < 0) { printf( - "ERROR: idx < 0, tensor %s in %s, idx %lld\n", + "ERROR: idx < 0, tensor %s in %s, idx %ld\n", ptr_name_, func_name_, static_cast(idx)); + // NOTE: CUDA_KERNEL_ASSERT appears to be a no-op when HIPified; need to + // figure a workaround for this. CUDA_KERNEL_ASSERT(idx >= 0) } else if (idx >= numel_) { printf( - "ERROR: idx >= numel, tensor %s in %s, idx %lld, numel %lld\n", + "ERROR: idx >= numel, tensor %s in %s, idx %ld, numel %ld\n", ptr_name_, func_name_, static_cast(idx), @@ -118,8 +151,8 @@ class TensorAccessorBase { const index_t* const sizes_; const index_t* const strides_; index_t numel_; - const char* const ptr_name_; - const char* const func_name_; + char ptr_name_[PTR_NAME_MAX_LEN]; + char func_name_[FUNC_NAME_MAX_LEN]; }; // The `TensorAccessor` is typically instantiated for CPU `Tensor`s using @@ -129,7 +162,7 @@ class TensorAccessorBase { template < typename T, size_t N, - template class PtrTraits = DefaultPtrTraits, + template class PtrTraits = at::DefaultPtrTraits, typename index_t = int64_t> class TensorAccessor : public TensorAccessorBase { public: @@ -228,10 +261,12 @@ class GenericPackedTensorAccessorBase { : data_(data) { std::copy(sizes, sizes + N, std::begin(sizes_)); std::copy(strides, strides + N, std::begin(strides_)); - // Compute numel_ - numel_ = 1; - for (const auto d : c10::irange(N)) { - numel_ += (sizes[d] - 1) * strides[d]; + if (sizes != nullptr && strides != nullptr) { + // Compute numel_ + numel_ = 1; + for (const auto d : c10::irange(N)) { + numel_ += (sizes[d] - 1) * strides[d]; + } } copy_str(ptr_name_, ptr_name, PTR_NAME_MAX_LEN); copy_str(func_name_, func_name, FUNC_NAME_MAX_LEN); @@ -249,36 +284,32 @@ class GenericPackedTensorAccessorBase { const char* const ptr_name, const char* const func_name) : data_(data) { - for (const auto i : c10::irange(N)) { - this->sizes_[i] = sizes[i]; - this->strides_[i] = strides[i]; - } - // Compute numel_ - numel_ = 1; - for (const auto d : c10::irange(N)) { - numel_ += (sizes[d] - 1) * strides[d]; + if (sizes != nullptr && strides != nullptr) { + for (const auto i : c10::irange(N)) { + this->sizes_[i] = sizes[i]; + this->strides_[i] = strides[i]; + } + // Compute numel_ + numel_ = 1; + for (const auto d : c10::irange(N)) { + numel_ += (sizes[d] - 1) * strides[d]; + } } copy_str(ptr_name_, ptr_name, PTR_NAME_MAX_LEN); copy_str(func_name_, func_name, FUNC_NAME_MAX_LEN); } - C10_HOST void copy_str(char* dst, const char* src, const size_t max_len) { - const auto len = std::min(strlen(src), max_len - 1); - std::memcpy(dst, src, sizeof(char) * len); - dst[len] = '\0'; - } - C10_HOST_DEVICE T& at(index_t idx) const { if (idx < 0) { printf( - "ERROR: idx < 0, tensor %s in %s, idx %lld\n", + "ERROR: idx < 0, tensor %s in %s, idx %ld\n", ptr_name_, func_name_, static_cast(idx)); CUDA_KERNEL_ASSERT(idx >= 0) } else if (idx >= numel_) { printf( - "ERROR: idx >= numel, tensor %s in %s, idx %lld, numel %lld\n", + "ERROR: idx >= numel, tensor %s in %s, idx %ld, numel %ld\n", ptr_name_, func_name_, static_cast(idx), @@ -308,9 +339,17 @@ class GenericPackedTensorAccessorBase { index_t numel_; char ptr_name_[PTR_NAME_MAX_LEN]; char func_name_[FUNC_NAME_MAX_LEN]; + C10_HOST void bounds_check_(index_t i) const { TORCH_CHECK_INDEX( 0 <= i && i < index_t{N}, +#ifdef FBGEMM_GPU_MEMCHECK + "[ ", + func_name_, + " ][ ", + ptr_name_, + " ]: ", +#endif "Index ", i, " is not within bounds of a tensor of dimension ", @@ -525,7 +564,7 @@ inline void check_tensor_dim( #endif "to have ", N, - "dims, but found ", + " dims, but found ", tensor.dim(), " instead!"); } @@ -587,36 +626,52 @@ inline pta::TensorAccessor make_tensor_accessor( static_assert( N > 0, - "accessor is used for indexing tensor, for scalars use *data_ptr()"); + "Accessor is used for indexing tensor, for scalars use *data_ptr()"); - fbgemm_gpu::check_tensor_dim( - tensor + // If the tensor is defined, then check the tensor dimensions and scalar type + // before building and returning the accessor. + if (tensor.defined()) { + fbgemm_gpu::check_tensor_dim( + tensor #ifdef FBGEMM_GPU_MEMCHECK - , - func_name, - tensor_name + , + func_name, + tensor_name #endif - ); + ); - fbgemm_gpu::check_scalar_type( - tensor + fbgemm_gpu::check_scalar_type( + tensor #ifdef FBGEMM_GPU_MEMCHECK - , - func_name, - tensor_name + , + func_name, + tensor_name #endif - ); + ); #ifdef FBGEMM_GPU_MEMCHECK - return fbgemm_gpu::TensorAccessor( - static_cast::PtrType>(tensor.data_ptr()), - tensor.sizes().data(), - tensor.strides().data(), - tensor_name, - func_name); + return fbgemm_gpu::TensorAccessor( + static_cast::PtrType>(tensor.data_ptr()), + tensor.sizes().data(), + tensor.strides().data(), + tensor_name, + func_name); #else - return tensor.accessor(); + return tensor.accessor(); #endif + + } else { + // Else, just return a null tensor accessor - this is useful for cases where + // optionals are not used. + +#ifdef FBGEMM_GPU_MEMCHECK + return fbgemm_gpu::TensorAccessor( + nullptr, nullptr, nullptr, tensor_name, func_name); +#else + return pta::TensorAccessor( + nullptr, nullptr, nullptr); +#endif + } } //////////////////////////////////////////////////////////////////////////////// @@ -678,7 +733,7 @@ template < pta::PackedTensorAccessor32 make_packed_tensor_accessor32( #ifdef FBGEMM_GPU_MEMCHECK const at::Tensor& tensor, - const char* const ptr_name, + const char* const tensor_name, const char* const func_name) { #else const at::Tensor& tensor) { @@ -687,11 +742,18 @@ pta::PackedTensorAccessor32 make_packed_tensor_accessor32( TORCH_CHECK( tensor.numel() <= static_cast(std::numeric_limits::max()), +#ifdef FBGEMM_GPU_MEMCHECK + "[ ", + func_name, + " ]: Tensor ", + tensor_name, + " ", +#endif "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64"); #ifdef FBGEMM_GPU_MEMCHECK return make_generic_packed_tensor_accessor( - tensor, ptr_name, func_name); + tensor, tensor_name, func_name); #else return tensor.packed_accessor32(); #endif @@ -704,7 +766,7 @@ template < pta::PackedTensorAccessor64 make_packed_tensor_accessor64( #ifdef FBGEMM_GPU_MEMCHECK const at::Tensor& tensor, - const char* const ptr_name, + const char* const tensor_name, const char* const func_name) { #else const at::Tensor& tensor) { @@ -712,7 +774,7 @@ pta::PackedTensorAccessor64 make_packed_tensor_accessor64( #ifdef FBGEMM_GPU_MEMCHECK return make_generic_packed_tensor_accessor( - tensor, ptr_name, func_name); + tensor, tensor_name, func_name); #else return tensor.packed_accessor64(); #endif diff --git a/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp b/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp index 8a94d1b370..b9661c126d 100644 --- a/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp +++ b/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp @@ -13,8 +13,15 @@ #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/embedding_forward_split_cpu.h" +#include "fbgemm_gpu/utils/tensor_accessor.h" #include "torch/types.h" // @manual=//caffe2:torch-cpp-cpu +#if FBGEMM_GPU_MEMCHECK +#define FBGEMM_MEM_CHECK_ONLY +#else +#define FBGEMM_MEM_CHECK_ONLY maybe_unused +#endif + template void test_csr2csc() { internal::HyperCompressedSparseColumn csc; @@ -27,13 +34,14 @@ void test_csr2csc() { int table_to_feature_offset[2] = {0, 1}; int num_embeddings = 10; + const auto no_weights = at::empty({0}, at::TensorOptions().dtype(at::kFloat)); + [[FBGEMM_MEM_CHECK_ONLY]] const auto func_name1 = "::internal::csr2csc_1"; ::internal::csr2csc( csc, B, - offsets.accessor(), - indices.accessor(), - at::TensorAccessor, 1>( - nullptr, nullptr, nullptr), // no weights + MAKE_TA_WITH_NAME(func_name1, offsets, T, 1), + MAKE_TA_WITH_NAME(func_name1, indices, T, 1), + MAKE_TA_WITH_NAME(func_name1, no_weights, float, 1), pooling_mode, table_to_feature_offset, num_embeddings); @@ -61,12 +69,16 @@ void test_csr2csc() { internal::HyperCompressedSparseColumn csc_weighted; at::Tensor indice_weights = torch::tensor( {1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f}, torch::kFloat32); + + [[maybe_unused]] const auto func_name2 = "::internal::csr2csc_2"; + using weight_t = at::acc_type; + ::internal::csr2csc( csc_weighted, B, - offsets.accessor(), - indices.accessor(), - indice_weights.accessor, 1>(), + MAKE_TA_WITH_NAME(func_name2, offsets, T, 1), + MAKE_TA_WITH_NAME(func_name2, indices, T, 1), + MAKE_TA_WITH_NAME(func_name2, indice_weights, weight_t, 1), pooling_mode, table_to_feature_offset, num_embeddings); @@ -99,3 +111,5 @@ TEST(CpuKernelTest, csr2csc_test_int32) { TEST(CpuKernelTest, csr2csc_test_int64) { test_csr2csc(); } + +#undef FBGEMM_MEM_CHECK_ONLY diff --git a/fbgemm_gpu/test/utils/tensor_accessor_test.cu b/fbgemm_gpu/test/utils/tensor_accessor_test.cu new file mode 100644 index 0000000000..a69b933d94 --- /dev/null +++ b/fbgemm_gpu/test/utils/tensor_accessor_test.cu @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include // @manual=//caffe2:torch-cpp-cpu + +// DISABLE compilation in FBGEMM_GPU_MEMCHECK mode as a test +#ifdef FBGEMM_GPU_MEMCHECK +#undef FBGEMM_GPU_MEMCHECK +#endif + +#include "fbgemm_gpu/utils/tensor_accessor.h" + +template +void test_ta_create_1(const at::Tensor& tensor) { + [[maybe_unused]] const auto func_name = "test_ta_create"; + [[maybe_unused]] const auto accessor = + MAKE_TA_WITH_NAME(func_name, tensor, T, 1); +} + +template +void test_ta_create_2(const at::Tensor& tensor) { + [[maybe_unused]] const auto func_name = "test_ta_create"; + [[maybe_unused]] const auto accessor = + MAKE_TA_WITH_NAME(func_name, tensor, float, N); +} + +void test_ta_create_3(const at::Tensor& tensor) { + [[maybe_unused]] const auto func_name = "test_ta_create"; + [[maybe_unused]] const auto accessor = + MAKE_TA_WITH_NAME(func_name, tensor, float, 1); +} + +TEST(TensorAccessorTest, test_ta_create) { + const auto tensor = torch::tensor( + {1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f}, torch::kFloat32); + // Test mismatched types + EXPECT_THROW({ test_ta_create_1(tensor); }, std::exception); + EXPECT_THROW({ test_ta_create_1(tensor); }, std::exception); + EXPECT_THROW({ test_ta_create_1(tensor); }, std::exception); + + // Test invalid dimensions + EXPECT_THROW({ test_ta_create_2<2>(tensor); }, std::exception); + EXPECT_THROW({ test_ta_create_2<3>(tensor); }, std::exception); + EXPECT_THROW({ test_ta_create_2<4>(tensor); }, std::exception); + + // Test valid type and dimension + EXPECT_NO_THROW({ test_ta_create_3(tensor); }); +} + +template +void test_pta_create_1(const at::Tensor& tensor) { + [[maybe_unused]] const auto func_name = "test_pta_create"; + [[maybe_unused]] const auto accessor = + MAKE_PTA_WITH_NAME(func_name, tensor, T, 1, 64); +} + +template +void test_pta_create_2(const at::Tensor& tensor) { + [[maybe_unused]] const auto func_name = "test_pta_create"; + [[maybe_unused]] const auto accessor = + MAKE_PTA_WITH_NAME(func_name, tensor, float, N, 64); +} + +void test_pta_create_3(const at::Tensor& tensor) { + [[maybe_unused]] const auto func_name = "test_pta_create"; + [[maybe_unused]] const auto accessor = + MAKE_PTA_WITH_NAME(func_name, tensor, float, 1, 64); +} + +TEST(PackedTensorAccessorTest, test_pta_create) { + const auto tensor = torch::tensor( + {1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f}, torch::kFloat32); + // Test mismatched types + EXPECT_THROW({ test_pta_create_1(tensor); }, std::exception); + EXPECT_THROW({ test_pta_create_1(tensor); }, std::exception); + EXPECT_THROW({ test_pta_create_1(tensor); }, std::exception); + + // Test invalid dimensions + EXPECT_THROW({ test_pta_create_2<2>(tensor); }, std::exception); + EXPECT_THROW({ test_pta_create_2<3>(tensor); }, std::exception); + EXPECT_THROW({ test_pta_create_2<4>(tensor); }, std::exception); + + // Test valid type and dimension + EXPECT_NO_THROW({ test_pta_create_3(tensor); }); +} diff --git a/fbgemm_gpu/test/utils/tensor_accessor_with_memcheck_test.cu b/fbgemm_gpu/test/utils/tensor_accessor_with_memcheck_test.cu new file mode 100644 index 0000000000..29651e75a1 --- /dev/null +++ b/fbgemm_gpu/test/utils/tensor_accessor_with_memcheck_test.cu @@ -0,0 +1,127 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include // @manual=//caffe2:torch-cpp-cpu + +// ENABLE compilation in FBGEMM_GPU_MEMCHECK mode as a test +#ifndef FBGEMM_GPU_MEMCHECK +#define FBGEMM_GPU_MEMCHECK +#endif + +#include "fbgemm_gpu/utils/tensor_accessor.h" + +template +void test_ta_create_1(const at::Tensor& tensor) { + const auto func_name = "test_ta_make"; + [[maybe_unused]] const auto accessor = + MAKE_TA_WITH_NAME(func_name, tensor, T, 1); +} + +template +void test_ta_create_2(const at::Tensor& tensor) { + const auto func_name = "test_ta_make"; + [[maybe_unused]] const auto accessor = + MAKE_TA_WITH_NAME(func_name, tensor, float, N); +} + +void test_ta_create_3(const at::Tensor& tensor) { + const auto func_name = "test_ta_make"; + [[maybe_unused]] const auto accessor = + MAKE_TA_WITH_NAME(func_name, tensor, float, 1); +} + +TEST(TensorAccessorWithMemcheckTest, test_ta_create) { + const auto tensor = torch::tensor( + {1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f}, torch::kFloat32); + // Test mismatched types + EXPECT_THROW({ test_ta_create_1(tensor); }, std::exception); + EXPECT_THROW({ test_ta_create_1(tensor); }, std::exception); + EXPECT_THROW({ test_ta_create_1(tensor); }, std::exception); + + // Test invalid dimensions + EXPECT_THROW({ test_ta_create_2<2>(tensor); }, std::exception); + EXPECT_THROW({ test_ta_create_2<3>(tensor); }, std::exception); + EXPECT_THROW({ test_ta_create_2<4>(tensor); }, std::exception); + + // Test valid type and dimension + EXPECT_NO_THROW({ test_ta_create_3(tensor); }); +} + +template +void test_ta_access() { + const auto func_name = "ta_access"; + const auto tensor = at::empty({0}, at::TensorOptions().dtype(DType)); + const auto accessor = MAKE_TA_WITH_NAME(func_name, tensor, T, 1); + + EXPECT_DEATH({ accessor.at(10); }, "idx < numel_"); +} + +// NOTE: CUDA_KERNEL_ASSERT appears to be a no-op when HIPified +#ifndef __HIPCC__ +TEST(TensorAccessorWithMemcheckTest, test_ta_access) { + test_ta_access(); + test_ta_access(); +} +#endif + +template +void test_pta_create_1(const at::Tensor& tensor) { + [[maybe_unused]] const auto func_name = "test_pta_create"; + [[maybe_unused]] const auto accessor = + MAKE_PTA_WITH_NAME(func_name, tensor, T, 1, 64); +} + +template +void test_pta_create_2(const at::Tensor& tensor) { + [[maybe_unused]] const auto func_name = "test_pta_create"; + [[maybe_unused]] const auto accessor = + MAKE_PTA_WITH_NAME(func_name, tensor, float, N, 64); +} + +void test_pta_create_3(const at::Tensor& tensor) { + [[maybe_unused]] const auto func_name = "test_pta_create"; + [[maybe_unused]] const auto accessor = + MAKE_PTA_WITH_NAME(func_name, tensor, float, 1, 64); +} + +TEST(PackedTensorAccessorTest, test_pta_create) { + const auto tensor = torch::tensor( + {1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f}, torch::kFloat32); + // Test mismatched types + EXPECT_THROW({ test_pta_create_1(tensor); }, std::exception); + EXPECT_THROW({ test_pta_create_1(tensor); }, std::exception); + EXPECT_THROW({ test_pta_create_1(tensor); }, std::exception); + + // Test invalid dimensions + EXPECT_THROW({ test_pta_create_2<2>(tensor); }, std::exception); + EXPECT_THROW({ test_pta_create_2<3>(tensor); }, std::exception); + EXPECT_THROW({ test_pta_create_2<4>(tensor); }, std::exception); + + // Test valid type and dimension + EXPECT_NO_THROW({ test_pta_create_3(tensor); }); +} + +template +void test_pta_access() { + const auto func_name = "test_pta_access"; + const auto tensor = at::empty({0}, at::TensorOptions().dtype(DType)); + const auto accessor = MAKE_PTA_WITH_NAME(func_name, tensor, T, 1, 64); + + EXPECT_DEATH({ accessor.at(10); }, "idx < numel_"); +} + +#ifndef __HIPCC__ +TEST(PackedTensorAccessorWithMemcheckTest, test_pta_access) { + test_pta_access(); + test_pta_access(); +} +#endif + +#undef FBGEMM_GPU_MEMCHECK