diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index ca28c9b0144ad2..81d05539f59962 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -82,6 +82,7 @@ namespace impl { c10::QScheme, c10::ScalarType, c10::Device, + c10::DeviceIndex, c10::Layout, c10::MemoryFormat, at::Dimname diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 265a98a87e6277..e76f507dd25332 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -1869,6 +1870,13 @@ struct getTypePtr_ final { } }; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; + template <> struct getMaybeFakeTypePtr_ final { static decltype(auto) call() { diff --git a/aten/src/ATen/core/op_registration/infer_schema.h b/aten/src/ATen/core/op_registration/infer_schema.h index e4c7e0e12ce0b4..a00ef76f460b92 100644 --- a/aten/src/ATen/core/op_registration/infer_schema.h +++ b/aten/src/ATen/core/op_registration/infer_schema.h @@ -38,8 +38,8 @@ constexpr int checkStaticTypes() { // Give nice error messages for some of the common error cases. // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT static_assert(guts::conjunction< - bool_t::value || std::is_same::value || std::is_same::value>... - >::value, "INVALID TYPE: Only int64_t and bool are supported as an integral argument type"); + bool_t::value || std::is_same::value || std::is_same::value || std::is_same::value>... + >::value, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type"); static_assert(guts::conjunction< bool_t::value>... >::value, "INVALID TYPE: float is not supported as an argument type, use double instead"); diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index d781f130951088..acb9b1931f045e 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -42,13 +42,13 @@ #include namespace c10::cuda::_internal { -void setHasPrimaryContext(bool (*func)(int64_t)); +void setHasPrimaryContext(bool (*func)(DeviceIndex)); } namespace at::cuda::detail { const at::cuda::NVRTC& nvrtc(); -int64_t current_device(); +DeviceIndex current_device(); static void (*magma_init_fn)() = nullptr; @@ -57,7 +57,7 @@ void set_magma_init_fn(void (*fn)()) { } namespace { -bool _hasPrimaryContext(int64_t device_index) { +bool _hasPrimaryContext(DeviceIndex device_index) { TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(), "hasPrimaryContext expects a valid device index, but got device_index=", device_index); unsigned int ctx_flags; @@ -226,7 +226,7 @@ const at::cuda::NVRTC& CUDAHooks::nvrtc() const { return at::cuda::detail::nvrtc(); } -int64_t current_device() { +DeviceIndex current_device() { int device; cudaError_t err = c10::cuda::GetDevice(&device); if (err == cudaSuccess) { @@ -235,11 +235,11 @@ int64_t current_device() { return -1; } -int64_t CUDAHooks::current_device() const { +DeviceIndex CUDAHooks::current_device() const { return at::cuda::detail::current_device(); } -bool CUDAHooks::hasPrimaryContext(int64_t device_index) const { +bool CUDAHooks::hasPrimaryContext(DeviceIndex device_index) const { return _hasPrimaryContext(device_index); } @@ -414,19 +414,19 @@ double CUDAHooks::batchnormMinEpsilonCuDNN() const { #endif } -int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize(int64_t device_index) const { +int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const { return at::native::detail::cufft_get_plan_cache_max_size_impl(device_index); } -void CUDAHooks::cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const { +void CUDAHooks::cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const { at::native::detail::cufft_set_plan_cache_max_size_impl(device_index, max_size); } -int64_t CUDAHooks::cuFFTGetPlanCacheSize(int64_t device_index) const { +int64_t CUDAHooks::cuFFTGetPlanCacheSize(DeviceIndex device_index) const { return at::native::detail::cufft_get_plan_cache_size_impl(device_index); } -void CUDAHooks::cuFFTClearPlanCache(int64_t device_index) const { +void CUDAHooks::cuFFTClearPlanCache(DeviceIndex device_index) const { at::native::detail::cufft_clear_plan_cache_impl(device_index); } @@ -434,7 +434,7 @@ int CUDAHooks::getNumGPUs() const { return at::cuda::device_count(); } -void CUDAHooks::deviceSynchronize(int64_t device_index) const { +void CUDAHooks::deviceSynchronize(DeviceIndex device_index) const { at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); c10::cuda::device_synchronize(); } diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index c88454d7f0e59d..37b2cb48baec37 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -29,8 +29,8 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool hasCuSOLVER() const override; bool hasROCM() const override; const at::cuda::NVRTC& nvrtc() const override; - int64_t current_device() const override; - bool hasPrimaryContext(int64_t device_index) const override; + DeviceIndex current_device() const override; + bool hasPrimaryContext(DeviceIndex device_index) const override; Allocator* getCUDADeviceAllocator() const override; Allocator* getPinnedMemoryAllocator() const override; bool compiledWithCuDNN() const override; @@ -43,12 +43,12 @@ struct CUDAHooks : public at::CUDAHooksInterface { long versionCuDNN() const override; std::string showConfig() const override; double batchnormMinEpsilonCuDNN() const override; - int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const override; - void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const override; - int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override; - void cuFFTClearPlanCache(int64_t device_index) const override; + int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override; + void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override; + int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override; + void cuFFTClearPlanCache(DeviceIndex device_index) const override; int getNumGPUs() const override; - void deviceSynchronize(int64_t device_index) const override; + void deviceSynchronize(DeviceIndex device_index) const override; }; }}} // at::cuda::detail diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index d57ae1ba1619bb..ea746d332a3468 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -114,11 +114,11 @@ struct TORCH_API CUDAHooksInterface { TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP); } - virtual bool hasPrimaryContext(int64_t device_index) const { + virtual bool hasPrimaryContext(DeviceIndex device_index) const { TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP); } - virtual int64_t current_device() const { + virtual DeviceIndex current_device() const { return -1; } @@ -167,19 +167,19 @@ struct TORCH_API CUDAHooksInterface { "Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP); } - virtual int64_t cuFFTGetPlanCacheMaxSize(int64_t /*device_index*/) const { + virtual int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex /*device_index*/) const { TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); } - virtual void cuFFTSetPlanCacheMaxSize(int64_t /*device_index*/, int64_t /*max_size*/) const { + virtual void cuFFTSetPlanCacheMaxSize(DeviceIndex /*device_index*/, int64_t /*max_size*/) const { TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); } - virtual int64_t cuFFTGetPlanCacheSize(int64_t /*device_index*/) const { + virtual int64_t cuFFTGetPlanCacheSize(DeviceIndex /*device_index*/) const { TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); } - virtual void cuFFTClearPlanCache(int64_t /*device_index*/) const { + virtual void cuFFTClearPlanCache(DeviceIndex /*device_index*/) const { TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); } @@ -187,7 +187,7 @@ struct TORCH_API CUDAHooksInterface { return 0; } - virtual void deviceSynchronize(int64_t /*device_index*/) const { + virtual void deviceSynchronize(DeviceIndex /*device_index*/) const { TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP); } }; diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 3a6efbbfbad6e0..18108d422d2337 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -793,19 +793,19 @@ Tensor fft_ifftshift(const Tensor& x, at::OptionalIntArrayRef dim_opt) { // We call the following methods via CUDA hooks because they are really only // valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details. -int64_t _cufft_get_plan_cache_max_size(int64_t device_index) { +int64_t _cufft_get_plan_cache_max_size(DeviceIndex device_index) { return detail::getCUDAHooks().cuFFTGetPlanCacheMaxSize(device_index); } -void _cufft_set_plan_cache_max_size(int64_t device_index, int64_t max_size) { +void _cufft_set_plan_cache_max_size(DeviceIndex device_index, int64_t max_size) { detail::getCUDAHooks().cuFFTSetPlanCacheMaxSize(device_index, max_size); } -int64_t _cufft_get_plan_cache_size(int64_t device_index) { +int64_t _cufft_get_plan_cache_size(DeviceIndex device_index) { return detail::getCUDAHooks().cuFFTGetPlanCacheSize(device_index); } -void _cufft_clear_plan_cache(int64_t device_index) { +void _cufft_clear_plan_cache(DeviceIndex device_index) { detail::getCUDAHooks().cuFFTClearPlanCache(device_index); } diff --git a/aten/src/ATen/native/cuda/CuFFTPlanCache.h b/aten/src/ATen/native/cuda/CuFFTPlanCache.h index 992399dca72e22..edeb8e8c82f80a 100644 --- a/aten/src/ATen/native/cuda/CuFFTPlanCache.h +++ b/aten/src/ATen/native/cuda/CuFFTPlanCache.h @@ -524,9 +524,9 @@ class CuFFTParamsLRUCache { // native function counterparts (at native/SpectralOps.cpp), i.e., // _cufft_get_plan_cache_max_size, _cufft_set_plan_cache_max_size // _cufft_get_plan_cache_size, and _cufft_clear_plan_cache. -int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index); -void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size); -int64_t cufft_get_plan_cache_size_impl(int64_t device_index); -void cufft_clear_plan_cache_impl(int64_t device_index); +int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index); +void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size); +int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index); +void cufft_clear_plan_cache_impl(DeviceIndex device_index); }}} // namespace at::native::detail diff --git a/aten/src/ATen/native/cuda/SpectralOps.cpp b/aten/src/ATen/native/cuda/SpectralOps.cpp index d9078343ac53b1..1b0c6c2991a41a 100644 --- a/aten/src/ATen/native/cuda/SpectralOps.cpp +++ b/aten/src/ATen/native/cuda/SpectralOps.cpp @@ -133,7 +133,7 @@ static std::vector> plan_caches; static std::mutex plan_caches_mutex; static inline -CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) { +CuFFTParamsLRUCache &cufft_get_plan_cache(DeviceIndex device_index) { std::lock_guard guard(plan_caches_mutex); AT_ASSERT(device_index >= 0); @@ -152,7 +152,7 @@ CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) { namespace detail { -int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) { +int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index) { TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(), "cufft_get_plan_cache_max_size: expected 0 <= device_index < ", at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=", @@ -160,7 +160,7 @@ int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) { return cufft_get_plan_cache(device_index).max_size(); } -void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size) { +void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size) { TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(), "cufft_set_plan_cache_max_size: expected 0 <= device_index < ", at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=", @@ -168,7 +168,7 @@ void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size) return cufft_get_plan_cache(device_index).resize(max_size); } -int64_t cufft_get_plan_cache_size_impl(int64_t device_index) { +int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index) { TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(), "cufft_get_plan_cache_size: expected 0 <= device_index < ", at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=", @@ -176,7 +176,7 @@ int64_t cufft_get_plan_cache_size_impl(int64_t device_index) { return cufft_get_plan_cache(device_index).size(); } -void cufft_clear_plan_cache_impl(int64_t device_index) { +void cufft_clear_plan_cache_impl(DeviceIndex device_index) { TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(), "cufft_clear_plan_cache: expected 0 <= device_index < ", at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=", diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8b7e99590b3599..700224093aa362 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2945,13 +2945,13 @@ CPU: _validate_compressed_sparse_indices_cpu CUDA: _validate_compressed_sparse_indices_cuda -- func: _cufft_get_plan_cache_size(int device_index) -> int +- func: _cufft_get_plan_cache_size(DeviceIndex device_index) -> int -- func: _cufft_get_plan_cache_max_size(int device_index) -> int +- func: _cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int -- func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> () +- func: _cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> () -- func: _cufft_clear_plan_cache(int device_index) -> () +- func: _cufft_clear_plan_cache(DeviceIndex device_index) -> () - func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor device_check: NoCheck # TensorIterator diff --git a/c10/core/Device.h b/c10/core/Device.h index 2a44a9c0ccc9cd..1f346e2f6750c3 100644 --- a/c10/core/Device.h +++ b/c10/core/Device.h @@ -174,13 +174,13 @@ struct C10_API Device final { // This is safe to do, because backends that use the DeviceIndex // have a later check when we actually try to switch to that device. TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - index_ == -1 || index_ >= 0, + index_ >= -1, "Device index must be -1 or non-negative, got ", - (int)index_); + static_cast(index_)); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( !is_cpu() || index_ <= 0, "CPU device index must be -1 or zero, got ", - (int)index_); + static_cast(index_)); } }; diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 26afdbb0d72dda..1c4b7756590d7a 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -148,9 +148,9 @@ void warn_or_error_on_sync() { } } -c10::optional getDeviceIndexWithPrimaryContext() { +c10::optional getDeviceIndexWithPrimaryContext() { // check current device first - int64_t current_device_index = current_device(); + auto current_device_index = current_device(); if (current_device_index >= 0) { if (hasPrimaryContext(current_device_index)) { return current_device_index; @@ -167,18 +167,18 @@ c10::optional getDeviceIndexWithPrimaryContext() { } namespace _internal { -bool dummyHasPrimaryContext(C10_UNUSED int64_t device_index) { +bool dummyHasPrimaryContext(C10_UNUSED DeviceIndex device_index) { TORCH_CHECK(false, "Should never been called"); } -bool (*hasPrimaryContext)(int64_t) = dummyHasPrimaryContext; +bool (*hasPrimaryContext)(DeviceIndex) = dummyHasPrimaryContext; // Private api to be called from CUDAHooks.cpp -C10_CUDA_API void setHasPrimaryContext(bool (*func)(int64_t)) { +C10_CUDA_API void setHasPrimaryContext(bool (*func)(DeviceIndex)) { hasPrimaryContext = func ? func : dummyHasPrimaryContext; } } // namespace _internal -bool hasPrimaryContext(int64_t device_index) { +bool hasPrimaryContext(DeviceIndex device_index) { return _internal::hasPrimaryContext(device_index); } diff --git a/c10/cuda/CUDAFunctions.h b/c10/cuda/CUDAFunctions.h index 31e3beb06bd476..2b9dc8eb5dfa2a 100644 --- a/c10/cuda/CUDAFunctions.h +++ b/c10/cuda/CUDAFunctions.h @@ -111,8 +111,8 @@ C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) { C10_CUDA_CHECK(cudaStreamSynchronize(stream)); } -C10_CUDA_API bool hasPrimaryContext(int64_t device_index); -C10_CUDA_API c10::optional getDeviceIndexWithPrimaryContext(); +C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index); +C10_CUDA_API c10::optional getDeviceIndexWithPrimaryContext(); } // namespace cuda } // namespace c10 diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index f702286a38994c..0fa26eddd8c179 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -58,6 +58,7 @@ TypePtr SchemaTypeParser::parseBaseType() { // use the custom class mechanism // instead. @jerryzh {"Device", c10::TypeFactory::get()}, + {"DeviceIndex", c10::TypeFactory::get()}, {"Stream", c10::TypeFactory::get()}, {"Scalar", c10::TypeFactory::get()}, {"str", c10::TypeFactory::get()}, diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 9dffd854fa9360..1c09d91d4f9d2e 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -43,6 +43,7 @@ static std::unordered_map type_map = { {"MemoryFormat", ParameterType::MEMORY_FORMAT}, {"QScheme", ParameterType::QSCHEME}, {"Device", ParameterType::DEVICE}, + {"DeviceIndex", ParameterType::INT64}, {"Stream", ParameterType::STREAM}, {"std::string", ParameterType::STRING}, {"c10::string_view", ParameterType::STRING}, diff --git a/torchgen/api/python.py b/torchgen/api/python.py index 96aa43be1060b5..ce7d3a2ea3ec3f 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -661,6 +661,7 @@ def argument_type_str( BaseTy.Storage, BaseTy.Layout, BaseTy.Device, + BaseTy.DeviceIndex, BaseTy.MemoryFormat, BaseTy.Dimname, BaseTy.Stream, @@ -907,7 +908,7 @@ def argument_type_str_pyi(t: Type) -> str: add_optional = True if isinstance(t, BaseType): - if t.name == BaseTy.int: + if t.name in [BaseTy.int, BaseTy.DeviceIndex]: ret = "_int" if t.name == BaseTy.SymInt: ret = "Union[_int, SymInt]" @@ -1255,6 +1256,8 @@ def arg_parser_unpack_method( return "scalartypeWithDefault" if has_default_init else "scalartype" elif t.name == BaseTy.Device: return "deviceWithDefault" if has_default_init else "device" + elif t.name == BaseTy.DeviceIndex: + return "toInt64" elif t.name == BaseTy.int: return "toInt64" elif t.name == BaseTy.SymInt: diff --git a/torchgen/api/types/types.py b/torchgen/api/types/types.py index 29f100e8c2e553..693623f973c4bf 100644 --- a/torchgen/api/types/types.py +++ b/torchgen/api/types/types.py @@ -62,6 +62,7 @@ dimVectorT = BaseCppType("at", "DimVector") layoutT = BaseCppType("at", "Layout") deviceT = BaseCppType("at", "Device") +deviceIndexT = BaseCppType("at", "DeviceIndex") scalarT = BaseCppType("at", "Scalar") optionalScalarRefT = BaseCppType("at", "OptionalScalarRef") memoryFormatT = BaseCppType("at", "MemoryFormat") @@ -111,6 +112,7 @@ BaseTy.DimVector: dimVectorT, BaseTy.Layout: layoutT, BaseTy.Device: deviceT, + BaseTy.DeviceIndex: deviceIndexT, BaseTy.Scalar: scalarT, BaseTy.MemoryFormat: memoryFormatT, BaseTy.QScheme: qschemeT, diff --git a/torchgen/model.py b/torchgen/model.py index 5af8d87448c1eb..9f47ad4051a01b 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -1803,6 +1803,7 @@ class BaseTy(Enum): bool = auto() Layout = auto() Device = auto() + DeviceIndex = auto() Scalar = auto() MemoryFormat = auto() QScheme = auto()