Skip to content

Commit

Permalink
[BE] use DeviceIndex instead of int64_t for related device interfaces (
Browse files Browse the repository at this point in the history
…pytorch#103068)

This PR unifies the device interfaces in aten/*cpp and torch/csrc/*cpp to use  **c10::DeviceIndex**.
Pull Request resolved: pytorch#103068
Approved by: https://github.com/malfet
  • Loading branch information
cyyever authored and pytorchmergebot committed Aug 25, 2023
1 parent 4656e09 commit d9fb716
Show file tree
Hide file tree
Showing 18 changed files with 73 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ namespace impl {
c10::QScheme,
c10::ScalarType,
c10::Device,
c10::DeviceIndex,
c10::Layout,
c10::MemoryFormat,
at::Dimname
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <c10/util/Optional.h>
#include <c10/core/SymFloat.h>
#include <c10/core/SymBool.h>
#include <c10/core/Device.h>

#include <array>
#include <memory>
Expand Down Expand Up @@ -1869,6 +1870,13 @@ struct getTypePtr_<int64_t> final {
}
};

template <>
struct getTypePtr_<DeviceIndex> final {
static decltype(auto) call() {
return IntType::get();
}
};

template <>
struct getMaybeFakeTypePtr_<SymInt, false> final {
static decltype(auto) call() {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/core/op_registration/infer_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ constexpr int checkStaticTypes() {
// Give nice error messages for some of the common error cases.
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
static_assert(guts::conjunction<
bool_t<!std::is_integral<Types>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
>::value, "INVALID TYPE: Only int64_t and bool are supported as an integral argument type");
bool_t<!std::is_integral<Types>::value || std::is_same<Types, int8_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
>::value, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type");
static_assert(guts::conjunction<
bool_t<!std::is_same<Types, float>::value>...
>::value, "INVALID TYPE: float is not supported as an argument type, use double instead");
Expand Down
22 changes: 11 additions & 11 deletions aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@
#include <memory>

namespace c10::cuda::_internal {
void setHasPrimaryContext(bool (*func)(int64_t));
void setHasPrimaryContext(bool (*func)(DeviceIndex));
}

namespace at::cuda::detail {

const at::cuda::NVRTC& nvrtc();
int64_t current_device();
DeviceIndex current_device();

static void (*magma_init_fn)() = nullptr;

Expand All @@ -57,7 +57,7 @@ void set_magma_init_fn(void (*fn)()) {
}

namespace {
bool _hasPrimaryContext(int64_t device_index) {
bool _hasPrimaryContext(DeviceIndex device_index) {
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
"hasPrimaryContext expects a valid device index, but got device_index=", device_index);
unsigned int ctx_flags;
Expand Down Expand Up @@ -226,7 +226,7 @@ const at::cuda::NVRTC& CUDAHooks::nvrtc() const {
return at::cuda::detail::nvrtc();
}

int64_t current_device() {
DeviceIndex current_device() {
int device;
cudaError_t err = c10::cuda::GetDevice(&device);
if (err == cudaSuccess) {
Expand All @@ -235,11 +235,11 @@ int64_t current_device() {
return -1;
}

int64_t CUDAHooks::current_device() const {
DeviceIndex CUDAHooks::current_device() const {
return at::cuda::detail::current_device();
}

bool CUDAHooks::hasPrimaryContext(int64_t device_index) const {
bool CUDAHooks::hasPrimaryContext(DeviceIndex device_index) const {
return _hasPrimaryContext(device_index);
}

Expand Down Expand Up @@ -414,27 +414,27 @@ double CUDAHooks::batchnormMinEpsilonCuDNN() const {
#endif
}

int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize(int64_t device_index) const {
int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const {
return at::native::detail::cufft_get_plan_cache_max_size_impl(device_index);
}

void CUDAHooks::cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const {
void CUDAHooks::cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const {
at::native::detail::cufft_set_plan_cache_max_size_impl(device_index, max_size);
}

int64_t CUDAHooks::cuFFTGetPlanCacheSize(int64_t device_index) const {
int64_t CUDAHooks::cuFFTGetPlanCacheSize(DeviceIndex device_index) const {
return at::native::detail::cufft_get_plan_cache_size_impl(device_index);
}

void CUDAHooks::cuFFTClearPlanCache(int64_t device_index) const {
void CUDAHooks::cuFFTClearPlanCache(DeviceIndex device_index) const {
at::native::detail::cufft_clear_plan_cache_impl(device_index);
}

int CUDAHooks::getNumGPUs() const {
return at::cuda::device_count();
}

void CUDAHooks::deviceSynchronize(int64_t device_index) const {
void CUDAHooks::deviceSynchronize(DeviceIndex device_index) const {
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
c10::cuda::device_synchronize();
}
Expand Down
14 changes: 7 additions & 7 deletions aten/src/ATen/cuda/detail/CUDAHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCuSOLVER() const override;
bool hasROCM() const override;
const at::cuda::NVRTC& nvrtc() const override;
int64_t current_device() const override;
bool hasPrimaryContext(int64_t device_index) const override;
DeviceIndex current_device() const override;
bool hasPrimaryContext(DeviceIndex device_index) const override;
Allocator* getCUDADeviceAllocator() const override;
Allocator* getPinnedMemoryAllocator() const override;
bool compiledWithCuDNN() const override;
Expand All @@ -43,12 +43,12 @@ struct CUDAHooks : public at::CUDAHooksInterface {
long versionCuDNN() const override;
std::string showConfig() const override;
double batchnormMinEpsilonCuDNN() const override;
int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const override;
void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const override;
int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override;
void cuFFTClearPlanCache(int64_t device_index) const override;
int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
void cuFFTClearPlanCache(DeviceIndex device_index) const override;
int getNumGPUs() const override;
void deviceSynchronize(int64_t device_index) const override;
void deviceSynchronize(DeviceIndex device_index) const override;
};

}}} // at::cuda::detail
14 changes: 7 additions & 7 deletions aten/src/ATen/detail/CUDAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ struct TORCH_API CUDAHooksInterface {
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
}

virtual bool hasPrimaryContext(int64_t device_index) const {
virtual bool hasPrimaryContext(DeviceIndex device_index) const {
TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
}

virtual int64_t current_device() const {
virtual DeviceIndex current_device() const {
return -1;
}

Expand Down Expand Up @@ -167,27 +167,27 @@ struct TORCH_API CUDAHooksInterface {
"Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP);
}

virtual int64_t cuFFTGetPlanCacheMaxSize(int64_t /*device_index*/) const {
virtual int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex /*device_index*/) const {
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}

virtual void cuFFTSetPlanCacheMaxSize(int64_t /*device_index*/, int64_t /*max_size*/) const {
virtual void cuFFTSetPlanCacheMaxSize(DeviceIndex /*device_index*/, int64_t /*max_size*/) const {
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}

virtual int64_t cuFFTGetPlanCacheSize(int64_t /*device_index*/) const {
virtual int64_t cuFFTGetPlanCacheSize(DeviceIndex /*device_index*/) const {
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}

virtual void cuFFTClearPlanCache(int64_t /*device_index*/) const {
virtual void cuFFTClearPlanCache(DeviceIndex /*device_index*/) const {
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}

virtual int getNumGPUs() const {
return 0;
}

virtual void deviceSynchronize(int64_t /*device_index*/) const {
virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP);
}
};
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,19 +793,19 @@ Tensor fft_ifftshift(const Tensor& x, at::OptionalIntArrayRef dim_opt) {

// We call the following methods via CUDA hooks because they are really only
// valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details.
int64_t _cufft_get_plan_cache_max_size(int64_t device_index) {
int64_t _cufft_get_plan_cache_max_size(DeviceIndex device_index) {
return detail::getCUDAHooks().cuFFTGetPlanCacheMaxSize(device_index);
}

void _cufft_set_plan_cache_max_size(int64_t device_index, int64_t max_size) {
void _cufft_set_plan_cache_max_size(DeviceIndex device_index, int64_t max_size) {
detail::getCUDAHooks().cuFFTSetPlanCacheMaxSize(device_index, max_size);
}

int64_t _cufft_get_plan_cache_size(int64_t device_index) {
int64_t _cufft_get_plan_cache_size(DeviceIndex device_index) {
return detail::getCUDAHooks().cuFFTGetPlanCacheSize(device_index);
}

void _cufft_clear_plan_cache(int64_t device_index) {
void _cufft_clear_plan_cache(DeviceIndex device_index) {
detail::getCUDAHooks().cuFFTClearPlanCache(device_index);
}

Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/cuda/CuFFTPlanCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,9 @@ class CuFFTParamsLRUCache {
// native function counterparts (at native/SpectralOps.cpp), i.e.,
// _cufft_get_plan_cache_max_size, _cufft_set_plan_cache_max_size
// _cufft_get_plan_cache_size, and _cufft_clear_plan_cache.
int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index);
void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size);
int64_t cufft_get_plan_cache_size_impl(int64_t device_index);
void cufft_clear_plan_cache_impl(int64_t device_index);
int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index);
void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size);
int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index);
void cufft_clear_plan_cache_impl(DeviceIndex device_index);

}}} // namespace at::native::detail
10 changes: 5 additions & 5 deletions aten/src/ATen/native/cuda/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ static std::vector<std::unique_ptr<CuFFTParamsLRUCache>> plan_caches;
static std::mutex plan_caches_mutex;

static inline
CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) {
CuFFTParamsLRUCache &cufft_get_plan_cache(DeviceIndex device_index) {
std::lock_guard<std::mutex> guard(plan_caches_mutex);

AT_ASSERT(device_index >= 0);
Expand All @@ -152,31 +152,31 @@ CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) {

namespace detail {

int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) {
int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index) {
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_get_plan_cache_max_size: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
device_index);
return cufft_get_plan_cache(device_index).max_size();
}

void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size) {
void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size) {
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_set_plan_cache_max_size: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
device_index);
return cufft_get_plan_cache(device_index).resize(max_size);
}

int64_t cufft_get_plan_cache_size_impl(int64_t device_index) {
int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index) {
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_get_plan_cache_size: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
device_index);
return cufft_get_plan_cache(device_index).size();
}

void cufft_clear_plan_cache_impl(int64_t device_index) {
void cufft_clear_plan_cache_impl(DeviceIndex device_index) {
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_clear_plan_cache: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2945,13 +2945,13 @@
CPU: _validate_compressed_sparse_indices_cpu
CUDA: _validate_compressed_sparse_indices_cuda

- func: _cufft_get_plan_cache_size(int device_index) -> int
- func: _cufft_get_plan_cache_size(DeviceIndex device_index) -> int

- func: _cufft_get_plan_cache_max_size(int device_index) -> int
- func: _cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int

- func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> ()
- func: _cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> ()

- func: _cufft_clear_plan_cache(int device_index) -> ()
- func: _cufft_clear_plan_cache(DeviceIndex device_index) -> ()

- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down
6 changes: 3 additions & 3 deletions c10/core/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ struct C10_API Device final {
// This is safe to do, because backends that use the DeviceIndex
// have a later check when we actually try to switch to that device.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
index_ == -1 || index_ >= 0,
index_ >= -1,
"Device index must be -1 or non-negative, got ",
(int)index_);
static_cast<int>(index_));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!is_cpu() || index_ <= 0,
"CPU device index must be -1 or zero, got ",
(int)index_);
static_cast<int>(index_));
}
};

Expand Down
12 changes: 6 additions & 6 deletions c10/cuda/CUDAFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ void warn_or_error_on_sync() {
}
}

c10::optional<int64_t> getDeviceIndexWithPrimaryContext() {
c10::optional<DeviceIndex> getDeviceIndexWithPrimaryContext() {
// check current device first
int64_t current_device_index = current_device();
auto current_device_index = current_device();
if (current_device_index >= 0) {
if (hasPrimaryContext(current_device_index)) {
return current_device_index;
Expand All @@ -167,18 +167,18 @@ c10::optional<int64_t> getDeviceIndexWithPrimaryContext() {
}

namespace _internal {
bool dummyHasPrimaryContext(C10_UNUSED int64_t device_index) {
bool dummyHasPrimaryContext(C10_UNUSED DeviceIndex device_index) {
TORCH_CHECK(false, "Should never been called");
}
bool (*hasPrimaryContext)(int64_t) = dummyHasPrimaryContext;
bool (*hasPrimaryContext)(DeviceIndex) = dummyHasPrimaryContext;

// Private api to be called from CUDAHooks.cpp
C10_CUDA_API void setHasPrimaryContext(bool (*func)(int64_t)) {
C10_CUDA_API void setHasPrimaryContext(bool (*func)(DeviceIndex)) {
hasPrimaryContext = func ? func : dummyHasPrimaryContext;
}
} // namespace _internal

bool hasPrimaryContext(int64_t device_index) {
bool hasPrimaryContext(DeviceIndex device_index) {
return _internal::hasPrimaryContext(device_index);
}

Expand Down
4 changes: 2 additions & 2 deletions c10/cuda/CUDAFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
}

C10_CUDA_API bool hasPrimaryContext(int64_t device_index);
C10_CUDA_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext();
C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index);
C10_CUDA_API c10::optional<DeviceIndex> getDeviceIndexWithPrimaryContext();

} // namespace cuda
} // namespace c10
1 change: 1 addition & 0 deletions torch/csrc/jit/frontend/schema_type_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
// use the custom class mechanism
// instead. @jerryzh
{"Device", c10::TypeFactory::get<DeviceObjType>()},
{"DeviceIndex", c10::TypeFactory::get<IntType>()},
{"Stream", c10::TypeFactory::get<StreamObjType>()},
{"Scalar", c10::TypeFactory::get<NumberType>()},
{"str", c10::TypeFactory::get<StringType>()},
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/utils/python_arg_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
{"QScheme", ParameterType::QSCHEME},
{"Device", ParameterType::DEVICE},
{"DeviceIndex", ParameterType::INT64},
{"Stream", ParameterType::STREAM},
{"std::string", ParameterType::STRING},
{"c10::string_view", ParameterType::STRING},
Expand Down
Loading

0 comments on commit d9fb716

Please sign in to comment.